VideoDemo.cpp 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518
  1. #include <fstream>
  2. #include <opencv2/opencv.hpp>
  3. #include <iostream>
  4. #include <string>
  5. #include <vector>
  6. #include <sstream>
  7. #include <cstdlib>
  8. // libcommon
  9. #include "Logger.h"
  10. #include "OSTime.h"
  11. #include "SysUtils.h"
  12. #include "Semaphore.h"
  13. // libmascommon
  14. #include "MemPool.h"
  15. // deps
  16. #include "DetUtils.h"
  17. #include "ilogger.hpp"
  18. #include "trt_infer.hpp"
  19. // libyolocrowddetector
  20. #include "DetectorAPI.h"
  21. #include "YoloCrowdDetector.h"
  22. // libheadcountstrategy
  23. #include "StrategyAPI.h"
  24. #include "HeadCountStrategy.h"
  25. #define RUN_TEXT_DRAWING 0
  26. struct CallbackContext
  27. {
  28. masd::Strategy* strategy;
  29. cv::Mat* img;
  30. cv::VideoWriter* videoWriter;
  31. };
  32. tzc::Semaphore SEMA;
  33. // Function to split a string by a delimiter and return a vector of tokens
  34. std::vector<std::string> splitString(const std::string& str, char delimiter)
  35. {
  36. std::vector<std::string> tokens;
  37. std::stringstream ss(str);
  38. std::string item;
  39. while (std::getline(ss, item, delimiter))
  40. {
  41. // Trim whitespace from each token
  42. size_t start = item.find_first_not_of(" \t");
  43. size_t end = item.find_last_not_of(" \t");
  44. if (start != std::string::npos && end != std::string::npos)
  45. {
  46. tokens.push_back(item.substr(start, end - start + 1));
  47. }
  48. else if (start != std::string::npos)
  49. {
  50. tokens.push_back(item.substr(start));
  51. }
  52. // Ignore tokens that are all whitespace
  53. }
  54. return tokens;
  55. }
  56. void SaveDetectorInfo(const char* info)
  57. {
  58. std::ofstream outFile("detector_info.json", std::ios::out | std::ios::app);
  59. if (outFile.is_open())
  60. {
  61. outFile << "{\n\t\"Detector Information\": \"" << info << "\"\n}" << std::endl;
  62. outFile.close();
  63. TZLogInfo("Detector information saved to detector_info.json~~~");
  64. }
  65. else
  66. {
  67. TZLogError("Failed to open file for writing!!!");
  68. }
  69. }
  70. TZ_INT DetectionCallback(SPtr<masd::StreamInfo>& media, void* ctx)
  71. {
  72. TZLogInfo("Detection callback triggered!~~~");
  73. // Retrieve the context
  74. CallbackContext* callbackContext = reinterpret_cast<CallbackContext*>(ctx);
  75. if (!callbackContext || !callbackContext->img || callbackContext->strategy == nullptr)
  76. {
  77. TZLogError("Error: Invalid context or image pointer!!!");
  78. return -1;
  79. }
  80. cv::Mat* img = callbackContext->img;
  81. masd::Strategy* strategy = callbackContext->strategy;
  82. if (img->empty())
  83. {
  84. TZLogError("Error: Invalid image pointer!!!");
  85. return -1;
  86. }
  87. auto allDetRst = media->GetAllDetRst();
  88. int imgWidth = img->cols, imgHeight = img->rows;
  89. for (auto it = allDetRst.begin(); it != allDetRst.end(); ++it)
  90. {
  91. const std::string& detKey = it->first;
  92. const SPtr<masd::DetProducing>& detProducing = it->second;
  93. TZLogInfo("Detection Key: %s~~~", detKey.c_str());
  94. std::cout << "Detection Result: " << detProducing->Result << std::endl;
  95. if (!detProducing->Draw.Rects.empty())
  96. {
  97. TZLogInfo("Processing Draw Info...~~~");
  98. for (const auto& rect : detProducing->Draw.Rects)
  99. {
  100. TZLogInfo("Rect: LTX: %.2f, LTY: %.2f, "
  101. "RBX: %.2f, RBY: %.2f, Color: %s, Thickness: %d~~~",
  102. rect.LTX, rect.LTY, rect.RBX, rect.RBY,
  103. rect.Color.c_str(), rect.Thickness);
  104. if (!rect.Text.Text.empty())
  105. {
  106. TZLogInfo("Text: %s~~~", rect.Text.Text.c_str());
  107. }
  108. cv::Scalar color;
  109. {
  110. std::stringstream colorStream(rect.Color);
  111. int r, g, b;
  112. char comma;
  113. colorStream >> r >> comma >> g >> comma >> b;
  114. color = cv::Scalar(b, g, r);
  115. }
  116. cv::Point topLeft(rect.LTX * imgWidth, rect.LTY * imgHeight);
  117. cv::Point bottomRight(rect.RBX * imgWidth, rect.RBY * imgHeight);
  118. cv::rectangle(*img, topLeft, bottomRight, color, rect.Thickness);
  119. #if RUN_TEXT_DRAWING
  120. if (!rect.Text.Text.empty())
  121. {
  122. cv::putText(*img, rect.Text.Text,
  123. cv::Point(topLeft.x, topLeft.y - 10),
  124. cv::FONT_HERSHEY_SIMPLEX, 0.8, color, 2);
  125. }
  126. #endif
  127. }
  128. }
  129. else
  130. {
  131. TZLogInfo("No Draw Info available.~~~");
  132. }
  133. if (detProducing->DetMedia)
  134. {
  135. const auto& media = detProducing->DetMedia;
  136. TZLogInfo("Media Length: %d~~~", media->Length);
  137. TZLogInfo("Media DataType: %d~~~", media->DataType);
  138. TZLogInfo("Media Height: %d, Width: %d~~~", media->Height, media->Width);
  139. }
  140. }
  141. // Process the strategy inside the callback
  142. TZ_INT strategyResult = strategy->DoStrategy(media);
  143. if (strategyResult != masd::MEC_OK)
  144. {
  145. TZLogError("Headcount strategy failed to process the stream info!!!");
  146. return -1;
  147. }
  148. TZLogInfo("Headcount strategy processed the stream info successfully~~~");
  149. SPtr<masd::StraProducing> straProducing = media->GetStraProducing();
  150. if (straProducing)
  151. {
  152. std::string headcountResult = straProducing->Result.RstName;
  153. TZLogInfo("Detected headcount: %s", headcountResult.c_str());
  154. int fontFace = cv::FONT_HERSHEY_SIMPLEX;
  155. double fontScale = 0.8;
  156. int thickness = 2;
  157. cv::Scalar textColor(0, 255, 0);
  158. cv::Point textPosition(10, 30);
  159. cv::putText(*img, headcountResult, textPosition, fontFace, fontScale, textColor, thickness);
  160. }
  161. else
  162. {
  163. TZLogWarn("No headcount result found in the stream info.");
  164. }
  165. callbackContext->videoWriter->write(*img);
  166. SEMA.Signal();
  167. return 0;
  168. }
  169. void printUsage(const char* programName)
  170. {
  171. std::cout << "Usage: " << programName << " --videos <video_paths> [--model <model_path>] [--confidence <confidence_threshold>]\n";
  172. std::cout << " --videos : Comma-separated list of video file paths (required)\n";
  173. std::cout << " --model : Path to the detection model (optional, default: ../../yolo-crowd-ft-e60.trt)\n";
  174. std::cout << " --confidence : Confidence threshold (optional, range: 0.0 - 1.0, default: 0.25)\n";
  175. std::cout << "Example:\n";
  176. std::cout << " " << programName << " --videos video1.mp4,video2.mp4 --model ../../models/yolo-crowd-ft-e60.trt --confidence 0.3\n";
  177. }
  178. int main(int argc, char* argv[])
  179. {
  180. // Initialize log
  181. INITIALIZE_LOGGER_NORMAL("test", "./test.log", 1, 100, 6, 1, 1);
  182. // Parse command-line arguments
  183. std::string videoPathsStr;
  184. std::string modelPath = "../../models/yolo-crowd-ft-e60.trt"; // Default model path
  185. float confidenceThreshold = 0.25f; // Default confidence threshold
  186. if (argc < 3)
  187. {
  188. std::cout << "Error: Insufficient arguments provided.\n";
  189. printUsage(argv[0]);
  190. return -1;
  191. }
  192. for (int i = 1; i < argc; ++i)
  193. {
  194. std::string arg = argv[i];
  195. if (arg == "--videos" && i + 1 < argc)
  196. {
  197. videoPathsStr = argv[++i];
  198. }
  199. else if (arg == "--model" && i + 1 < argc)
  200. {
  201. modelPath = argv[++i];
  202. }
  203. else if (arg == "--confidence" && i + 1 < argc)
  204. {
  205. try
  206. {
  207. confidenceThreshold = std::stof(argv[++i]);
  208. if (confidenceThreshold < 0.0f || confidenceThreshold > 1.0f)
  209. {
  210. std::cout << "Error: Confidence threshold must be between 0.0 and 1.0.\n";
  211. return -1;
  212. }
  213. }
  214. catch (const std::invalid_argument& e)
  215. {
  216. std::cout << "Error: Invalid confidence threshold value.\n";
  217. return -1;
  218. }
  219. catch (const std::out_of_range& e)
  220. {
  221. std::cout << "Error: Confidence threshold value out of range.\n";
  222. return -1;
  223. }
  224. }
  225. else
  226. {
  227. std::cout << "Error: Unknown or incomplete argument '" << arg << "'.\n";
  228. printUsage(argv[0]);
  229. return -1;
  230. }
  231. }
  232. // Check if --videos was provided
  233. if (videoPathsStr.empty())
  234. {
  235. std::cout << "Error: --videos argument is required.\n";
  236. printUsage(argv[0]);
  237. return -1;
  238. }
  239. // Split video paths
  240. std::vector<std::string> videoPaths = splitString(videoPathsStr, ',');
  241. if (videoPaths.empty())
  242. {
  243. std::cout << "Error: No video paths provided.\n";
  244. printUsage(argv[0]);
  245. return -1;
  246. }
  247. // Display the received parameters
  248. TZLogInfo("Received parameters:");
  249. TZLogInfo(" Video Paths:");
  250. for (const auto& path : videoPaths)
  251. {
  252. TZLogInfo(" %s", path.c_str());
  253. }
  254. TZLogInfo(" Model Path: %s", modelPath.c_str());
  255. TZLogInfo(" Confidence Threshold: %.2f", confidenceThreshold);
  256. // Initialize memory pool
  257. masd::MemPool *pool = masd::MEMPOOL;
  258. if (pool->Initialize() != masd::MEC_OK)
  259. {
  260. TZLogError("Memory pool initialization failed!!!");
  261. return -1;
  262. }
  263. // Iterate over each video path
  264. for (const auto& videoPath : videoPaths)
  265. {
  266. TZLogInfo("Processing video: %s", videoPath.c_str());
  267. // Step 1:
  268. // Initialize the SDK
  269. TZ_INT initResult = Initialize();
  270. if (initResult != masd::MEC_OK)
  271. {
  272. TZLogError("Failed to initialize the SDK!!!");
  273. return -1;
  274. }
  275. TZLogInfo("SDK Initialized Successfully~~~");
  276. // Step 2:
  277. // Build yolo-crowd detector
  278. masd::Detector* detector = BuildDetector();
  279. if (detector == nullptr)
  280. {
  281. TZLogError("Failed to build yolo-crowd detector!!!");
  282. Dispose();
  283. return -1;
  284. }
  285. TZLogInfo("Yolo-crowd detector built successfully~~~");
  286. // Build headcount strategy
  287. masd::Strategy* strategy = BuildStrategy();
  288. if(strategy == nullptr)
  289. {
  290. TZLogError("Failed to build headcount strategy!!!");
  291. DestroyDetector(detector);
  292. Dispose();
  293. return -1;
  294. }
  295. TZLogInfo("Headcount strategy built successfully~~~");
  296. // Step 3:
  297. // Initialize the yolo-crowd detector with configuration parameters
  298. std::stringstream initParamStream;
  299. initParamStream << "{"
  300. << "\"gpu_id\": 0, "
  301. << "\"max_objects\": 1024, "
  302. << "\"confidence_threshold\": " << confidenceThreshold << ", "
  303. << "\"nms_threshold\": 0.5, "
  304. << "\"model_path\": \"" << modelPath << "\""
  305. << "}";
  306. std::string initParam = initParamStream.str();
  307. TZ_INT initDetResult = detector->Initialize(initParam);
  308. if (initDetResult != masd::MEC_OK)
  309. {
  310. TZLogError("Failed to initialize the yolo-crowd detector!!!");
  311. DestroyDetector(detector);
  312. DestroyStrategy(strategy);
  313. Dispose();
  314. return -1;
  315. }
  316. TZLogInfo("Yolo-crowd detector initialized successfully~~~");
  317. // Initialize the headcount strategy
  318. TZ_INT initStraResult = strategy->Initialize();
  319. if (initStraResult != masd::MEC_OK)
  320. {
  321. TZLogError("Failed to initialize the headcount strategy!!!");
  322. DestroyDetector(detector);
  323. DestroyStrategy(strategy);
  324. Dispose();
  325. return -1;
  326. }
  327. TZLogInfo("Headcount strategy initialized successfully~~~");
  328. // Step 4:
  329. // Set yolo-crowd detection configuration (optional)
  330. std::stringstream detectConfigStream;
  331. detectConfigStream << "{"
  332. << "\"freq\": 0, "
  333. << "\"target\": {"
  334. << "\"target_class\": 0, "
  335. << "\"target_threshold\": " << confidenceThreshold
  336. << "}, "
  337. << "\"focusArea\": ["
  338. << "{ \"LTX\": 0.0, \"LTY\": 0.0, \"RBX\": 1.0, \"RBY\": 1.0 }"
  339. << "], "
  340. << "\"ignoreArea\": ["
  341. << "{ \"LTX\": 0.0, \"LTY\": 0.0, \"RBX\": 0.0, \"RBY\": 0.0 }"
  342. << "]"
  343. << "}";
  344. std::string detectConfig = detectConfigStream.str();
  345. TZ_INT setDetCfgResult = detector->SetDetectCfg(detectConfig);
  346. if (setDetCfgResult != masd::MEC_OK)
  347. {
  348. TZLogError("Failed to set yolo-crowd detection configuration!!!");
  349. DestroyDetector(detector);
  350. DestroyStrategy(strategy);
  351. Dispose();
  352. return -1;
  353. }
  354. TZLogInfo("Yolo-crowd detection configuration set successfully~~~");
  355. // Set headcount strategy configuration (optional)
  356. std::string headcountConfig = R"({
  357. "TimeThreshold": 5
  358. })";
  359. TZ_INT setStraCfgResult = strategy->SetStrategyCfg(headcountConfig);
  360. if (setStraCfgResult != masd::MEC_OK)
  361. {
  362. TZLogError("Failed to set headcount strategy configuration!!!");
  363. DestroyDetector(detector);
  364. DestroyStrategy(strategy);
  365. Dispose();
  366. return -1;
  367. }
  368. TZLogInfo("Headcount strategy configuration set successfully~~~");
  369. // Step 5:
  370. // Open the video file
  371. cv::VideoCapture videoCapture(videoPath);
  372. if(!videoCapture.isOpened())
  373. {
  374. TZLogError("Failed to load video: %s!!!", videoPath.c_str());
  375. DestroyDetector(detector);
  376. DestroyStrategy(strategy);
  377. Dispose();
  378. return -1;
  379. }
  380. TZ_INT videoWidth = static_cast<TZ_INT>(videoCapture.get(cv::CAP_PROP_FRAME_WIDTH));
  381. TZ_INT videoHeight = static_cast<TZ_INT>(videoCapture.get(cv::CAP_PROP_FRAME_HEIGHT));
  382. TZ_INT videoFPS = static_cast<TZ_INT>(videoCapture.get(cv::CAP_PROP_FPS));
  383. // Generate output video file name based on input video file name
  384. std::string outputVideoPath = "output_" + videoPath.substr(videoPath.find_last_of("/\\") + 1);
  385. cv::VideoWriter videoWriter(outputVideoPath,
  386. cv::VideoWriter::fourcc('X', '2', '6', '4'),
  387. videoFPS, cv::Size(videoWidth, videoHeight));
  388. if (!videoWriter.isOpened())
  389. {
  390. TZLogError("Failed to open video writer for %s!!!", outputVideoPath.c_str());
  391. DestroyDetector(detector);
  392. DestroyStrategy(strategy);
  393. Dispose();
  394. return -1;
  395. }
  396. TZLogInfo("Processing video: %s", videoPath.c_str());
  397. TZLogInfo("Output video will be saved as: %s", outputVideoPath.c_str());
  398. cv::Mat frame;
  399. while(videoCapture.read(frame))
  400. {
  401. if(frame.empty())
  402. {
  403. TZLogError("Failed to read frame from video: %s!!!", videoPath.c_str());
  404. break;
  405. }
  406. TZ_INT length = frame.total() * frame.elemSize();
  407. SPtr<masd::Media> mediaResource = std::make_shared<masd::Media>(length);
  408. mediaResource->Width = frame.cols;
  409. mediaResource->Height = frame.rows;
  410. mediaResource->DataType = frame.type();
  411. mediaResource->Mem = frame.data;
  412. SPtr<masd::StreamInfo> streamInfo = std::make_shared<masd::StreamInfo>();
  413. streamInfo->SetMediaRsc(mediaResource);
  414. CallbackContext callbackContext{strategy, &frame, &videoWriter};
  415. detector->DoDetect(streamInfo, DetectionCallback, &callbackContext);
  416. SEMA.Wait();
  417. }
  418. // Step 6:
  419. // Print DetGetInformation
  420. char detectorInfo[4096];
  421. TZ_INT infoResult = GetInformation(detectorInfo);
  422. if (infoResult != masd::MEC_OK)
  423. {
  424. TZLogError("Failed to get detector information!!!");
  425. DestroyDetector(detector);
  426. DestroyStrategy(strategy);
  427. Dispose();
  428. return -1;
  429. }
  430. SaveDetectorInfo(detectorInfo);
  431. // Step 7:
  432. // Destroy the detector
  433. DestroyDetector(detector);
  434. TZLogInfo("Detector destroyed successfully~~~");
  435. // Destroy the strategy
  436. DestroyStrategy(strategy);
  437. TZLogInfo("Detector strategy destroyed successfully~~~");
  438. // Step 8:
  439. // Dispose the SDK
  440. TZ_INT disposeResult = Dispose();
  441. if (disposeResult != masd::MEC_OK)
  442. {
  443. TZLogError("Failed to dispose the SDK!!!");
  444. return -1;
  445. }
  446. TZLogInfo("SDK disposed successfully~~~");
  447. }
  448. return 0;
  449. }