in src/data_loader.py [0:0]
def get_loader(dataset,
dataset_root,
split,
transform,
batch_size,
shuffle,
num_workers,
include_eos,
drop_last=False,
shuffle_labels=False,
seed=1234,
checkpoint=None):
# reads the file with ids to use for this split
perm_file = os.path.join('../data/splits/', dataset, split + '.txt')
with open(perm_file, 'r') as f:
perm = np.array([int(line.rstrip('\n')) for line in f])
if dataset == 'coco':
if split == 'train' or split == 'val':
annFile = os.path.join(dataset_root, 'annotations', 'instances_train2014.json')
impath = os.path.join(dataset_root, 'train2014')
else:
annFile = os.path.join(dataset_root, 'annotations', 'instances_val2014.json')
impath = os.path.join(dataset_root, 'val2014')
dataset = COCO(
root=impath,
annFile=annFile,
transform=transform,
shuffle=shuffle_labels,
perm=perm,
include_eos=include_eos)
elif dataset == 'voc':
dataset = VOC(
root=dataset_root,
year='2007',
image_set=split,
download=False,
transform=transform,
shuffle=shuffle_labels,
perm=perm,
include_eos=include_eos)
elif dataset == 'nuswide':
dataset = NUSWIDE(
dataset_root,
split,
transform=transform,
shuffle=shuffle_labels,
perm=perm,
include_eos=include_eos)
elif dataset == 'ade20k':
dataset = ADE20K(
dataset_root,
split,
transform=transform,
shuffle=shuffle_labels,
perm=perm,
include_eos=include_eos)
elif dataset == 'recipe1m':
dataset = Recipe1M(
dataset_root,
split,
maxnumims=5,
shuffle=shuffle_labels,
transform=transform,
use_lmdb=False,
suff='final_',
perm=perm,
include_eos=include_eos)
def worker_init_fn(worker_id):
np.random.seed(seed)
if shuffle:
# for training
sampler = RandomSamplerWithState(dataset, batch_size, seed)
if checkpoint is not None:
sampler.set_state(checkpoint['args'].current_epoch, checkpoint['current_step'])
else:
# for validation and test
sampler = SequentialSampler(dataset)
data_loader = torch.utils.data.DataLoader(
dataset=dataset,
batch_size=batch_size,
shuffle=False,
num_workers=num_workers,
drop_last=drop_last,
pin_memory=True,
collate_fn=collate_fn,
worker_init_fn=worker_init_fn,
sampler=sampler)
return data_loader, dataset