in problems.py [0:0]
def imagenet(args):
kwargs = {'num_workers': 32, 'pin_memory': True} if args.cuda else {}
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
lock_transforms = (args.method.endswith("svrg")) and args.transform_locking and args.opt_vr
logging.info("Loading training dataset")
train_dir = "/datasets01_101/imagenet_full_size/061417/train"
logging.info("Data ...")
train_dataset = ImagenetWrapper(train_dir, lock_transforms=lock_transforms)
logging.info("Imagenet Wrapper created")
logging.info("VR Sampler with order=perm")
sampler = VRSampler(order="perm",
batch_size=args.batch_size,
dataset_size=len(train_dataset))
train_loader = UpdatedDataLoaderMult.DataLoader(
train_dataset, batch_sampler=sampler,
worker_init_fn=train_dataset.child_initialize, **kwargs) #worker_init_fn
logging.info("Train Loader created, batches: {}".format(len(train_loader)))
test_loader = torch.utils.data.DataLoader(
datasets.ImageFolder("/datasets01_101/imagenet_full_size/061417/val",
transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
normalize,
])),
batch_size=args.batch_size, shuffle=False, **kwargs)
args.nbatches = len(train_loader)
logging.info("Initializing model")
if args.architecture == "resnet18":
model = torchvision.models.resnet.resnet18()
elif args.architecture == "resnet50":
model = torchvision.models.resnet.resnet50()
elif args.architecture == "resnext101_32x8d":
model = resnext.resnext101_32x8d()
else:
raise Exception("Architecture not supported for imagenet")
logging.info("Lifting model to DataParallel")
model = torch.nn.DataParallel(model).cuda() # Use multiple gpus
model.sampler = sampler
return train_loader, test_loader, model, train_dataset