in machine_learning/ml_infrastructure/inference-server-performance/server/scripts/tensorrt-optimization.py [0:0]
def convert_int8(
input_model_dir, output_model_dir, batch_size, precision_mode,
calib_image_dir, input_tensor, output_tensor, epochs):
# (TODO) Need to check if we need Tesla T4 when conversion.
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
# Get path to calibration data.
calibration_files = get_calibration_files(calib_image_dir, 'validation*')
# Create dataset and apply preprocess
# (TODO) Get num cpus to set appropriate number to num_parallel_calls
dataset = tf.data.TFRecordDataset(calibration_files)
dataset = dataset.apply(
tf.contrib.data.map_and_batch(
map_func=preprocess, batch_size=batch_size,
num_parallel_calls=multiprocessing.cpu_count()))
"""
Step 1: Creating the calibration graph.
"""
# Create TF-TRT INT8 calibration graph.
trt_int8_calib_graph = trt.create_inference_graph(
input_graph_def=None,
outputs=[output_tensor],
max_batch_size=batch_size,
input_saved_model_dir=input_model_dir,
precision_mode=precision_mode)
# Calibrate graph.
with tf.Session(graph=tf.Graph(), config=config) as sess:
tf.logging.info('preparing calibration data...')
iterator = dataset.make_one_shot_iterator()
next_element = iterator.get_next()
tf.logging.info('Loading INT8 calibration graph...')
output_node = tf.import_graph_def(
trt_int8_calib_graph, return_elements=[output_tensor], name='')
tf.logging.info('Calibrate model with calibration data...')
for _ in range(epochs):
sess.run(output_node,
feed_dict={input_tensor: sess.run(next_element)[0]})
"""
Step 2: Converting the calibration graph to inference graph
"""
tf.logging.info('Creating TF-TRT INT8 inference engine...')
trt_int8_calibrated_graph = trt.calib_graph_to_infer_graph(
trt_int8_calib_graph)
# Copy MetaGraph from base model.
with tf.Session(graph=tf.Graph(), config=config) as sess:
base_model = tf.saved_model.loader.load(
sess, [tf.saved_model.tag_constants.SERVING], input_model_dir)
metagraph = tf.MetaGraphDef()
metagraph.graph_def.CopyFrom(trt_int8_calibrated_graph)
for key in base_model.collection_def:
if key not in ['variables', 'local_variables', 'model_variables',
'trainable_variables', 'train_op', 'table_initializer']:
metagraph.collection_def[key].CopyFrom(base_model.collection_def[key])
metagraph.meta_info_def.CopyFrom(base_model.meta_info_def)
for key in base_model.signature_def:
metagraph.signature_def[key].CopyFrom(base_model.signature_def[key])
saved_model_builder = (
tf.saved_model.builder.SavedModelBuilder(output_model_dir))
# Write SavedModel with INT8 precision.
with tf.Graph().as_default():
tf.graph_util.import_graph_def(
trt_int8_calibrated_graph, return_elements=[output_tensor], name='')
with tf.Session(config=config) as sess:
saved_model_builder.add_meta_graph_and_variables(
sess, ('serve',), signature_def_map=metagraph.signature_def)
# Ignore other meta graphs from the input SavedModel.
saved_model_builder.save()