clshuaige 1 mese fa
parent
commit
c89936e146
1 ha cambiato i file con 10 aggiunte e 26 eliminazioni
  1. 10 26
      yolov5_trt.cpp

+ 10 - 26
yolov5_trt.cpp

@@ -392,8 +392,7 @@ int main()
 
         size_t size{ 0 };
 
-
-
+        // 读取engine文件
         std::ifstream file("yolov5-crowd-n.engine", std::ios::binary);
 
         if (file.good()) {
@@ -433,11 +432,11 @@ int main()
 
 
         // generate input data
-
         float data[BATCH_SIZE * 3 * IN_H * IN_W];
 
         // Run inference
         int num_total = 3 * (IN_H/4 * IN_W/4 + IN_H/8 * IN_W/8 + IN_H/16 * IN_W/16) * 6;
+        // 存储推理结果
         float prob[num_total];
         printf("num_total: %d\n", num_total);
 
@@ -465,36 +464,19 @@ int main()
     float height_scale = height / 640.0;
     float width_scale = width / 640.0;
 
-        doInference(*context, data, prob, BATCH_SIZE);
-        // 推理结果
-        // float *prob_ptr = prob;
-        // int count = 0;
-        // for (int i = 0; i < num_total / 6; i++)
-        // {
-        //     {
-        //         for (int j = 0; j < 6; j++)
-        //         {
-        //             //printf("%f ", prob_ptr[j]);
-        //             count++;
-        //         }
-        //         //printf("\n");
-                
-        //     }
-        //     prob_ptr += 6;
+    doInference(*context, data, prob, BATCH_SIZE);
 
-        // }
-        // printf("count: %d\n", count);
-        printf("inference done");
+    printf("inference done");
 
 
 
-        // Destroy the engine
+    // Destroy the engine
 
-        context->destroy();
+    context->destroy();
 
-        engine->destroy();
+    engine->destroy();
 
-        runtime->destroy();
+    runtime->destroy();
 
     // 将输出根据grid和anchor进行计算,得到预测框的坐标和置信度
     float* result = postprocess(prob);  // prob -> 604800
@@ -526,7 +508,9 @@ int main()
 
     // 根据置信度 获得有效框
     vector<vector<float>> info = get_info(result, 0.25, 6);
+    // 将推理的xywh 转换为原图中的 xyxy
     info_simplify(info, width_scale, height_scale);
+    // 将推理结果按类别分割
     vector<vector<vector<float>>> info_split = split_info(info);
 
     printf("info size: %ld\n", info_split.size());