in esm/data.py [0:0]
def __call__(self, inputs: Union[Sequence[RawMSA], RawMSA]):
if isinstance(inputs[0][0], str):
# Input is a single MSA
raw_batch: Sequence[RawMSA] = [inputs] # type: ignore
else:
raw_batch = inputs # type: ignore
batch_size = len(raw_batch)
max_alignments = max(len(msa) for msa in raw_batch)
max_seqlen = max(len(msa[0][1]) for msa in raw_batch)
tokens = torch.empty(
(
batch_size,
max_alignments,
max_seqlen + int(self.alphabet.prepend_bos) + int(self.alphabet.append_eos),
),
dtype=torch.int64,
)
tokens.fill_(self.alphabet.padding_idx)
labels = []
strs = []
for i, msa in enumerate(raw_batch):
msa_seqlens = set(len(seq) for _, seq in msa)
if not len(msa_seqlens) == 1:
raise RuntimeError(
"Received unaligned sequences for input to MSA, all sequence "
"lengths must be equal."
)
msa_labels, msa_strs, msa_tokens = super().__call__(msa)
labels.append(msa_labels)
strs.append(msa_strs)
tokens[i, : msa_tokens.size(0), : msa_tokens.size(1)] = msa_tokens
return labels, strs, tokens