in esm/model.py [0:0]
def forward(self, tokens, repr_layers=[], need_head_weights=False, return_contacts=False):
if return_contacts:
need_head_weights = True
assert tokens.ndim == 3
batch_size, num_alignments, seqlen = tokens.size()
padding_mask = tokens.eq(self.padding_idx) # B, R, C
if not padding_mask.any():
padding_mask = None
x = self.embed_tokens(tokens)
x += self.embed_positions(tokens.view(batch_size * num_alignments, seqlen)).view(x.size())
if self.msa_position_embedding is not None:
if x.size(1) > 1024:
raise RuntimeError(
"Using model with MSA position embedding trained on maximum MSA "
f"depth of 1024, but received {x.size(1)} alignments."
)
x += self.msa_position_embedding[:, :num_alignments]
x = self.emb_layer_norm_before(x)
x = self.dropout_module(x)
if padding_mask is not None:
x = x * (1 - padding_mask.unsqueeze(-1).type_as(x))
repr_layers = set(repr_layers)
hidden_representations = {}
if 0 in repr_layers:
hidden_representations[0] = x
if need_head_weights:
row_attn_weights = []
col_attn_weights = []
# B x R x C x D -> R x C x B x D
x = x.permute(1, 2, 0, 3)
for layer_idx, layer in enumerate(self.layers):
x = layer(
x,
self_attn_padding_mask=padding_mask,
need_head_weights=need_head_weights,
)
if need_head_weights:
x, col_attn, row_attn = x
# H x C x B x R x R -> B x H x C x R x R
col_attn_weights.append(col_attn.permute(2, 0, 1, 3, 4))
# H x B x C x C -> B x H x C x C
row_attn_weights.append(row_attn.permute(1, 0, 2, 3))
if (layer_idx + 1) in repr_layers:
hidden_representations[layer_idx + 1] = x.permute(2, 0, 1, 3)
x = self.emb_layer_norm_after(x)
x = x.permute(2, 0, 1, 3) # R x C x B x D -> B x R x C x D
# last hidden representation should have layer norm applied
if (layer_idx + 1) in repr_layers:
hidden_representations[layer_idx + 1] = x
x = self.lm_head(x)
result = {"logits": x, "representations": hidden_representations}
if need_head_weights:
# col_attentions: B x L x H x C x R x R
col_attentions = torch.stack(col_attn_weights, 1)
# row_attentions: B x L x H x C x C
row_attentions = torch.stack(row_attn_weights, 1)
result["col_attentions"] = col_attentions
result["row_attentions"] = row_attentions
if return_contacts:
contacts = self.contact_head(tokens, row_attentions)
result["contacts"] = contacts
return result