def convert_int8()

in machine_learning/ml_infrastructure/inference-server-performance/server/scripts/tensorrt-optimization.py [0:0]


def convert_int8(
    input_model_dir, output_model_dir, batch_size, precision_mode,
    calib_image_dir, input_tensor, output_tensor, epochs):
  
  # (TODO) Need to check if we need Tesla T4 when conversion.
  config = tf.ConfigProto()
  config.gpu_options.allow_growth = True
  
  # Get path to calibration data.
  calibration_files = get_calibration_files(calib_image_dir, 'validation*')
  
  # Create dataset and apply preprocess
  # (TODO) Get num cpus to set appropriate number to num_parallel_calls
  dataset = tf.data.TFRecordDataset(calibration_files)
  dataset = dataset.apply(
      tf.contrib.data.map_and_batch(
          map_func=preprocess, batch_size=batch_size,
          num_parallel_calls=multiprocessing.cpu_count()))
  
  """
  Step 1: Creating the calibration graph.
  """
  
  # Create TF-TRT INT8 calibration graph.
  trt_int8_calib_graph = trt.create_inference_graph(
      input_graph_def=None,
      outputs=[output_tensor],
      max_batch_size=batch_size,
      input_saved_model_dir=input_model_dir,    
      precision_mode=precision_mode)

  # Calibrate graph.
  with tf.Session(graph=tf.Graph(), config=config) as sess:
    tf.logging.info('preparing calibration data...')
    iterator = dataset.make_one_shot_iterator()
    next_element = iterator.get_next()
    
    tf.logging.info('Loading INT8 calibration graph...')
    output_node = tf.import_graph_def(
        trt_int8_calib_graph, return_elements=[output_tensor], name='')
    
    tf.logging.info('Calibrate model with calibration data...')
    for _ in range(epochs):
      sess.run(output_node,
               feed_dict={input_tensor: sess.run(next_element)[0]})
  
  """
  Step 2: Converting the calibration graph to inference graph
  """
  tf.logging.info('Creating TF-TRT INT8 inference engine...')
  trt_int8_calibrated_graph = trt.calib_graph_to_infer_graph(
      trt_int8_calib_graph)
  
  # Copy MetaGraph from base model.
  with tf.Session(graph=tf.Graph(), config=config) as sess:
    base_model = tf.saved_model.loader.load(
        sess, [tf.saved_model.tag_constants.SERVING], input_model_dir)
    
    metagraph = tf.MetaGraphDef()
    metagraph.graph_def.CopyFrom(trt_int8_calibrated_graph)
    for key in base_model.collection_def:
      if key not in ['variables', 'local_variables', 'model_variables',
                     'trainable_variables', 'train_op', 'table_initializer']:
        metagraph.collection_def[key].CopyFrom(base_model.collection_def[key])
        
    metagraph.meta_info_def.CopyFrom(base_model.meta_info_def)
    for key in base_model.signature_def:
      metagraph.signature_def[key].CopyFrom(base_model.signature_def[key])
      
  saved_model_builder = (
      tf.saved_model.builder.SavedModelBuilder(output_model_dir))

  # Write SavedModel with INT8 precision.
  with tf.Graph().as_default():
    tf.graph_util.import_graph_def(
        trt_int8_calibrated_graph, return_elements=[output_tensor], name='')
    with tf.Session(config=config) as sess:
      saved_model_builder.add_meta_graph_and_variables(
          sess, ('serve',), signature_def_map=metagraph.signature_def)

  # Ignore other meta graphs from the input SavedModel.
  saved_model_builder.save()