def _get_train_sampler()

in optimum/habana/transformers/trainer.py [0:0]


    def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
        if self.train_dataset is None or not has_length(self.train_dataset):
            return None

        # Build the sampler.
        if self.args.group_by_length:
            if is_datasets_available() and isinstance(self.train_dataset, datasets.Dataset):
                lengths = (
                    self.train_dataset[self.args.length_column_name]
                    if self.args.length_column_name in self.train_dataset.column_names
                    else None
                )
            else:
                lengths = None
            model_input_name = (
                self.processing_class.model_input_names[0] if self.processing_class is not None else None
            )
            return LengthGroupedSampler(
                self.args.train_batch_size * self.args.gradient_accumulation_steps,
                dataset=self.train_dataset,
                lengths=lengths,
                model_input_name=model_input_name,
            )

        else:
            num_samples = len(self.train_dataset)
            if (
                not self.args.dataloader_drop_last
                and num_samples % self.args.per_device_train_batch_size != 0
                and self.args.parallel_mode != ParallelMode.DISTRIBUTED
            ):
                # Make the total number of samples divisible by the batch size in lazy mode if needed
                num_samples += (
                    self.args.per_device_train_batch_size - num_samples % self.args.per_device_train_batch_size
                )
            return RandomSampler(self.train_dataset, num_samples=num_samples)