in optimum/habana/transformers/trainer.py [0:0]
def _prepare_input(self, data: Union[torch.Tensor, Any]) -> Union[torch.Tensor, Any]:
"""
Prepares one `data` before feeding it to the model, be it a tensor or a nested list/dictionary of tensors.
Compared to Transformers, it is also possible to enable non-blocking data copy.
"""
if isinstance(data, Mapping):
return type(data)({k: self._prepare_input(v) for k, v in data.items()})
elif isinstance(data, (tuple, list)):
return type(data)(self._prepare_input(v) for v in data)
elif isinstance(data, torch.Tensor):
if (
self.accelerator.mpu.sequence_parallel_is_initialized()
and self.accelerator.mpu.get_sequence_parallel_world_size() > 1
):
seq_parallel_world_rank = self.accelerator.mpu.get_sequence_parallel_rank()
sub_seq_length = int(data.size()[1] / self.accelerator.mpu.get_sequence_parallel_world_size())
data = data[
:, seq_parallel_world_rank * sub_seq_length : (seq_parallel_world_rank + 1) * sub_seq_length
]
kwargs = {"device": self.args.device}
if self.is_deepspeed_enabled and (torch.is_floating_point(data) or torch.is_complex(data)):
# NLP models inputs are int/uint and those get adjusted to the right dtype of the
# embedding. Other models such as wav2vec2's inputs are already float and thus
# may need special handling to match the dtypes of the model
kwargs.update({"dtype": self.accelerator.state.deepspeed_plugin.hf_ds_config.dtype()})
if self.args.non_blocking_data_copy:
return data.to(**kwargs, non_blocking=True)
else:
return data.to(**kwargs)
return data