onnx2trt.py 1.2 KB

123456789101112131415161718192021222324252627282930313233343536373839
  1. import tensorrt as trt
  2. import onnx
  3. import torch
  4. device = torch.device('cuda:0')
  5. onnx = "yolo-crowd-n.onnx"
  6. print(f"starting export of onnx to tensorrt engine... with TensorRT version: {trt.__version__}\n")
  7. f = "yolov5-crowd-n.engine"
  8. logger = trt.Logger(trt.Logger.INFO)
  9. builder = trt.Builder(logger)
  10. config = builder.create_builder_config()
  11. config.max_workspace_size = 1 << 30
  12. flag = 1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
  13. network = builder.create_network(flag)
  14. parser = trt.OnnxParser(network, logger)
  15. if not parser.parse_from_file(onnx):
  16. print('ERROR: Failed to parse the ONNX file.')
  17. for error in range(parser.num_errors):
  18. print(parser.get_error(error))
  19. exit(1)
  20. inputs = [network.get_input(i) for i in range(network.num_inputs)]
  21. outputs = [network.get_output(i) for i in range(network.num_outputs)]
  22. for inp in inputs:
  23. print(f'input: {inp.name}, shape: {inp.shape} {inp.dtype}')
  24. for out in outputs:
  25. print(f'output: {out.name}, shape: {out.shape} {out.dtype}')
  26. if builder.platform_has_fast_fp16:
  27. config.set_flag(trt.BuilderFlag.FP16)
  28. build = builder.build_serialized_network
  29. with build(network, config) as engine, open(f, "wb") as t:
  30. t.write(engine)