in cp_examples/mip_finetune/mip_model.py [0:0]
def load_pretrained_model(arch, pretrained_file):
pretrained_dict = torch.load(pretrained_file)["state_dict"]
state_dict = {}
for k, v in pretrained_dict.items():
if k.startswith("model.encoder_q."):
k = k.replace("model.encoder_q.", "")
state_dict[k] = v
if arch.startswith("densenet"):
num_classes = pretrained_dict["model.encoder_q.classifier.weight"].shape[0]
model = DenseNet(num_classes=num_classes)
model.load_state_dict(state_dict)
feature_dim = pretrained_dict["model.encoder_q.classifier.weight"].shape[1]
del model.classifier
else:
raise ValueError(f"Model architecture {arch} is not supported.")
return model, feature_dim