def load_with_converter()

in tftrt/blog_posts/Leveraging TensorFlow-TensorRT integration for Low latency Inference/tf2_inference.py [0:0]


def load_with_converter(path, precision, batch_size):
    """Loads a saved model using a TF-TRT converter, and returns the converter
    """

    params = copy.deepcopy(trt.DEFAULT_TRT_CONVERSION_PARAMS)
    if precision == 'int8':
        precision_mode = trt.TrtPrecisionMode.INT8
    elif precision == 'fp16':
        precision_mode = trt.TrtPrecisionMode.FP16
    else:
        precision_mode = trt.TrtPrecisionMode.FP32

    params = params._replace(
        precision_mode=precision_mode,
        max_workspace_size_bytes=2 << 32,  # 8,589,934,592 bytes
        maximum_cached_engines=100,
        minimum_segment_size=3,
        allow_build_at_runtime=True
    )

    import pprint
    print("%" * 85)
    pprint.pprint(params)
    print("%" * 85)

    converter = trt.TrtGraphConverterV2(
        input_saved_model_dir=path,
        conversion_params=params,
    )

    return converter