in empchat/datasets/loader.py [0:0]
def pad(tensors, padding_value=-1):
"""
Concatenate and pad the input tensors, which may be 1D or 2D.
"""
max_len = max(t.size(-1) for t in tensors)
if tensors[0].dim() == 1:
out = torch.LongTensor(len(tensors), max_len).fill_(padding_value)
for i, t in enumerate(tensors):
out[i, : t.size(0)] = t
return out
elif tensors[0].dim() == 2:
max_width = max(t.size(0) for t in tensors)
out = torch.LongTensor(len(tensors), max_width, max_len).fill_(padding_value)
for i, t in enumerate(tensors):
out[i, : t.size(0), : t.size(1)] = t
return out
else:
raise ValueError("Input tensors must be either 1D or 2D!")