in scripts/model_convert.py [0:0]
def convert_gpvit():
from mmcv import Config, DictAction
from mmcv.runner import (get_dist_info, init_dist, load_checkpoint,
wrap_fp16_model)
from mmdet.models import build_detector
device = torch.device('cuda:0')
config_path = 'models/gpvit/configs/gpvit/retinanet/gpvit_l2_retinanet_1x.py'
ckpt_path = 'weights/gpvit_l2_retinanet_1x.pth'
cfg = Config.fromfile(config_path)
cfg.model.pretrained = None
if cfg.model.get('neck'):
if isinstance(cfg.model.neck, list):
for neck_cfg in cfg.model.neck:
if neck_cfg.get('rfp_backbone'):
if neck_cfg.rfp_backbone.get('pretrained'):
neck_cfg.rfp_backbone.pretrained = None
elif cfg.model.neck.get('rfp_backbone'):
if cfg.model.neck.rfp_backbone.get('pretrained'):
cfg.model.neck.rfp_backbone.pretrained = None
cfg.gpu_ids = ['0']
cfg.model.train_cfg = None
model_mmcv = build_detector(cfg.model, test_cfg=cfg.get('test_cfg')).to(device)
load_checkpoint(model_mmcv, ckpt_path, map_location='cpu')
model_mmcv.eval()
# model_mmcv = convert_sync_batchnorm_to_batchnorm(model)
# from models.gpvit import GPViTAdapterSingleStage
# backbone = GPViTAdapterSingleStage().to(device)
# backbone.eval()
# # model_dict = backbone.state_dict()
# # for i, (k, v) in enumerate(model_dict.items()):
# # print(i, k, v.shape)
state_dict = torch.load(ckpt_path)['state_dict']
# for i, (k, v) in enumerate(state_dict.items()):
# print(i, k, v.shape)
# backbone.load_state_dict(
# {k.replace('backbone.', ''): v for k, v in state_dict.items() if 'backbone.' in k}
# , strict=False)
# exit()
# print((backbone.ad_norm2.weight == model_mmcv.backbone.ad_norm2.weight).all())
img = torch.rand((1, 3, 640, 640), device=device)
feats_mmcv = model_mmcv.backbone(img)
# feats_yolo = backbone(img)
model_yolo = Model('models/cfg/gpvit_l2.yaml', ch=3, nc=10).to(device) # create
# model_yolo = convert_sync_batchnorm_to_batchnorm(model_yolo)
model_yolo.eval()
model_dict = model_yolo.state_dict()
# for i, (k, v) in enumerate(model_dict.items()):
# print(i, k, v.shape)
# exit()
updated_state_dict = {}
for k, v in state_dict.items():
if k.startswith('backbone.'):
if 'backbone.level_embed' in k:
k = 'model.0.level_embed'
elif 'backbone.spm' in k:
k = k.replace('backbone.spm.', 'model.0.')
else:
k = k.replace('backbone.', 'model.1.') # 1 or 6
elif k.startswith('neck.'):
li = int(k.split('.')[2])
if li > 2:
continue
if 'lateral_convs' in k:
if li == 2:
k = k.replace(f'neck.lateral_convs.{li}.', 'model.5.')
else:
k = k.replace(f'neck.lateral_convs.{li}.', f'model.{11-li*4}.')
elif 'fpn_convs' in k:
k = k.replace(f'neck.fpn_convs.{li}.', f'model.{14-li*4}.')
else:
raise ValueError(k)
else:
continue
updated_state_dict[k] = v
# for i, (k, v) in enumerate(updated_state_dict.items()):
# print(i, k, v.shape)
updated_state_dict.update({k: v for k, v in model_dict.items() if k not in updated_state_dict})
model_yolo.load_state_dict(updated_state_dict, strict=True)
# print(model_yolo.model[5].bn.weight)
# print(model_mmcv.neck.lateral_convs[2].bn.weight)
# exit()
torch.save({'model': model_yolo}, 'weights/gpvit_l2.pt')
exit()
e1, e2 = None, None
try:
feats_mmcv = model_mmcv.neck(feats_mmcv)[:3]
except Exception as e:
e1 = str(e)
feats_yolo = None
try:
feats_yolo = model_yolo(img)[0]
except Exception as e:
e2 = str(e)
assert e1 == e2, f'\n{e1}\n{e2}'
# print(type(feats_mmcv), type(feats_yolo))
for i, (f1, f2) in enumerate(zip(feats_mmcv, feats_yolo)):
# print(f1.shape, f2.shape)
assert (f1 == f2).all(), f'\nlayer{i}\n{f1}\n{f2}'
print('Done.')