def load_pretrained_model()

in cp_examples/mip_finetune/mip_model.py [0:0]


def load_pretrained_model(arch, pretrained_file):
    pretrained_dict = torch.load(pretrained_file)["state_dict"]
    state_dict = {}
    for k, v in pretrained_dict.items():
        if k.startswith("model.encoder_q."):
            k = k.replace("model.encoder_q.", "")
            state_dict[k] = v

    if arch.startswith("densenet"):
        num_classes = pretrained_dict["model.encoder_q.classifier.weight"].shape[0]
        model = DenseNet(num_classes=num_classes)
        model.load_state_dict(state_dict)
        feature_dim = pretrained_dict["model.encoder_q.classifier.weight"].shape[1]
        del model.classifier
    else:
        raise ValueError(f"Model architecture {arch} is not supported.")

    return model, feature_dim