def build_engine()

in nni/compression/pytorch/quantization_speedup/integrated_tensorrt.py [0:0]


def build_engine(model_file, config=None, extra_layer_bits=32, strict_datatype=False, calib=None):
    """
    This function builds an engine from an onnx model with calibration process.

    Parameters
    ----------
    model_file : str
        The path of onnx model
    config : dict
        Config recording bits number and name of layers
    extra_layer_bits : int
        Other layers which are not in config will be quantized to corresponding bits number
    strict_datatype : bool
        Whether constrain layer bits to the number given in config or not. If true, all the layer
        will be set to given bits strictly. Otherwise, these layers will be set automatically by
        tensorrt
    calib : numpy array
        The data using to calibrate quantization model

    Returns
    -------
    tensorrt.ICudaEngine
        An ICudaEngine for executing inference on a built network
    """
    with trt.Builder(TRT_LOGGER) as builder, builder.create_network(common.EXPLICIT_BATCH) as network, \
        trt.OnnxParser(network, TRT_LOGGER) as parser, builder.create_builder_config() as trt_config:
        # Attention that, builder should be set to 1 because of the implementation of allocate_buffer
        trt_version = int(trt.__version__[0])
        assert trt_version == TRT8 or trt_version == TRT7, "Version of TensorRT is too old, please \
            update TensorRT to version >= 7.0"
        if trt_version == TRT7:
            logger.warning("TensorRT7 is deprecated and may be removed in the following release.")

        builder.max_batch_size = 1
        if trt_version == TRT8:
            trt_config.max_workspace_size = common.GiB(4)
        else:
            builder.max_workspace_size = common.GiB(4)

        if extra_layer_bits == 32 and config is None:
            pass
        elif extra_layer_bits == 16 and config is None:
            if trt_version == TRT8:
                trt_config.set_flag(trt.BuilderFlag.FP16)
            else:
                builder.fp16_mode = True
        elif extra_layer_bits == 8 and config is None:
            # entire model in 8bit mode
            if trt_version == TRT8:
                trt_config.set_flag(trt.BuilderFlag.INT8)
            else:
                builder.int8_mode = True
        else:
            if trt_version == TRT8:
                trt_config.set_flag(trt.BuilderFlag.INT8)
                trt_config.set_flag(trt.BuilderFlag.FP16)
                if strict_datatype:
                    trt_config.set_flag(trt.BuilderFlag.STRICT_TYPES)
            else:
                builder.int8_mode = True
                builder.fp16_mode = True
                builder.strict_type_constraints = strict_datatype

        valid_config(config)

        # Parse onnx model
        with open(model_file, 'rb') as model:
            if not parser.parse(model.read()):
                logger.error('ERROR: Fail to parse the ONNX file.')
                for error in range(parser.num_errors):
                    logger.error(parser.get_error(error))
                return None

        if calib is not None:
            if trt_version == TRT8:
                trt_config.int8_calibrator = calib
            else:
                builder.int8_calibrator = calib
            # This design may not be correct if output more than one
            for i in range(network.num_layers):
                if config is None:
                    break
                layer = network.get_layer(i)
                if layer.name in config:
                    w_bits = config[layer.name]['weight_bits']
                    a_bits = config[layer.name]['output_bits']
                    layer.precision = Precision_Dict[w_bits]
                    layer.set_output_type(0, Precision_Dict[a_bits])
        else:
            # This implementation may be incorrect when output number > 1
            for i in range(network.num_layers):
                if config is None:
                    # no low bits layer need to be set, keep original model
                    break
                layer = network.get_layer(i)
                if layer.name not in config:
                    continue
                # layer numbers of gemm changed during pytorch->onnx model convertion, need special handle
                if layer.name[0:4] == "Gemm":
                    handle_gemm(network, i, config)
                    continue

                # If weight_bits exists in config, set layer precision and layer's input tensor dynamic range.
                if 'weight_bits' in config[layer.name]:
                    assert 'tracked_min_input' in config[layer.name]
                    assert 'tracked_max_input' in config[layer.name]
                    w_bits = config[layer.name]['weight_bits']
                    tracked_min_input = config[layer.name]['tracked_min_input']
                    tracked_max_input = config[layer.name]['tracked_max_input']
                    layer.precision = Precision_Dict[w_bits]
                    in_tensor = layer.get_input(0)
                    in_tensor.dynamic_range = (tracked_min_input, tracked_max_input)

                # If output exists in config, set layer output type and layer's output tensor dynamic range.
                if 'output_bits' in config[layer.name]:
                    assert 'tracked_min_output' in config[layer.name]
                    assert 'tracked_max_output' in config[layer.name]
                    a_bits = config[layer.name]['output_bits']
                    tracked_min_output = config[layer.name]['tracked_min_output']
                    tracked_max_output = config[layer.name]['tracked_max_output']
                    layer.set_output_type(0, Precision_Dict[a_bits])
                    out_tensor = layer.get_output(0)
                    out_tensor.dynamic_range = (tracked_min_output, tracked_max_output)

        # Build engine and do int8 calibration.
        if trt_version == TRT8:
            engine = builder.build_engine(network, trt_config)
        else:
            engine = builder.build_cuda_engine(network)
        return engine