in easycv/apis/export.py [0:0]
def _export_yolox(model, cfg, filename):
""" export cls (cls & metric learning)model and preprocess config
Args:
model (nn.Module): model to be exported
cfg: Config object
filename (str): filename to save exported models
"""
if hasattr(cfg, 'export'):
export_type = getattr(cfg.export, 'export_type', 'raw')
default_export_type_list = ['raw', 'jit', 'blade', 'onnx']
if export_type not in default_export_type_list:
logging.warning(
'YOLOX-PAI only supports the export type as [raw,jit,blade,onnx], otherwise we use raw as default'
)
export_type = 'raw'
model.export_type = export_type
if export_type != 'raw':
from easycv.utils.misc import reparameterize_models
# only when we use jit or blade, we need to reparameterize_models before export
model = reparameterize_models(model)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = copy.deepcopy(model)
preprocess_jit = cfg.export.get('preprocess_jit', False)
batch_size = cfg.export.get('batch_size', 1)
static_opt = cfg.export.get('static_opt', True)
use_trt_efficientnms = cfg.export.get('use_trt_efficientnms',
False)
# assert image scale and assgin input
img_scale = cfg.get('img_scale', (640, 640))
assert (
len(img_scale) == 2
), 'Export YoloX predictor config contains img_scale must be (int, int) tuple!'
input = 255 * torch.rand((batch_size, 3) + tuple(img_scale))
# assert use_trt_efficientnms only happens when static_opt=True
if static_opt is not True:
assert (
use_trt_efficientnms == False
), 'Export YoloX predictor use_trt_efficientnms=True only when use static_opt=True!'
# allow to save a preprocess jit model with exported model
save_preprocess_jit = False
if preprocess_jit:
save_preprocess_jit = True
# set model use_trt_efficientnms
if use_trt_efficientnms:
from easycv.toolkit.blade import create_tensorrt_efficientnms
if hasattr(model, 'get_nmsboxes_num'):
nmsbox_num = int(model.get_nmsboxes_num(img_scale))
else:
logging.warning(
'PAI-YOLOX: use_trt_efficientnms encounter model has no attr named get_nmsboxes_num, use 8400 (80*80+40*40+20*20)cas default!'
)
nmsbox_num = 8400
tmp_example_scores = torch.randn(
[batch_size, nmsbox_num, 4 + 1 + len(cfg.CLASSES)],
dtype=torch.float32)
logging.warning(
'PAI-YOLOX: use_trt_efficientnms with staic shape [{}, {}, {}]'
.format(batch_size, nmsbox_num, 4 + 1 + len(cfg.CLASSES)))
model.trt_efficientnms = create_tensorrt_efficientnms(
tmp_example_scores,
iou_thres=model.nms_thre,
score_thres=model.test_conf)
model.use_trt_efficientnms = True
model.eval()
model.to(device)
model_export = ModelExportWrapper(
model,
input.to(device),
trace_model=True,
)
model_export.eval().to(device)
# trace model
yolox_trace = torch.jit.trace(model_export, input.to(device))
# save export model
if export_type == 'blade':
blade_config = cfg.export.get(
'blade_config',
dict(enable_fp16=True, fp16_fallback_op_ratio=0.3))
from easycv.toolkit.blade import blade_env_assert, blade_optimize
assert blade_env_assert()
# optimize model with blade
yolox_blade = blade_optimize(
speed_test_model=model,
model=yolox_trace,
inputs=(input.to(device), ),
blade_config=blade_config,
static_opt=static_opt)
with io.open(filename + '.blade', 'wb') as ofile:
torch.jit.save(yolox_blade, ofile)
with io.open(filename + '.blade.config.json', 'w') as ofile:
config = dict(
model=cfg.model,
export=cfg.export,
test_pipeline=cfg.test_pipeline,
classes=cfg.CLASSES)
json.dump(config, ofile)
if export_type == 'onnx':
with io.open(
filename + '.config.json' if filename.endswith('onnx')
else filename + '.onnx.config.json', 'w') as ofile:
config = dict(
model=cfg.model,
export=cfg.export,
test_pipeline=cfg.test_pipeline,
classes=cfg.CLASSES)
json.dump(config, ofile)
torch.onnx.export(
model,
input.to(device),
filename if filename.endswith('onnx') else filename +
'.onnx',
export_params=True,
opset_version=12,
do_constant_folding=True,
input_names=['input'],
output_names=['output'],
)
if export_type == 'jit':
with io.open(filename + '.jit', 'wb') as ofile:
torch.jit.save(yolox_trace, ofile)
with io.open(filename + '.jit.config.json', 'w') as ofile:
config = dict(
model=cfg.model,
export=cfg.export,
test_pipeline=cfg.test_pipeline,
classes=cfg.CLASSES)
json.dump(config, ofile)
# save export preprocess/postprocess
if save_preprocess_jit:
tpre_input = 255 * torch.rand((batch_size, ) + img_scale +
(3, ))
tpre = ProcessExportWrapper(
example_inputs=tpre_input.to(device),
process_fn=PreProcess(
target_size=img_scale, keep_ratio=True))
tpre.eval().to(device)
preprocess = torch.jit.script(tpre)
with io.open(filename + '.preprocess', 'wb') as prefile:
torch.jit.save(preprocess, prefile)
else:
if hasattr(cfg, 'test_pipeline'):
# with last pipeline Collect
test_pipeline = cfg.test_pipeline
print(test_pipeline)
else:
print('test_pipeline not found, using default preprocessing!')
raise ValueError('export model config without test_pipeline')
config = dict(
model=cfg.model,
test_pipeline=test_pipeline,
CLASSES=cfg.CLASSES,
)
meta = dict(config=json.dumps(config))
checkpoint = dict(
state_dict=model.state_dict(), meta=meta, author='EasyCV')
with io.open(filename, 'wb') as ofile:
torch.save(checkpoint, ofile)