in scripts/model_convert.py [0:0]
def convert_retinanet():
import torchvision.transforms as T
import torchvision.models.detection as detection
device = torch.device('cuda:0')
model = detection.retinanet_resnet50_fpn_v2(pretrained=True).to(device)
model.eval()
# export_onnx(model, 'weights/retinanet_resnet50_fpn_v2.onnx')
model_yolo = Model('models/cfg/retina.yaml', ch=3, nc=80).to(device) # create
model_yolo.eval()
# model_yolo.info()
# print(model_yolo)
# return
model_state_dict = model.state_dict()
yolo_state_dict = model_yolo.state_dict()
# for k, v in yolo_state_dict.items():
# print(k, v.shape)
# return
state_dict = {}
for k, v in model_state_dict.items():
# print(k, v.shape)
# continue
if '.body.' in k:
if '.layer' not in k:
k = k.replace('backbone.body.', 'model.0.')
k = k.replace('1.', '.')
else:
li = int(k.split('.layer')[1].split('.')[0])
k = k.replace(f'backbone.body.layer{li}.', f'model.{li+1}.m.')
elif '.fpn.' in k:
if 'blocks.p' in k:
continue
li = int(k.split('_blocks.')[1].split('.')[0])
if '.inner_' in k:
k = k.replace(f'backbone.fpn.inner_blocks.{li}.0.', f'model.{13-li*4}.')
k = k.replace('model.5.', 'model.6.')
elif '.layer_' in k:
k = k.replace(f'backbone.fpn.layer_blocks.{li}.0.', f'model.{15-li*4}.')
else:
raise ValueError(k)
else:
continue
state_dict[k] = v
# print(k, v.shape)
# continue
for k, v in yolo_state_dict.items():
if k not in state_dict:
state_dict[k] = v
model_yolo.load_state_dict(state_dict, strict=True)
model_yolo.eval()
torch.save({'model': model_yolo}, 'weights/retinanet.pt')
return
img = torch.rand((1, 3, 320, 320)).to(device)
features_retina = model.backbone(img)
# for k, v in features_retina.items():
# print(k, type(k), type(v), v.shape)
features_retina = [features_retina[str(i)] for i in [2, 1, 0]] # large, medium, small
features_yolo = model_yolo(img)
# print([x.shape for x in features_yolo])
for fi, (f1, f2) in enumerate(zip(features_retina, features_yolo)):
assert (f1 == f2).all(), f'{f1 - f2}\n{fi}'
# e1, e2 = None, None
# f1, f2 = None, None
# try:
# f1 = model.backbone.body['conv1'](img)
# f1 = model.backbone.body['bn1'](f1)
# f1 = model.backbone.body['relu'](f1)
# f1 = model.backbone.body['maxpool'](f1)
# f1 = model.backbone.body['layer1'](f1)
# except Exception as e:
# print(e)
# e1 = str(e)
# try:
# f2 = model_yolo(img)
# except Exception as e:
# print(e)
# e2 = str(e)
# # print(e1 == e2)
# print(f1.shape, f2.shape, (f1 == f2).all().item())
# # print(f1 - f2)
# print((model.backbone.body['conv1'].weight == model_yolo.model[0].conv.weight).all().item(), end=" ")
# print((model.backbone.body['bn1'].weight == model_yolo.model[0].bn.weight).all().item(), end=" ")
# print((model.backbone.body['bn1'].bias == model_yolo.model[0].bn.bias).all().item(), end=" ")
# print((model.backbone.body['bn1'].running_mean == model_yolo.model[0].bn.running_mean).all().item(), end=" ")
# print((model.backbone.body['bn1'].running_var == model_yolo.model[0].bn.running_var).all().item(), end=" ")
# print((model.backbone.body['bn1'].num_batches_tracked == model_yolo.model[0].bn.num_batches_tracked).all().item(), end="\n")
# print(model.backbone.body['bn1'].training, model_yolo.model[0].bn.training)
# print(model.backbone.body['bn1'])
# print(model_yolo.model[0].bn)
# print(model.backbone.body['layer1'][0].bn1 == model_yolo.model[2].m[0].bn1)
# print(model_yolo.model[2].m[0].bn1)
return
transform = T.Compose([
T.ToTensor(),
T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
img = transform(img_org.convert("RGB")).unsqueeze(0).to(device)
predictions = model(img)
pred_boxes = predictions[0]['boxes'].cpu().detach().numpy() # 边界框
pred_scores = predictions[0]['scores'].cpu().detach().numpy() # 得分
pred_labels = predictions[0]['labels'].cpu().detach().numpy() # 标签(通常是COCO数据集的类别索引)
# print(pred_scores)
score_threshold = 0.1
fig, ax = plt.subplots(1, figsize=(12, 9))
ax.imshow(img_org)
for box, score, label in zip(pred_boxes, pred_scores, pred_labels):
if score > score_threshold:
x1, y1, x2, y2 = box
patch = patches.Rectangle((x1, y1), x2-x1, y2-y1, linewidth=2, edgecolor='r', facecolor='none')
ax.add_patch(patch)
plt.text(x1, y1, f"{label}: {score:.2f}", fontsize=12, color='white',
bbox=dict(facecolor='red', alpha=0.5))
plt.axis('off')
plt.savefig('tmp.png', dpi=400)