def make_dataloader()

in gossip_sgd_adpsgd.py [0:0]


def make_dataloader(args, train=True):
    """ Returns train/val distributed dataloaders (cf. ImageNet in 1hr) """

    data_dir = args.dataset_dir
    train_dir = os.path.join(data_dir, 'train')
    val_dir = os.path.join(data_dir, 'val')

    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])

    if train:
        log.debug('fpaths train {}'.format(train_dir))
        train_dataset = datasets.ImageFolder(train_dir, transforms.Compose([
                            transforms.RandomResizedCrop(224),
                            transforms.RandomHorizontalFlip(),
                            transforms.ToTensor(),
                            normalize]))

        # sampler produces indices used to assign data samples to each agent
        train_sampler = torch.utils.data.distributed.DistributedSampler(
                            dataset=train_dataset,
                            num_replicas=args.world_size,
                            rank=args.rank)

        train_loader = torch.utils.data.DataLoader(
            train_dataset, batch_size=args.batch_size,
            shuffle=(train_sampler is None),
            num_workers=args.num_dataloader_workers,
            pin_memory=True, sampler=train_sampler)

        return train_loader, train_sampler

    else:
        log.debug('fpaths val {}'.format(val_dir))
        val_loader = torch.utils.data.DataLoader(
            datasets.ImageFolder(val_dir, transforms.Compose([
                transforms.Resize(256),
                transforms.CenterCrop(224),
                transforms.ToTensor(),
                normalize])),
            batch_size=args.batch_size, shuffle=False,
            num_workers=args.num_dataloader_workers, pin_memory=True)

        return val_loader