in jat/processing_jat.py [0:0]
def pad(self, *args, **kwargs):
inputs = args[0]
keys = [key for key in inputs[0].keys() if inputs[0][key] is not None]
inputs = {key: [arg[key] for arg in inputs] for key in keys}
elmt = next(iter(inputs.values()))
if isinstance(elmt[0], torch.Tensor) and not isinstance(elmt, torch.Tensor):
encoding = {key: torch.stack(inputs[key]) for key in inputs.keys()}
else:
encoding = self._truncate_and_pad(
inputs, padding=kwargs.get("padding", False), truncation=False, max_length=kwargs.get("max_length")
)
return BatchEncoding(encoding, tensor_type=kwargs.get("return_tensors"))