in cp_examples/sip_finetune/sip_finetune.py [0:0]
def validate_pretrained_model(state_dict, pretrained_file):
# sanity check to make sure we're not altering weights
pretrained_dict = torch.load(pretrained_file, map_location="cpu")["state_dict"]
model_dict = dict()
for k, v in pretrained_dict.items():
if "model.encoder_q" in k:
model_dict[k[len("model.encoder_q.") :]] = v
for k in list(model_dict.keys()):
# only ignore fc layer
if "classifier.weight" in k or "classifier.bias" in k:
continue
if "fc.weight" in k or "fc.bias" in k:
continue
assert (
state_dict[k].cpu() == model_dict[k]
).all(), f"{k} changed in linear classifier training."