def __dataloader()

in covidprognosis/plmodules/xray_datamodule.py [0:0]


    def __dataloader(self, split: str) -> torch.utils.data.DataLoader:
        assert split in ("train", "val", "test")
        shuffle = False
        if split == "train":
            dataset = self.train_dataset
            shuffle = True
        elif split == "val":
            dataset = self.val_dataset
        else:
            dataset = self.test_dataset

        loader = torch.utils.data.DataLoader(
            dataset=dataset,
            batch_size=self.batch_size,
            num_workers=self.num_workers,
            drop_last=True,
            shuffle=shuffle,
            worker_init_fn=worker_init_fn,
        )

        return loader