create_trt.cpp 2.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119
  1. #include <fstream>
  2. #include <iostream>
  3. #include <NvInfer.h>
  4. #include </home/cl/package/TensorRT-8.6.1.6/samples/common/logger.h>
  5. using namespace nvinfer1;
  6. using namespace sample;
  7. const char* IN_NAME = "input";
  8. const char* OUT_NAME = "output";
  9. static const int IN_H = 224;
  10. static const int IN_W = 224;
  11. static const int BATCH_SIZE = 1;
  12. static const int EXPLICIT_BATCH = 1 << (int)(NetworkDefinitionCreationFlag::kEXPLICIT_BATCH);
  13. int main(int argc, char** argv)
  14. {
  15. // Create builder
  16. Logger m_logger;
  17. IBuilder* builder = createInferBuilder(m_logger);
  18. IBuilderConfig* config = builder->createBuilderConfig();
  19. // Create model to populate the network
  20. INetworkDefinition* network = builder->createNetworkV2(EXPLICIT_BATCH);
  21. ITensor* input_tensor = network->addInput(IN_NAME, DataType::kFLOAT, Dims4{ BATCH_SIZE, 3, IN_H, IN_W });
  22. IPoolingLayer* pool = network->addPoolingNd(*input_tensor, PoolingType::kMAX, DimsHW{ 2, 2 });
  23. pool->setStrideNd(DimsHW{ 2, 2 });
  24. pool->getOutput(0)->setName(OUT_NAME);
  25. network->markOutput(*pool->getOutput(0));
  26. // Build engine
  27. IOptimizationProfile* profile = builder->createOptimizationProfile();
  28. profile->setDimensions(IN_NAME, OptProfileSelector::kMIN, Dims4(BATCH_SIZE, 3, IN_H, IN_W));
  29. profile->setDimensions(IN_NAME, OptProfileSelector::kOPT, Dims4(BATCH_SIZE, 3, IN_H, IN_W));
  30. profile->setDimensions(IN_NAME, OptProfileSelector::kMAX, Dims4(BATCH_SIZE, 3, IN_H, IN_W));
  31. config->setMaxWorkspaceSize(1 << 20);
  32. ICudaEngine* engine = builder->buildEngineWithConfig(*network, *config);
  33. // Serialize the model to engine file
  34. IHostMemory* modelStream{ nullptr };
  35. assert(engine != nullptr);
  36. modelStream = engine->serialize();
  37. std::ofstream p("model.engine", std::ios::binary);
  38. if (!p) {
  39. std::cerr << "could not open output file to save model" << std::endl;
  40. return -1;
  41. }
  42. p.write(reinterpret_cast<const char*>(modelStream->data()), modelStream->size());
  43. std::cout << "generating file done!" << std::endl;
  44. // Release resources
  45. modelStream->destroy();
  46. network->destroy();
  47. engine->destroy();
  48. builder->destroy();
  49. config->destroy();
  50. return 0;
  51. }