def convert_rtmdet()

in scripts/model_convert.py [0:0]


def convert_rtmdet():
    import mmcv
    from mmdet.apis import inference_detector, init_detector
    from mmengine.config import Config, ConfigDict
    from mmengine.logging import print_log
    from mmengine.utils import ProgressBar, path

    from mmyolo.registry import VISUALIZERS
    from mmyolo.utils import switch_to_deploy
    from mmyolo.utils.labelme_utils import LabelmeFormat
    from mmyolo.utils.misc import get_file_list, show_data_classes

    rtmdet_root = '../mmyolo'
    cfg = f'{rtmdet_root}/configs/rtmdet/rtmdet_m_syncbn_fast_8xb32-300e_coco.py'
    checkpoint = f'{rtmdet_root}/weights/rtmdet_m_syncbn_fast_8xb32-300e_coco_20230102_135952-40af4fe8.pth'
    file = f'{rtmdet_root}/demo/demo.jpg'

    device = torch.device('cuda:0')
    config = Config.fromfile(cfg)
    if 'init_cfg' in config.model.backbone:
        config.model.backbone.init_cfg = None
    
    model = init_detector(config, checkpoint, device=device, cfg_options={})

    # export_onnx(model, f'weights/{osp.basename(checkpoint).split(".")[0]}.onnx')
    # return

    # result = inference_detector(model, file)
    # print(type(result.pred_instances))

    model_yolo = Model('models/cfg/rtmdet_m.yaml', ch=3, nc=80).to(device)  # create
    model_yolo.eval()

    # # return

    model_state_dict = model.state_dict()
    yolo_state_dict = model_yolo.state_dict()
    # for k, v in yolo_state_dict.items():
    #     print(k, v.shape)
    # return
    state_dict = {}
    csplayer_cnt, csplayer_idx = 0, ()
    for k, v in model_state_dict.items():
        # print(k, v.shape)
        # continue
        if 'backbone.stem.' in k:
            k = k.replace('backbone.stem.', 'model.')
        elif 'backbone.stage' in k:
            m, n = list(map(int, k.split('backbone.stage')[1].split('.')[:2]))
            if (m, n) != csplayer_idx:
                csplayer_idx = (m, n)
                csplayer_cnt += 1
            k = k.replace(f'backbone.stage{m}.{n}.', f'model.{2+csplayer_cnt}.')
            if m == 4 and n == 1:
                k = k.replace(f'model.{2+csplayer_cnt}.conv', f'model.{2+csplayer_cnt}.cv')
        elif 'neck.reduce_layers.' in k:
            k = k.replace('neck.reduce_layers.2.', f'model.{2+csplayer_cnt+1}.')
        elif 'neck.top_down_layers.' in k:
            try:
                m, n = list(map(int, k.split('neck.top_down_layers.')[1].split('.')[:2]))
            except:
                m = int(k.split('.')[2])
                n = -1
            if (m, n) != csplayer_idx:
                csplayer_idx = (m, n)
                csplayer_cnt += 1
            if n != -1:
                k = k.replace(f'neck.top_down_layers.{m}.{n}.', f'model.{2+3+csplayer_cnt}.')
            else:
                k = k.replace(f'neck.top_down_layers.{m}.', f'model.{2+3+2+csplayer_cnt}.')
        elif 'neck.downsample_layers.' in k:
            li = int(k.split('neck.downsample_layers.')[1].split('.')[0])
            k = k.replace(f'neck.downsample_layers.{li}.', f'model.{2+3+2+1+csplayer_cnt+3*li}.')
        elif 'neck.bottom_up_layers.' in k:
            li = int(k.split('neck.bottom_up_layers.')[1].split('.')[0])
            k = k.replace(f'neck.bottom_up_layers.{li}.', f'model.{2+3+2+1+2+csplayer_cnt+3*li}.')
        else:
            continue
        
        state_dict[k] = v
        # print(k, v.shape)
    # return

    for k, v in yolo_state_dict.items():
        if k not in state_dict:
            state_dict[k] = v
    
    model_yolo.load_state_dict(state_dict, strict=True)
    model_yolo.eval()
    
    torch.save({'model': model_yolo}, 'weights/rtmdet_m.pt')
    return

    def _forword_neck(self, inputs):
        assert len(inputs) == len(self.in_channels)
        # self.upsample = torch.nn.Upsample(scale_factor=2)

        # top-down path
        inner_outs = [inputs[-1]]
        for idx in range(len(self.in_channels) - 1, 0, -1):
            feat_heigh = inner_outs[0]
            feat_low = inputs[idx - 1]
            feat_heigh = self.reduce_layers[idx](feat_heigh)
            inner_outs[0] = feat_heigh

            upsample_feat = self.upsample_layers[len(self.in_channels) - 1 - idx](feat_heigh)

            inner_out = self.top_down_layers[len(self.in_channels) - 1 - idx](
                torch.cat([upsample_feat, feat_low], 1))
            inner_outs.insert(0, inner_out)

        # bottom-up path
        outs = [inner_outs[0]]
        for idx in range(len(self.in_channels) - 1):
            feat_low = outs[-1]
            feat_height = inner_outs[idx + 1]
            downsample_feat = self.downsample_layers[idx](feat_low)
            out = self.bottom_up_layers[idx](
                torch.cat([downsample_feat, feat_height], 1))
            outs.append(out)

        # out convs
        # for idx, conv in enumerate(self.out_layers):
        #     outs[idx] = conv(outs[idx])

        return tuple(outs)

    img = torch.rand((1, 3, 160, 160)).to(device)
    features_backbone = model.backbone(img)
    # features_rtmdet = model.neck(features_backbone)  # small, mediam, large
    # print([x.shape for x in features_rtmdet])
    features_rtmdet2 = _forword_neck(model.neck, features_backbone)
    # print([x.shape for x in features_rtmdet2])
    # for x1, x2, in zip(features_rtmdet, features_rtmdet2):
    #     assert (x1 == x2).all()

    features_yolo = model_yolo(img)
    # print([x.shape for x in features_yolo])

    for fi, (f1, f2) in enumerate(zip(features_rtmdet2, features_yolo)):
        assert (f1 == f2).all(), f'{f1 - f2}\n{fi}'
    print('Done.')
    return

    e1, e2 = None, None
    f1, f2 = None, None
    try:
        f1 = model.backbone.stem[0](img)
        # f1 = model.backbone.stem[0].conv(img)
        # f1 = model.backbone.stem[0].bn(f1)
        # f1 = model.backbone.body['relu'](f1)
        # f1 = model.backbone.body['maxpool'](f1)
        # f1 = model.backbone.body['layer1'](f1)
    except Exception as e:
        print(e)
        e1 = str(e)

    try:
        f2 = model_yolo(img)
    except Exception as e:
        print(e)
        e2 = str(e)

    # print(e1 == e2)
    print(f1.shape, f2.shape, (f1 == f2).all().item())