def pad()

in empchat/datasets/loader.py [0:0]


def pad(tensors, padding_value=-1):
    """
    Concatenate and pad the input tensors, which may be 1D or 2D.
    """
    max_len = max(t.size(-1) for t in tensors)
    if tensors[0].dim() == 1:
        out = torch.LongTensor(len(tensors), max_len).fill_(padding_value)
        for i, t in enumerate(tensors):
            out[i, : t.size(0)] = t
        return out
    elif tensors[0].dim() == 2:
        max_width = max(t.size(0) for t in tensors)
        out = torch.LongTensor(len(tensors), max_width, max_len).fill_(padding_value)
        for i, t in enumerate(tensors):
            out[i, : t.size(0), : t.size(1)] = t
        return out
    else:
        raise ValueError("Input tensors must be either 1D or 2D!")