in gossip_sgd_adpsgd.py [0:0]
def make_dataloader(args, train=True):
""" Returns train/val distributed dataloaders (cf. ImageNet in 1hr) """
data_dir = args.dataset_dir
train_dir = os.path.join(data_dir, 'train')
val_dir = os.path.join(data_dir, 'val')
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
if train:
log.debug('fpaths train {}'.format(train_dir))
train_dataset = datasets.ImageFolder(train_dir, transforms.Compose([
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
normalize]))
# sampler produces indices used to assign data samples to each agent
train_sampler = torch.utils.data.distributed.DistributedSampler(
dataset=train_dataset,
num_replicas=args.world_size,
rank=args.rank)
train_loader = torch.utils.data.DataLoader(
train_dataset, batch_size=args.batch_size,
shuffle=(train_sampler is None),
num_workers=args.num_dataloader_workers,
pin_memory=True, sampler=train_sampler)
return train_loader, train_sampler
else:
log.debug('fpaths val {}'.format(val_dir))
val_loader = torch.utils.data.DataLoader(
datasets.ImageFolder(val_dir, transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
normalize])),
batch_size=args.batch_size, shuffle=False,
num_workers=args.num_dataloader_workers, pin_memory=True)
return val_loader