in nni/compression/pytorch/quantization_speedup/integrated_tensorrt.py [0:0]
def build_engine(model_file, config=None, extra_layer_bits=32, strict_datatype=False, calib=None):
"""
This function builds an engine from an onnx model with calibration process.
Parameters
----------
model_file : str
The path of onnx model
config : dict
Config recording bits number and name of layers
extra_layer_bits : int
Other layers which are not in config will be quantized to corresponding bits number
strict_datatype : bool
Whether constrain layer bits to the number given in config or not. If true, all the layer
will be set to given bits strictly. Otherwise, these layers will be set automatically by
tensorrt
calib : numpy array
The data using to calibrate quantization model
Returns
-------
tensorrt.ICudaEngine
An ICudaEngine for executing inference on a built network
"""
with trt.Builder(TRT_LOGGER) as builder, builder.create_network(common.EXPLICIT_BATCH) as network, \
trt.OnnxParser(network, TRT_LOGGER) as parser, builder.create_builder_config() as trt_config:
# Attention that, builder should be set to 1 because of the implementation of allocate_buffer
trt_version = int(trt.__version__[0])
assert trt_version == TRT8 or trt_version == TRT7, "Version of TensorRT is too old, please \
update TensorRT to version >= 7.0"
if trt_version == TRT7:
logger.warning("TensorRT7 is deprecated and may be removed in the following release.")
builder.max_batch_size = 1
if trt_version == TRT8:
trt_config.max_workspace_size = common.GiB(4)
else:
builder.max_workspace_size = common.GiB(4)
if extra_layer_bits == 32 and config is None:
pass
elif extra_layer_bits == 16 and config is None:
if trt_version == TRT8:
trt_config.set_flag(trt.BuilderFlag.FP16)
else:
builder.fp16_mode = True
elif extra_layer_bits == 8 and config is None:
# entire model in 8bit mode
if trt_version == TRT8:
trt_config.set_flag(trt.BuilderFlag.INT8)
else:
builder.int8_mode = True
else:
if trt_version == TRT8:
trt_config.set_flag(trt.BuilderFlag.INT8)
trt_config.set_flag(trt.BuilderFlag.FP16)
if strict_datatype:
trt_config.set_flag(trt.BuilderFlag.STRICT_TYPES)
else:
builder.int8_mode = True
builder.fp16_mode = True
builder.strict_type_constraints = strict_datatype
valid_config(config)
# Parse onnx model
with open(model_file, 'rb') as model:
if not parser.parse(model.read()):
logger.error('ERROR: Fail to parse the ONNX file.')
for error in range(parser.num_errors):
logger.error(parser.get_error(error))
return None
if calib is not None:
if trt_version == TRT8:
trt_config.int8_calibrator = calib
else:
builder.int8_calibrator = calib
# This design may not be correct if output more than one
for i in range(network.num_layers):
if config is None:
break
layer = network.get_layer(i)
if layer.name in config:
w_bits = config[layer.name]['weight_bits']
a_bits = config[layer.name]['output_bits']
layer.precision = Precision_Dict[w_bits]
layer.set_output_type(0, Precision_Dict[a_bits])
else:
# This implementation may be incorrect when output number > 1
for i in range(network.num_layers):
if config is None:
# no low bits layer need to be set, keep original model
break
layer = network.get_layer(i)
if layer.name not in config:
continue
# layer numbers of gemm changed during pytorch->onnx model convertion, need special handle
if layer.name[0:4] == "Gemm":
handle_gemm(network, i, config)
continue
# If weight_bits exists in config, set layer precision and layer's input tensor dynamic range.
if 'weight_bits' in config[layer.name]:
assert 'tracked_min_input' in config[layer.name]
assert 'tracked_max_input' in config[layer.name]
w_bits = config[layer.name]['weight_bits']
tracked_min_input = config[layer.name]['tracked_min_input']
tracked_max_input = config[layer.name]['tracked_max_input']
layer.precision = Precision_Dict[w_bits]
in_tensor = layer.get_input(0)
in_tensor.dynamic_range = (tracked_min_input, tracked_max_input)
# If output exists in config, set layer output type and layer's output tensor dynamic range.
if 'output_bits' in config[layer.name]:
assert 'tracked_min_output' in config[layer.name]
assert 'tracked_max_output' in config[layer.name]
a_bits = config[layer.name]['output_bits']
tracked_min_output = config[layer.name]['tracked_min_output']
tracked_max_output = config[layer.name]['tracked_max_output']
layer.set_output_type(0, Precision_Dict[a_bits])
out_tensor = layer.get_output(0)
out_tensor.dynamic_range = (tracked_min_output, tracked_max_output)
# Build engine and do int8 calibration.
if trt_version == TRT8:
engine = builder.build_engine(network, trt_config)
else:
engine = builder.build_cuda_engine(network)
return engine