in examples/contrib/gp/sv-dkl.py [0:0]
def main(args):
data_dir = args.data_dir if args.data_dir is not None else get_data_directory(__file__)
train_loader = get_data_loader(dataset_name='MNIST',
data_dir=data_dir,
batch_size=args.batch_size,
dataset_transforms=[transforms.Normalize((0.1307,), (0.3081,))],
is_training_set=True,
shuffle=True)
test_loader = get_data_loader(dataset_name='MNIST',
data_dir=data_dir,
batch_size=args.test_batch_size,
dataset_transforms=[transforms.Normalize((0.1307,), (0.3081,))],
is_training_set=False,
shuffle=False)
if args.cuda:
train_loader.num_workers = 1
test_loader.num_workers = 1
cnn = CNN()
# Create deep kernel by warping RBF with CNN.
# CNN will transform a high dimension image into a low dimension 2D tensors for RBF kernel.
# This kernel accepts inputs are inputs of CNN and gives outputs are covariance matrix of RBF
# on outputs of CNN.
rbf = gp.kernels.RBF(input_dim=10, lengthscale=torch.ones(10))
deep_kernel = gp.kernels.Warping(rbf, iwarping_fn=cnn)
# init inducing points (taken randomly from dataset)
batches = []
for i, (data, _) in enumerate(train_loader):
batches.append(data)
if i >= ((args.num_inducing - 1) // args.batch_size):
break
Xu = torch.cat(batches)[:args.num_inducing].clone()
if args.binary:
likelihood = gp.likelihoods.Binary()
latent_shape = torch.Size([])
else:
# use MultiClass likelihood for 10-class classification problem
likelihood = gp.likelihoods.MultiClass(num_classes=10)
# Because we use Categorical distribution in MultiClass likelihood, we need GP model
# returns a list of probabilities of each class. Hence it is required to use
# latent_shape = 10.
latent_shape = torch.Size([10])
# Turns on "whiten" flag will help optimization for variational models.
gpmodule = gp.models.VariationalSparseGP(X=Xu, y=None, kernel=deep_kernel, Xu=Xu,
likelihood=likelihood, latent_shape=latent_shape,
num_data=60000, whiten=True, jitter=2e-6)
if args.cuda:
gpmodule.cuda()
optimizer = torch.optim.Adam(gpmodule.parameters(), lr=args.lr)
elbo = infer.JitTraceMeanField_ELBO() if args.jit else infer.TraceMeanField_ELBO()
loss_fn = elbo.differentiable_loss
for epoch in range(1, args.epochs + 1):
start_time = time.time()
train(args, train_loader, gpmodule, optimizer, loss_fn, epoch)
with torch.no_grad():
test(args, test_loader, gpmodule)
print("Amount of time spent for epoch {}: {}s\n"
.format(epoch, int(time.time() - start_time)))