in esm/pretrained.py [0:0]
def load_model_and_alphabet_core(model_data, regression_data=None):
if regression_data is not None:
model_data["model"].update(regression_data["model"])
alphabet = esm.Alphabet.from_architecture(model_data["args"].arch)
if model_data["args"].arch == "roberta_large":
# upgrade state dict
pra = lambda s: "".join(s.split("encoder_")[1:] if "encoder" in s else s)
prs1 = lambda s: "".join(s.split("encoder.")[1:] if "encoder" in s else s)
prs2 = lambda s: "".join(
s.split("sentence_encoder.")[1:] if "sentence_encoder" in s else s
)
model_args = {pra(arg[0]): arg[1] for arg in vars(model_data["args"]).items()}
model_state = {prs1(prs2(arg[0])): arg[1] for arg in model_data["model"].items()}
model_state["embed_tokens.weight"][alphabet.mask_idx].zero_() # For token drop
model_args["emb_layer_norm_before"] = has_emb_layer_norm_before(model_state)
model_type = esm.ProteinBertModel
elif model_data["args"].arch == "protein_bert_base":
# upgrade state dict
pra = lambda s: "".join(s.split("decoder_")[1:] if "decoder" in s else s)
prs = lambda s: "".join(s.split("decoder.")[1:] if "decoder" in s else s)
model_args = {pra(arg[0]): arg[1] for arg in vars(model_data["args"]).items()}
model_state = {prs(arg[0]): arg[1] for arg in model_data["model"].items()}
model_type = esm.ProteinBertModel
elif model_data["args"].arch == "msa_transformer":
# upgrade state dict
pra = lambda s: "".join(s.split("encoder_")[1:] if "encoder" in s else s)
prs1 = lambda s: "".join(s.split("encoder.")[1:] if "encoder" in s else s)
prs2 = lambda s: "".join(
s.split("sentence_encoder.")[1:] if "sentence_encoder" in s else s
)
prs3 = lambda s: s.replace("row", "column") if "row" in s else s.replace("column", "row")
model_args = {pra(arg[0]): arg[1] for arg in vars(model_data["args"]).items()}
model_state = {prs1(prs2(prs3(arg[0]))): arg[1] for arg in model_data["model"].items()}
if model_args.get("embed_positions_msa", False):
emb_dim = model_state["msa_position_embedding"].size(-1)
model_args["embed_positions_msa_dim"] = emb_dim # initial release, bug: emb_dim==1
model_type = esm.MSATransformer
else:
raise ValueError("Unknown architecture selected")
model = model_type(
Namespace(**model_args),
alphabet,
)
expected_keys = set(model.state_dict().keys())
found_keys = set(model_state.keys())
if regression_data is None:
expected_missing = {"contact_head.regression.weight", "contact_head.regression.bias"}
error_msgs = []
missing = (expected_keys - found_keys) - expected_missing
if missing:
error_msgs.append(f"Missing key(s) in state_dict: {missing}.")
unexpected = found_keys - expected_keys
if unexpected:
error_msgs.append(f"Unexpected key(s) in state_dict: {unexpected}.")
if error_msgs:
raise RuntimeError(
"Error(s) in loading state_dict for {}:\n\t{}".format(
model.__class__.__name__, "\n\t".join(error_msgs)
)
)
if expected_missing - found_keys:
warnings.warn(
"Regression weights not found, predicting contacts will not produce correct results."
)
model.load_state_dict(model_state, strict=regression_data is not None)
return model, alphabet