in src/flint/data_utils/fields.py [0:0]
def collate_tokens(self, values, pad_idx, eos_idx=None, left_pad=False, move_eos_to_beginning=False):
"""
Convert a list of 1d tensors into a padded 2d tensor.
"""
if not torch.is_tensor(values[0]):
values = [torch.tensor(v) for v in values]
size = max(v.size(0) for v in values)
res = values[0].new(len(values), size).fill_(pad_idx)
def copy_tensor(src, dst):
assert dst.numel() == src.numel()
if move_eos_to_beginning:
assert src[-1] == eos_idx
dst[0] = eos_idx
dst[1:] = src[:-1]
else:
dst.copy_(src)
for i, v in enumerate(values):
copy_tensor(v, res[i][size - len(v):] if left_pad else res[i][:len(v)])
return res