in esm/data.py [0:0]
def get_batch_indices(self, toks_per_batch, extra_toks_per_seq=0):
sizes = [(len(s), i) for i, s in enumerate(self.sequence_strs)]
sizes.sort()
batches = []
buf = []
max_len = 0
def _flush_current_buf():
nonlocal max_len, buf
if len(buf) == 0:
return
batches.append(buf)
buf = []
max_len = 0
for sz, i in sizes:
sz += extra_toks_per_seq
if max(sz, max_len) * (len(buf) + 1) > toks_per_batch:
_flush_current_buf()
max_len = max(max_len, sz)
buf.append(i)
_flush_current_buf()
return batches