123456789101112131415161718192021222324252627282930313233343536373839 |
- import tensorrt as trt
- import onnx
- import torch
- device = torch.device('cuda:0')
- onnx = "yolo-crowd-n.onnx"
- print(f"starting export of onnx to tensorrt engine... with TensorRT version: {trt.__version__}\n")
- f = "yolov5-crowd-n.engine"
- logger = trt.Logger(trt.Logger.INFO)
- builder = trt.Builder(logger)
- config = builder.create_builder_config()
- config.max_workspace_size = 1 << 30
- flag = 1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
- network = builder.create_network(flag)
- parser = trt.OnnxParser(network, logger)
- if not parser.parse_from_file(onnx):
- print('ERROR: Failed to parse the ONNX file.')
- for error in range(parser.num_errors):
- print(parser.get_error(error))
- exit(1)
- inputs = [network.get_input(i) for i in range(network.num_inputs)]
- outputs = [network.get_output(i) for i in range(network.num_outputs)]
- for inp in inputs:
- print(f'input: {inp.name}, shape: {inp.shape} {inp.dtype}')
- for out in outputs:
- print(f'output: {out.name}, shape: {out.shape} {out.dtype}')
- if builder.platform_has_fast_fp16:
- config.set_flag(trt.BuilderFlag.FP16)
- build = builder.build_serialized_network
- with build(network, config) as engine, open(f, "wb") as t:
- t.write(engine)
|