def _prepare_input()

in optimum/habana/transformers/trainer.py [0:0]


    def _prepare_input(self, data: Union[torch.Tensor, Any]) -> Union[torch.Tensor, Any]:
        """
        Prepares one `data` before feeding it to the model, be it a tensor or a nested list/dictionary of tensors.
        Compared to Transformers, it is also possible to enable non-blocking data copy.
        """
        if isinstance(data, Mapping):
            return type(data)({k: self._prepare_input(v) for k, v in data.items()})
        elif isinstance(data, (tuple, list)):
            return type(data)(self._prepare_input(v) for v in data)
        elif isinstance(data, torch.Tensor):
            if (
                self.accelerator.mpu.sequence_parallel_is_initialized()
                and self.accelerator.mpu.get_sequence_parallel_world_size() > 1
            ):
                seq_parallel_world_rank = self.accelerator.mpu.get_sequence_parallel_rank()
                sub_seq_length = int(data.size()[1] / self.accelerator.mpu.get_sequence_parallel_world_size())
                data = data[
                    :, seq_parallel_world_rank * sub_seq_length : (seq_parallel_world_rank + 1) * sub_seq_length
                ]
            kwargs = {"device": self.args.device}
            if self.is_deepspeed_enabled and (torch.is_floating_point(data) or torch.is_complex(data)):
                # NLP models inputs are int/uint and those get adjusted to the right dtype of the
                # embedding. Other models such as wav2vec2's inputs are already float and thus
                # may need special handling to match the dtypes of the model
                kwargs.update({"dtype": self.accelerator.state.deepspeed_plugin.hf_ds_config.dtype()})
            if self.args.non_blocking_data_copy:
                return data.to(**kwargs, non_blocking=True)
            else:
                return data.to(**kwargs)
        return data