import tensorrt as trt import onnx import torch device = torch.device('cuda:0') onnx = "yolo-crowd-n.onnx" print(f"starting export of onnx to tensorrt engine... with TensorRT version: {trt.__version__}\n") f = "yolov5-crowd-n.engine" logger = trt.Logger(trt.Logger.INFO) builder = trt.Builder(logger) config = builder.create_builder_config() config.max_workspace_size = 1 << 30 flag = 1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH) network = builder.create_network(flag) parser = trt.OnnxParser(network, logger) if not parser.parse_from_file(onnx): print('ERROR: Failed to parse the ONNX file.') for error in range(parser.num_errors): print(parser.get_error(error)) exit(1) inputs = [network.get_input(i) for i in range(network.num_inputs)] outputs = [network.get_output(i) for i in range(network.num_outputs)] for inp in inputs: print(f'input: {inp.name}, shape: {inp.shape} {inp.dtype}') for out in outputs: print(f'output: {out.name}, shape: {out.shape} {out.dtype}') if builder.platform_has_fast_fp16: config.set_flag(trt.BuilderFlag.FP16) build = builder.build_serialized_network with build(network, config) as engine, open(f, "wb") as t: t.write(engine)