in machine_learning/ml_infrastructure/inference-server-performance/server/scripts/tensorrt-optimization.py [0:0]
def main():
parser = argparse.ArgumentParser(description='A TensorRT example')
parser.add_argument(
'--input-model-dir',
default='models/resnet/original/00001',
help='input directory of original model'
)
parser.add_argument(
'--output-dir',
default='models/resnet/',
help='output directory for converted models'
)
parser.add_argument(
'--input-tensor',
default='input_tensor:0',
help='a name of TF input ops used in specified SavedModel file'
)
parser.add_argument(
'--output-tensor',
default='softmax_tensor:0',
help='a name of TF output ops used in specified SavedModel file'
)
parser.add_argument(
'--precision-mode',
default='FP32',
help='target precision for TF-TRT conversion (default: FP32)'
)
parser.add_argument(
'--calib-image-dir',
default='gs://path-to-imagenet-dataset',
help='path to image dataset used for calibration for an INT8 model.')
parser.add_argument(
'--batch-size',
type=int,
default=64,
help='batch size for output model.')
parser.add_argument(
'--calibration-epochs',
type=int,
default=10,
help='number of epochs for INT8 calibration')
args = parser.parse_args()
# This program only supports FP32, FP16 and INT8 as a precision-mode.
if args.precision_mode not in ['FP32', 'FP16', 'INT8']:
raise ValueError(
'{} is not a valid precision-mode.'.format(precision_mode))
output_model_dir = os.path.join(
args.output_dir, args.precision_mode, '00001')
if os.path.exists(output_model_dir):
shutil.rmtree(output_model_dir)
if args.precision_mode in ['FP32', 'FP16']:
convert_fp32_or_fp16(
input_model_dir=args.input_model_dir,
output_model_dir=output_model_dir,
batch_size=args.batch_size,
precision_mode=args.precision_mode)
else:
convert_int8(
input_model_dir=args.input_model_dir,
output_model_dir=output_model_dir,
batch_size=args.batch_size,
precision_mode=args.precision_mode,
calib_image_dir=args.calib_image_dir,
input_tensor=args.input_tensor,
output_tensor=args.output_tensor,
epochs=args.calibration_epochs)