in tftrt/blog_posts/Leveraging TensorFlow-TensorRT integration for Low latency Inference/tf2_inference.py [0:0]
def load_with_converter(path, precision, batch_size):
"""Loads a saved model using a TF-TRT converter, and returns the converter
"""
params = copy.deepcopy(trt.DEFAULT_TRT_CONVERSION_PARAMS)
if precision == 'int8':
precision_mode = trt.TrtPrecisionMode.INT8
elif precision == 'fp16':
precision_mode = trt.TrtPrecisionMode.FP16
else:
precision_mode = trt.TrtPrecisionMode.FP32
params = params._replace(
precision_mode=precision_mode,
max_workspace_size_bytes=2 << 32, # 8,589,934,592 bytes
maximum_cached_engines=100,
minimum_segment_size=3,
allow_build_at_runtime=True
)
import pprint
print("%" * 85)
pprint.pprint(params)
print("%" * 85)
converter = trt.TrtGraphConverterV2(
input_saved_model_dir=path,
conversion_params=params,
)
return converter