def validate_pretrained_model()

in cp_examples/sip_finetune/sip_finetune.py [0:0]


def validate_pretrained_model(state_dict, pretrained_file):
    # sanity check to make sure we're not altering weights
    pretrained_dict = torch.load(pretrained_file, map_location="cpu")["state_dict"]
    model_dict = dict()
    for k, v in pretrained_dict.items():
        if "model.encoder_q" in k:
            model_dict[k[len("model.encoder_q.") :]] = v

    for k in list(model_dict.keys()):
        # only ignore fc layer
        if "classifier.weight" in k or "classifier.bias" in k:
            continue
        if "fc.weight" in k or "fc.bias" in k:
            continue

        assert (
            state_dict[k].cpu() == model_dict[k]
        ).all(), f"{k} changed in linear classifier training."