def convert_retinanet()

in scripts/model_convert.py [0:0]


def convert_retinanet():
    import torchvision.transforms as T
    import torchvision.models.detection as detection

    device = torch.device('cuda:0')

    model = detection.retinanet_resnet50_fpn_v2(pretrained=True).to(device)
    model.eval()

    # export_onnx(model, 'weights/retinanet_resnet50_fpn_v2.onnx')

    model_yolo = Model('models/cfg/retina.yaml', ch=3, nc=80).to(device)  # create
    model_yolo.eval()
    # model_yolo.info()
    # print(model_yolo)

    # return

    model_state_dict = model.state_dict()
    yolo_state_dict = model_yolo.state_dict()
    # for k, v in yolo_state_dict.items():
    #     print(k, v.shape)
    # return
    state_dict = {}
    for k, v in model_state_dict.items():
        # print(k, v.shape)
        # continue
        if '.body.' in k:
            if '.layer' not in k:
                k = k.replace('backbone.body.', 'model.0.')
                k = k.replace('1.', '.')
            else:
                li = int(k.split('.layer')[1].split('.')[0])
                k = k.replace(f'backbone.body.layer{li}.', f'model.{li+1}.m.')
        elif '.fpn.' in k:
            if 'blocks.p' in k:
                continue
            li = int(k.split('_blocks.')[1].split('.')[0])
            if '.inner_' in k:
                k = k.replace(f'backbone.fpn.inner_blocks.{li}.0.', f'model.{13-li*4}.')
                k = k.replace('model.5.', 'model.6.')
            elif '.layer_' in k:
                k = k.replace(f'backbone.fpn.layer_blocks.{li}.0.', f'model.{15-li*4}.')
            else:
                raise ValueError(k)
        else:
            continue
        
        state_dict[k] = v
        # print(k, v.shape)
        # continue
    for k, v in yolo_state_dict.items():
        if k not in state_dict:
            state_dict[k] = v
    
    model_yolo.load_state_dict(state_dict, strict=True)
    model_yolo.eval()

    torch.save({'model': model_yolo}, 'weights/retinanet.pt')
    return

    img = torch.rand((1, 3, 320, 320)).to(device)
    features_retina = model.backbone(img)
    # for k, v in features_retina.items():
    #     print(k, type(k), type(v), v.shape)
    features_retina = [features_retina[str(i)] for i in [2, 1, 0]]  # large, medium, small

    features_yolo = model_yolo(img)
    # print([x.shape for x in features_yolo])

    for fi, (f1, f2) in enumerate(zip(features_retina, features_yolo)):
        assert (f1 == f2).all(), f'{f1 - f2}\n{fi}'

    # e1, e2 = None, None
    # f1, f2 = None, None
    # try:
    #     f1 = model.backbone.body['conv1'](img)
    #     f1 = model.backbone.body['bn1'](f1)
    #     f1 = model.backbone.body['relu'](f1)
    #     f1 = model.backbone.body['maxpool'](f1)
    #     f1 = model.backbone.body['layer1'](f1)
    # except Exception as e:
    #     print(e)
    #     e1 = str(e)

    # try:
    #     f2 = model_yolo(img)
    # except Exception as e:
    #     print(e)
    #     e2 = str(e)

    # # print(e1 == e2)
    # print(f1.shape, f2.shape, (f1 == f2).all().item())
    # # print(f1 - f2)
    # print((model.backbone.body['conv1'].weight == model_yolo.model[0].conv.weight).all().item(), end=" ")
    # print((model.backbone.body['bn1'].weight == model_yolo.model[0].bn.weight).all().item(), end=" ")
    # print((model.backbone.body['bn1'].bias == model_yolo.model[0].bn.bias).all().item(), end=" ")
    # print((model.backbone.body['bn1'].running_mean == model_yolo.model[0].bn.running_mean).all().item(), end=" ")
    # print((model.backbone.body['bn1'].running_var == model_yolo.model[0].bn.running_var).all().item(), end=" ")
    # print((model.backbone.body['bn1'].num_batches_tracked == model_yolo.model[0].bn.num_batches_tracked).all().item(), end="\n")
    # print(model.backbone.body['bn1'].training, model_yolo.model[0].bn.training)
    # print(model.backbone.body['bn1'])
    # print(model_yolo.model[0].bn)

    # print(model.backbone.body['layer1'][0].bn1 == model_yolo.model[2].m[0].bn1)
    # print(model_yolo.model[2].m[0].bn1)

    return
    transform = T.Compose([
        T.ToTensor(),
        T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])
    img = transform(img_org.convert("RGB")).unsqueeze(0).to(device)

    predictions = model(img)
    pred_boxes = predictions[0]['boxes'].cpu().detach().numpy()  # 边界框
    pred_scores = predictions[0]['scores'].cpu().detach().numpy()  # 得分
    pred_labels = predictions[0]['labels'].cpu().detach().numpy()  # 标签(通常是COCO数据集的类别索引)
    # print(pred_scores)

    score_threshold = 0.1

    fig, ax = plt.subplots(1, figsize=(12, 9))
    ax.imshow(img_org)

    for box, score, label in zip(pred_boxes, pred_scores, pred_labels):
        if score > score_threshold:
            x1, y1, x2, y2 = box
            
            patch = patches.Rectangle((x1, y1), x2-x1, y2-y1, linewidth=2, edgecolor='r', facecolor='none')
            
            ax.add_patch(patch)

            plt.text(x1, y1, f"{label}: {score:.2f}", fontsize=12, color='white', 
                    bbox=dict(facecolor='red', alpha=0.5))

    plt.axis('off')
    plt.savefig('tmp.png', dpi=400)