in mico/dataloader/query_doc_pair.py [0:0]
def get_loaders(self, batch_size, num_workers, is_shuffle_train=True, is_get_test=True, prefetch_factor=2, pin_memory=False):
"""Get train/val/test loaders for training and testing.
Note
----
Setting `pin_memory=True` or larger `prefetch_factor` may increase the speed a little,
but it costs much more memory.
Parameters
----------
batch_size : int
The batch_size is for each process on each GPU.
num_workers : int
Setting this larger than 1 means we have more threads for loading text data.
It may increase the speed a little but cost much more memory.
is_shuffle_train : bool
Whether we shuffle the training data.
is_get_test : bool
If you want to make sure the test set is not used, set it to be `False`.
Returns
-------
(train_loader, val_loader, test_loader) :
The three dataloaders are PyTorch Dataloader objects.
"""
train_loader = DataLoader(self.train_dataset, batch_size=batch_size,
num_workers=num_workers, pin_memory=pin_memory,
shuffle=is_shuffle_train,
prefetch_factor=prefetch_factor)
val_loader = DataLoader(self.val_dataset, batch_size=batch_size,
num_workers=num_workers, shuffle=True, pin_memory=pin_memory,
prefetch_factor=prefetch_factor)
test_loader = DataLoader(self.test_dataset, batch_size=batch_size,
num_workers=num_workers, shuffle=True, pin_memory=pin_memory,
prefetch_factor=prefetch_factor) if is_get_test else None
return train_loader, val_loader, test_loader