in covidprognosis/plmodules/xray_datamodule.py [0:0]
def __dataloader(self, split: str) -> torch.utils.data.DataLoader:
assert split in ("train", "val", "test")
shuffle = False
if split == "train":
dataset = self.train_dataset
shuffle = True
elif split == "val":
dataset = self.val_dataset
else:
dataset = self.test_dataset
loader = torch.utils.data.DataLoader(
dataset=dataset,
batch_size=self.batch_size,
num_workers=self.num_workers,
drop_last=True,
shuffle=shuffle,
worker_init_fn=worker_init_fn,
)
return loader