in problems.py [0:0]
def cifar10(args):
data_dir = os.path.expanduser('~/data')
kwargs = {'num_workers': 0, 'pin_memory': True} if args.cuda else {}
transform = transforms.Compose([
transforms.RandomHorizontalFlip(),
transforms.RandomCrop(32, padding=4),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])
# We don't do the random transforms at test time.
transform_test = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])
logging.info("Loading training dataset")
if (args.method.endswith("svrg") or args.method == "scsg") and args.transform_locking and args.opt_vr:
train_dataset = CIFAR10_Wrapper(
root=data_dir, train=True,
download=True, transform=transform)
else:
train_dataset = torchvision.datasets.CIFAR10(
root=data_dir, train=True,
download=True, transform=transform)
if args.method.endswith("svrg") and args.opt_vr:
if args.method == "saga":
raise Exception("vr sampler currently doesn't support saga")
logging.info("VR Sampler with order=perm")
sampler = VRSampler(order="perm",
batch_size=args.batch_size,
dataset_size=len(train_dataset))
train_loader = UpdatedDataLoader.DataLoader(
train_dataset, batch_sampler=sampler, **kwargs)
else:
sampler = RandomSampler(train_dataset)
train_loader = torch.utils.data.DataLoader(
train_dataset, sampler=sampler, batch_size=args.batch_size, **kwargs)
args.nbatches = len(sampler)
logging.info("Loading test dataset")
test_dataset = torchvision.datasets.CIFAR10(
root=data_dir, train=False,
download=True, transform=transform_test)
test_loader = torch.utils.data.DataLoader(
test_dataset, batch_size=args.batch_size,
shuffle=False, **kwargs)
nonlinearity = F.relu
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(3, 6, 5)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(6, 16, 5)
logging.info("Initializing fully connected layers")
self.fc1 = nn.Linear(16 * 5 * 5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)
if args.batchnorm:
logging.info("Using batchnorm")
self.bn1 = nn.BatchNorm2d(6)
self.bn2 = nn.BatchNorm2d(16)
self.bn3 = nn.BatchNorm1d(120)
self.bn4 = nn.BatchNorm1d(84)
logging.info("initialized")
def forward(self, x):
x = self.conv1(x)
if args.batchnorm:
x = self.bn1(x)
x = nonlinearity (x)
x = self.pool(x)
#pdb.set_trace()
x = self.conv2(x)
if args.batchnorm:
x = self.bn2(x)
x = nonlinearity (x)
x = self.pool(x)
x = x.view(-1, 16 * 5 * 5)
x = self.fc1(x)
if args.batchnorm:
x = self.bn3(x)
x = nonlinearity (x)
x = self.fc2(x)
if args.batchnorm:
x = self.bn4(x)
x = nonlinearity (x)
x = self.fc3(x)
return x
logging.info("Loading architecture")
if args.architecture == "default":
logging.info("default architecture")
model = Net()
elif args.architecture == "resnet110":
model = resnet.ResNet110(batchnorm=args.batchnorm, nonlinearity=nonlinearity)
elif args.architecture == "resnet-small":
model = resnet.ResNetSmall(batchnorm=args.batchnorm, nonlinearity=nonlinearity)
elif args.architecture == "densenet-40-36":
model = densenet.densenet(depth=40, growthRate=36, batchnorm=args.batchnorm, nonlinearity=nonlinearity)
model = torch.nn.DataParallel(model)
else:
raise Exception("architecture not recognised:", args.architecture)
model.sampler = sampler
return train_loader, test_loader, model, train_dataset