def get_train_dataloader()

in optimum/graphcore/trainer.py [0:0]


    def get_train_dataloader(self) -> poptorch.DataLoader:
        """
        Returns the training `poptorch.DataLoader`.

        Will not use a sampler if `train_dataset` does not implement `__len__` and will use a random sampler (adapted to distributed
        training if necessary) otherwise.

        Subclass and override this method if you want to inject some custom behavior.
        """
        if self.train_dataset is None:
            raise ValueError("Trainer: training requires a train_dataset.")

        train_dataset = self.train_dataset
        data_collator = self.data_collator
        if is_datasets_available() and isinstance(train_dataset, datasets.Dataset):
            train_dataset = self._remove_unused_columns(train_dataset, description="training")
        else:
            data_collator = self._get_collator_with_removed_columns(data_collator, description="training")

        poptorch_specific_kwargs = {
            "auto_distributed_partitioning": not isinstance(train_dataset, torch.utils.data.IterableDataset),
            "mode": self.args.dataloader_mode,
            "worker_init_fn": _WorkerInit(123),
        }

        if isinstance(train_dataset, torch.utils.data.IterableDataset):
            return poptorch.DataLoader(
                self.opts,
                train_dataset,
                batch_size=self.args.train_batch_size,
                collate_fn=self.data_collator,
                num_workers=self.args.dataloader_num_workers,
                drop_last=self.args.dataloader_drop_last,
                pin_memory=self.args.dataloader_pin_memory,
                **poptorch_specific_kwargs,
            )

        train_sampler = self._get_train_sampler()
        combined_batch_size = self.args.per_device_train_batch_size * self.ipu_config.batch_size_factor()
        rebatched_worker_size = (
            2 * (combined_batch_size // self.args.dataloader_num_workers)
            if self.args.dataloader_num_workers
            else combined_batch_size
        )

        self._check_dataset_can_fill_batch(train_dataset, for_inference=False)

        return poptorch.DataLoader(
            self.opts,
            train_dataset,
            batch_size=self.args.per_device_train_batch_size,
            sampler=train_sampler,
            collate_fn=self.data_collator,
            num_workers=self.args.dataloader_num_workers,
            pin_memory=self.args.dataloader_pin_memory,
            rebatched_worker_size=rebatched_worker_size,
            **poptorch_specific_kwargs,
        )