in pytorchvideo_trainer/pytorchvideo_trainer/datamodule/datamodule.py [0:0]
def _get_dataloader(self, phase: str) -> DataLoader:
assert self.datasets[phase] is not None, "Failed to get the {} dataset!".format(
phase
)
if isinstance(self.datasets[phase], torch.utils.data.IterableDataset):
return torch.utils.data.DataLoader(
self.datasets[phase],
batch_size=self.config[phase].batch_size,
num_workers=self.config[phase].num_workers,
pin_memory=self.config[phase].pin_memory,
drop_last=self.config[phase].drop_last,
collate_fn=hydra.utils.instantiate(self.config[phase].collate_fn),
worker_init_fn=hydra.utils.instantiate(
self.config[phase].worker_init_fn
),
)
else:
sampler = None
if torch.distributed.is_available() and torch.distributed.is_initialized():
logging.info(
"Distributed Environmnet detected, using DistributedSampler for dataloader."
)
sampler = DistributedSampler(self.datasets[phase])
return torch.utils.data.DataLoader(
self.datasets[phase],
batch_size=self.config[phase].batch_size,
num_workers=self.config[phase].num_workers,
pin_memory=self.config[phase].pin_memory,
drop_last=self.config[phase].drop_last,
sampler=sampler,
shuffle=(False if sampler else self.config[phase].shuffle),
collate_fn=hydra.utils.instantiate(self.config[phase].collate_fn),
worker_init_fn=hydra.utils.instantiate(
self.config[phase].worker_init_fn
),
)