in baselines/GeM_baseline.py [0:0]
def load_model(name, checkpoint_file):
if name == "zoo_resnet50":
model = torchvision.models.resnet50(pretrained=True)
model.eval()
return model
if name == "multigrain_resnet50":
model = torchvision.models.resnet50(pretrained=False)
st = torch.load(checkpoint_file)
state_dict = OrderedDict([
(name[9:], v)
for name, v in st["model_state"].items() if name.startswith("features.")
])
model.fc
model.fc = None
model.load_state_dict(state_dict)
model.eval()
return model
assert False