def main()

in machine_learning/ml_infrastructure/inference-server-performance/server/scripts/tensorrt-optimization.py [0:0]


def main():
  parser = argparse.ArgumentParser(description='A TensorRT example')
  
  parser.add_argument(
    '--input-model-dir',
    default='models/resnet/original/00001',
    help='input directory of original model'    
  )
  parser.add_argument(
    '--output-dir',
    default='models/resnet/',
    help='output directory for converted models'
  )
  parser.add_argument(
    '--input-tensor',
    default='input_tensor:0',
    help='a name of TF input ops used in specified SavedModel file'
  )
  parser.add_argument(
    '--output-tensor',
    default='softmax_tensor:0',
    help='a name of TF output ops used in specified SavedModel file'
  )
  parser.add_argument(
    '--precision-mode',
    default='FP32',
    help='target precision for TF-TRT conversion (default: FP32)'
  )
  parser.add_argument(
    '--calib-image-dir',
    default='gs://path-to-imagenet-dataset',
    help='path to image dataset used for calibration for an INT8 model.')
  parser.add_argument(
    '--batch-size',
    type=int,
    default=64,
    help='batch size for output model.')
  parser.add_argument(
    '--calibration-epochs',
    type=int,
    default=10,
    help='number of epochs for INT8 calibration')
  
  args = parser.parse_args()

  # This program only supports FP32, FP16 and INT8 as a precision-mode.
  if args.precision_mode not in ['FP32', 'FP16', 'INT8']:
      raise ValueError(
        '{} is not a valid precision-mode.'.format(precision_mode))

  output_model_dir = os.path.join(
      args.output_dir, args.precision_mode, '00001')
  
  if os.path.exists(output_model_dir):
    shutil.rmtree(output_model_dir)
  
  if args.precision_mode in ['FP32', 'FP16']:
    convert_fp32_or_fp16(
        input_model_dir=args.input_model_dir,
        output_model_dir=output_model_dir,
        batch_size=args.batch_size,
        precision_mode=args.precision_mode)
  else:
    convert_int8(
        input_model_dir=args.input_model_dir,
        output_model_dir=output_model_dir,
        batch_size=args.batch_size,
        precision_mode=args.precision_mode,
        calib_image_dir=args.calib_image_dir,
        input_tensor=args.input_tensor,
        output_tensor=args.output_tensor,
        epochs=args.calibration_epochs)