def pad()

in jat/processing_jat.py [0:0]


    def pad(self, *args, **kwargs):
        inputs = args[0]
        keys = [key for key in inputs[0].keys() if inputs[0][key] is not None]
        inputs = {key: [arg[key] for arg in inputs] for key in keys}
        elmt = next(iter(inputs.values()))
        if isinstance(elmt[0], torch.Tensor) and not isinstance(elmt, torch.Tensor):
            encoding = {key: torch.stack(inputs[key]) for key in inputs.keys()}
        else:
            encoding = self._truncate_and_pad(
                inputs, padding=kwargs.get("padding", False), truncation=False, max_length=kwargs.get("max_length")
            )

        return BatchEncoding(encoding, tensor_type=kwargs.get("return_tensors"))