in easycv/apis/export.py [0:0]
def _export_cls(model, cfg, filename):
""" export cls (cls & metric learning)model and preprocess config
Args:
model (nn.Module): model to be exported
cfg: Config object
filename (str): filename to save exported models
"""
if hasattr(cfg, 'export'):
export_cfg = cfg.export
else:
export_cfg = dict(export_neck=False)
export_type = export_cfg.get('export_type', 'raw')
export_neck = export_cfg.get('export_neck', True)
label_map_path = cfg.get('label_map_path', None)
class_list = None
if label_map_path is not None:
class_list = io.open(label_map_path).readlines()
elif hasattr(cfg, 'class_list'):
class_list = cfg.class_list
elif hasattr(cfg, 'CLASSES'):
class_list = cfg.CLASSES
model_config = dict(
type='Classification',
backbone=replace_syncbn(cfg.model.backbone),
)
# avoid load pretrained model
model_config['pretrained'] = False
if export_neck:
if hasattr(cfg.model, 'neck'):
model_config['neck'] = cfg.model.neck
if hasattr(cfg.model, 'head'):
model_config['head'] = cfg.model.head
else:
print("this cls model doesn't contain cls head, we add a dummy head!")
model_config['head'] = head = dict(
type='ClsHead',
with_avg_pool=True,
in_channels=model_config['backbone'].get('num_classes', 2048),
num_classes=1000,
)
img_norm_cfg = dict(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
if hasattr(cfg, 'test_pipeline'):
test_pipeline = cfg.test_pipeline
for pipe in test_pipeline:
if pipe['type'] == 'Collect':
pipe['keys'] = ['img']
else:
test_pipeline = [
dict(type='Resize', size=[224, 224]),
dict(type='ToTensor'),
dict(type='Normalize', **img_norm_cfg),
dict(type='Collect', keys=['img'])
]
config = dict(
model=model_config,
test_pipeline=test_pipeline,
class_list=class_list,
)
meta = dict(config=json.dumps(config))
state_dict = OrderedDict()
for k, v in model.state_dict().items():
if k.startswith('backbone'):
state_dict[k] = v
if export_neck and (k.startswith('neck') or k.startswith('head')):
state_dict[k] = v
if export_type == 'raw':
checkpoint = dict(state_dict=state_dict, meta=meta, author='EasyCV')
with io.open(filename, 'wb') as ofile:
torch.save(checkpoint, ofile)
elif export_type == 'onnx':
_export_onnx_cls(model, model_config, cfg, filename, config)
else:
raise ValueError('Only support export onnx/raw model!')