in slowfast/datasets/loader.py [0:0]
def construct_loader(cfg, split, is_precise_bn=False):
"""
Constructs the data loader for the given dataset.
Args:
cfg (CfgNode): configs. Details can be found in
slowfast/config/defaults.py
split (str): the split of the data loader. Options include `train`,
`val`, and `test`.
"""
assert split in ["train", "val", "test"]
if split in ["train"]:
dataset_name = cfg.TRAIN.DATASET
batch_size = int(cfg.TRAIN.BATCH_SIZE / max(1, cfg.NUM_GPUS))
shuffle = True
drop_last = True
elif split in ["val"]:
dataset_name = cfg.TRAIN.DATASET
batch_size = int(cfg.TRAIN.BATCH_SIZE / max(1, cfg.NUM_GPUS))
shuffle = False
drop_last = False
elif split in ["test"]:
dataset_name = cfg.TEST.DATASET
batch_size = int(cfg.TEST.BATCH_SIZE / max(1, cfg.NUM_GPUS))
shuffle = False
drop_last = False
# Construct the dataset
dataset = build_dataset(dataset_name, cfg, split)
if isinstance(dataset, torch.utils.data.IterableDataset):
loader = torch.utils.data.DataLoader(
dataset,
batch_size=batch_size,
num_workers=cfg.DATA_LOADER.NUM_WORKERS,
pin_memory=cfg.DATA_LOADER.PIN_MEMORY,
drop_last=drop_last,
collate_fn=detection_collate if cfg.DETECTION.ENABLE else None,
worker_init_fn=utils.loader_worker_init_fn(dataset),
)
else:
if (
cfg.MULTIGRID.SHORT_CYCLE
and split in ["train"]
and not is_precise_bn
):
# Create a sampler for multi-process training
sampler = utils.create_sampler(dataset, shuffle, cfg)
batch_sampler = ShortCycleBatchSampler(
sampler, batch_size=batch_size, drop_last=drop_last, cfg=cfg
)
# Create a loader
loader = torch.utils.data.DataLoader(
dataset,
batch_sampler=batch_sampler,
num_workers=cfg.DATA_LOADER.NUM_WORKERS,
pin_memory=cfg.DATA_LOADER.PIN_MEMORY,
worker_init_fn=utils.loader_worker_init_fn(dataset),
)
else:
# Create a sampler for multi-process training
sampler = utils.create_sampler(dataset, shuffle, cfg)
# Create a loader
if cfg.DETECTION.ENABLE:
collate_func = detection_collate
elif cfg.AUG.NUM_SAMPLE > 1 and split in ["train"]:
collate_func = partial(
multiple_samples_collate, fold="imagenet" in dataset_name
)
else:
collate_func = None
loader = torch.utils.data.DataLoader(
dataset,
batch_size=batch_size,
shuffle=(False if sampler else shuffle),
sampler=sampler,
num_workers=cfg.DATA_LOADER.NUM_WORKERS,
pin_memory=cfg.DATA_LOADER.PIN_MEMORY,
drop_last=drop_last,
collate_fn=collate_func,
worker_init_fn=utils.loader_worker_init_fn(dataset),
)
return loader