def imagenet()

in problems.py [0:0]


def imagenet(args):
    kwargs = {'num_workers': 32, 'pin_memory': True} if args.cuda else {}

    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])

    lock_transforms = (args.method.endswith("svrg")) and args.transform_locking and args.opt_vr

    logging.info("Loading training dataset")
    train_dir = "/datasets01_101/imagenet_full_size/061417/train"

    logging.info("Data ...")
    train_dataset = ImagenetWrapper(train_dir, lock_transforms=lock_transforms)
    logging.info("Imagenet Wrapper created")

    logging.info("VR Sampler with order=perm")
    sampler = VRSampler(order="perm",
        batch_size=args.batch_size,
        dataset_size=len(train_dataset))

    train_loader = UpdatedDataLoaderMult.DataLoader(
        train_dataset, batch_sampler=sampler,
        worker_init_fn=train_dataset.child_initialize, **kwargs) #worker_init_fn
    logging.info("Train Loader created, batches: {}".format(len(train_loader)))

    test_loader = torch.utils.data.DataLoader(
        datasets.ImageFolder("/datasets01_101/imagenet_full_size/061417/val",
        transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            normalize,
        ])),
        batch_size=args.batch_size, shuffle=False, **kwargs)

    args.nbatches = len(train_loader)

    logging.info("Initializing model")
    if args.architecture == "resnet18":
        model = torchvision.models.resnet.resnet18()
    elif args.architecture == "resnet50":
        model = torchvision.models.resnet.resnet50()
    elif args.architecture == "resnext101_32x8d":
        model = resnext.resnext101_32x8d()
    else:
        raise Exception("Architecture not supported for imagenet")

    logging.info("Lifting model to DataParallel")
    model = torch.nn.DataParallel(model).cuda() # Use multiple gpus
    model.sampler = sampler

    return train_loader, test_loader, model, train_dataset