in esm/data.py [0:0]
def __call__(self, raw_batch: Sequence[Tuple[str, str]]):
# RoBERTa uses an eos token, while ESM-1 does not.
batch_size = len(raw_batch)
batch_labels, seq_str_list = zip(*raw_batch)
seq_encoded_list = [self.alphabet.encode(seq_str) for seq_str in seq_str_list]
max_len = max(len(seq_encoded) for seq_encoded in seq_encoded_list)
tokens = torch.empty(
(
batch_size,
max_len + int(self.alphabet.prepend_bos) + int(self.alphabet.append_eos),
),
dtype=torch.int64,
)
tokens.fill_(self.alphabet.padding_idx)
labels = []
strs = []
for i, (label, seq_str, seq_encoded) in enumerate(
zip(batch_labels, seq_str_list, seq_encoded_list)
):
labels.append(label)
strs.append(seq_str)
if self.alphabet.prepend_bos:
tokens[i, 0] = self.alphabet.cls_idx
seq = torch.tensor(seq_encoded, dtype=torch.int64)
tokens[
i,
int(self.alphabet.prepend_bos) : len(seq_encoded)
+ int(self.alphabet.prepend_bos),
] = seq
if self.alphabet.append_eos:
tokens[i, len(seq_encoded) + int(self.alphabet.prepend_bos)] = self.alphabet.eos_idx
return labels, strs, tokens