def convert_gpvit()

in scripts/model_convert.py [0:0]


def convert_gpvit():
    from mmcv import Config, DictAction
    from mmcv.runner import (get_dist_info, init_dist, load_checkpoint,
                            wrap_fp16_model)
    from mmdet.models import build_detector

    device = torch.device('cuda:0')

    config_path = 'models/gpvit/configs/gpvit/retinanet/gpvit_l2_retinanet_1x.py'
    ckpt_path = 'weights/gpvit_l2_retinanet_1x.pth'
    cfg = Config.fromfile(config_path)
    cfg.model.pretrained = None
    if cfg.model.get('neck'):
        if isinstance(cfg.model.neck, list):
            for neck_cfg in cfg.model.neck:
                if neck_cfg.get('rfp_backbone'):
                    if neck_cfg.rfp_backbone.get('pretrained'):
                        neck_cfg.rfp_backbone.pretrained = None
        elif cfg.model.neck.get('rfp_backbone'):
            if cfg.model.neck.rfp_backbone.get('pretrained'):
                cfg.model.neck.rfp_backbone.pretrained = None
    cfg.gpu_ids = ['0']
    cfg.model.train_cfg = None
    model_mmcv = build_detector(cfg.model, test_cfg=cfg.get('test_cfg')).to(device)
    load_checkpoint(model_mmcv, ckpt_path, map_location='cpu')
    model_mmcv.eval()
    # model_mmcv = convert_sync_batchnorm_to_batchnorm(model)

    # from models.gpvit import GPViTAdapterSingleStage
    # backbone = GPViTAdapterSingleStage().to(device)
    # backbone.eval()
    # # model_dict = backbone.state_dict()
    # # for i, (k, v) in enumerate(model_dict.items()):
    # #     print(i, k, v.shape)
    state_dict = torch.load(ckpt_path)['state_dict']
    # for i, (k, v) in enumerate(state_dict.items()):
    #     print(i, k, v.shape)
    # backbone.load_state_dict(
    #     {k.replace('backbone.', ''): v for k, v in state_dict.items() if 'backbone.' in k}
    #     , strict=False)
    # exit()
    
    # print((backbone.ad_norm2.weight == model_mmcv.backbone.ad_norm2.weight).all())
    img = torch.rand((1, 3, 640, 640), device=device)
    feats_mmcv = model_mmcv.backbone(img)
    # feats_yolo = backbone(img)

    model_yolo = Model('models/cfg/gpvit_l2.yaml', ch=3, nc=10).to(device)  # create
    # model_yolo = convert_sync_batchnorm_to_batchnorm(model_yolo)
    model_yolo.eval()
    model_dict = model_yolo.state_dict()
    # for i, (k, v) in enumerate(model_dict.items()):
    #     print(i, k, v.shape)
    # exit()
    updated_state_dict = {}
    for k, v in state_dict.items():
        if k.startswith('backbone.'):
            if 'backbone.level_embed' in k:
                k = 'model.0.level_embed'
            elif 'backbone.spm' in  k:
                k = k.replace('backbone.spm.', 'model.0.')
            else:
                k = k.replace('backbone.', 'model.1.')  # 1 or 6
        elif k.startswith('neck.'):
            li = int(k.split('.')[2])
            if li > 2:
                continue
            if 'lateral_convs' in k:
                if li == 2:
                    k = k.replace(f'neck.lateral_convs.{li}.', 'model.5.')
                else:
                    k = k.replace(f'neck.lateral_convs.{li}.', f'model.{11-li*4}.')
            elif 'fpn_convs' in k:
                k = k.replace(f'neck.fpn_convs.{li}.', f'model.{14-li*4}.')
            else:
                raise ValueError(k)
        else:
            continue
        updated_state_dict[k] = v
    # for i, (k, v) in enumerate(updated_state_dict.items()):
    #     print(i, k, v.shape)
    updated_state_dict.update({k: v for k, v in model_dict.items() if k not in updated_state_dict})

    model_yolo.load_state_dict(updated_state_dict, strict=True)
    # print(model_yolo.model[5].bn.weight)
    # print(model_mmcv.neck.lateral_convs[2].bn.weight)
    # exit()

    torch.save({'model': model_yolo}, 'weights/gpvit_l2.pt')
    exit()


    e1, e2 = None, None
    try:
        feats_mmcv = model_mmcv.neck(feats_mmcv)[:3]
    except Exception as e:
        e1 = str(e)

    feats_yolo = None
    try:
        feats_yolo = model_yolo(img)[0]
    except Exception as e:
        e2 = str(e)
    
    assert e1 == e2, f'\n{e1}\n{e2}'

    # print(type(feats_mmcv), type(feats_yolo))
    for i, (f1, f2) in enumerate(zip(feats_mmcv, feats_yolo)):
        # print(f1.shape, f2.shape)
        assert (f1 == f2).all(), f'\nlayer{i}\n{f1}\n{f2}'

    print('Done.')