def _get_dataloader()

in pytorchvideo_trainer/pytorchvideo_trainer/datamodule/datamodule.py [0:0]


    def _get_dataloader(self, phase: str) -> DataLoader:
        assert self.datasets[phase] is not None, "Failed to get the {} dataset!".format(
            phase
        )

        if isinstance(self.datasets[phase], torch.utils.data.IterableDataset):
            return torch.utils.data.DataLoader(
                self.datasets[phase],
                batch_size=self.config[phase].batch_size,
                num_workers=self.config[phase].num_workers,
                pin_memory=self.config[phase].pin_memory,
                drop_last=self.config[phase].drop_last,
                collate_fn=hydra.utils.instantiate(self.config[phase].collate_fn),
                worker_init_fn=hydra.utils.instantiate(
                    self.config[phase].worker_init_fn
                ),
            )
        else:
            sampler = None
            if torch.distributed.is_available() and torch.distributed.is_initialized():
                logging.info(
                    "Distributed Environmnet detected, using DistributedSampler for dataloader."
                )
                sampler = DistributedSampler(self.datasets[phase])

            return torch.utils.data.DataLoader(
                self.datasets[phase],
                batch_size=self.config[phase].batch_size,
                num_workers=self.config[phase].num_workers,
                pin_memory=self.config[phase].pin_memory,
                drop_last=self.config[phase].drop_last,
                sampler=sampler,
                shuffle=(False if sampler else self.config[phase].shuffle),
                collate_fn=hydra.utils.instantiate(self.config[phase].collate_fn),
                worker_init_fn=hydra.utils.instantiate(
                    self.config[phase].worker_init_fn
                ),
            )