in esm/model.py [0:0]
def __init__(self, args, alphabet):
super().__init__()
self.args = args
self.alphabet_size = len(alphabet)
self.padding_idx = alphabet.padding_idx
self.mask_idx = alphabet.mask_idx
self.cls_idx = alphabet.cls_idx
self.eos_idx = alphabet.eos_idx
self.prepend_bos = alphabet.prepend_bos
self.append_eos = alphabet.append_eos
self.embed_tokens = nn.Embedding(
self.alphabet_size, self.args.embed_dim, padding_idx=self.padding_idx
)
if getattr(self.args, "embed_positions_msa", False):
emb_dim = getattr(self.args, "embed_positions_msa_dim", self.args.embed_dim)
self.msa_position_embedding = nn.Parameter(
0.01 * torch.randn(1, 1024, 1, emb_dim),
requires_grad=True,
)
else:
self.register_parameter("msa_position_embedding", None)
self.dropout_module = nn.Dropout(self.args.dropout)
self.layers = nn.ModuleList(
[
AxialTransformerLayer(
self.args.embed_dim,
self.args.ffn_embed_dim,
self.args.attention_heads,
self.args.dropout,
self.args.attention_dropout,
self.args.activation_dropout,
getattr(self.args, "max_tokens_per_msa", self.args.max_tokens),
)
for _ in range(self.args.layers)
]
)
self.contact_head = ContactPredictionHead(
self.args.layers * self.args.attention_heads,
self.prepend_bos,
self.append_eos,
eos_idx=self.eos_idx,
)
self.embed_positions = LearnedPositionalEmbedding(
self.args.max_positions,
self.args.embed_dim,
self.padding_idx,
)
self.emb_layer_norm_before = ESM1bLayerNorm(self.args.embed_dim)
self.emb_layer_norm_after = ESM1bLayerNorm(self.args.embed_dim)
self.lm_head = RobertaLMHead(
embed_dim=self.args.embed_dim,
output_dim=self.alphabet_size,
weight=self.embed_tokens.weight,
)