in mobile_cv/model_zoo/tools/create_model.py [0:0]
def convert_int8_jit(args, model, data, folder_name="int8_jit"):
if not args.convert_int8:
return None, None, None
try:
print("Converting to int8 jit...")
if args.int8_backend is not None:
torch.backends.quantized.engine = args.int8_backend
if not USE_GRAPH_MODE_QUANT:
# trace model
traced_model, traced_output = model_utils.convert_int8_jit(
model, data, int8_backend=args.int8_backend, add_quant_stub=False
)
else:
quant = qu.PostQuantizationFX(model)
quant_model = (
quant.set_quant_backend("default")
.prepare()
.calibrate_model([data], 1)
.convert_model()
)
traced_model, traced_output = model_utils.convert_torch_script(
quant_model,
data,
fuse_bn=args.fuse_bn,
use_get_traceable=bool(args.use_get_traceable),
)
print(traced_model)
output_dir = os.path.join(args.output_dir, folder_name)
model_utils.save_model(output_dir, traced_model, data)
return traced_model, traced_output, output_dir
except Exception as e:
print(f"Converting to int8 jit failed. {e}")
return None, None, None