def main()

in examples/contrib/gp/sv-dkl.py [0:0]


def main(args):
    data_dir = args.data_dir if args.data_dir is not None else get_data_directory(__file__)
    train_loader = get_data_loader(dataset_name='MNIST',
                                   data_dir=data_dir,
                                   batch_size=args.batch_size,
                                   dataset_transforms=[transforms.Normalize((0.1307,), (0.3081,))],
                                   is_training_set=True,
                                   shuffle=True)
    test_loader = get_data_loader(dataset_name='MNIST',
                                  data_dir=data_dir,
                                  batch_size=args.test_batch_size,
                                  dataset_transforms=[transforms.Normalize((0.1307,), (0.3081,))],
                                  is_training_set=False,
                                  shuffle=False)
    if args.cuda:
        train_loader.num_workers = 1
        test_loader.num_workers = 1

    cnn = CNN()

    # Create deep kernel by warping RBF with CNN.
    # CNN will transform a high dimension image into a low dimension 2D tensors for RBF kernel.
    # This kernel accepts inputs are inputs of CNN and gives outputs are covariance matrix of RBF
    # on outputs of CNN.
    rbf = gp.kernels.RBF(input_dim=10, lengthscale=torch.ones(10))
    deep_kernel = gp.kernels.Warping(rbf, iwarping_fn=cnn)

    # init inducing points (taken randomly from dataset)
    batches = []
    for i, (data, _) in enumerate(train_loader):
        batches.append(data)
        if i >= ((args.num_inducing - 1) // args.batch_size):
            break
    Xu = torch.cat(batches)[:args.num_inducing].clone()

    if args.binary:
        likelihood = gp.likelihoods.Binary()
        latent_shape = torch.Size([])
    else:
        # use MultiClass likelihood for 10-class classification problem
        likelihood = gp.likelihoods.MultiClass(num_classes=10)
        # Because we use Categorical distribution in MultiClass likelihood, we need GP model
        # returns a list of probabilities of each class. Hence it is required to use
        # latent_shape = 10.
        latent_shape = torch.Size([10])

    # Turns on "whiten" flag will help optimization for variational models.
    gpmodule = gp.models.VariationalSparseGP(X=Xu, y=None, kernel=deep_kernel, Xu=Xu,
                                             likelihood=likelihood, latent_shape=latent_shape,
                                             num_data=60000, whiten=True, jitter=2e-6)
    if args.cuda:
        gpmodule.cuda()

    optimizer = torch.optim.Adam(gpmodule.parameters(), lr=args.lr)

    elbo = infer.JitTraceMeanField_ELBO() if args.jit else infer.TraceMeanField_ELBO()
    loss_fn = elbo.differentiable_loss

    for epoch in range(1, args.epochs + 1):
        start_time = time.time()
        train(args, train_loader, gpmodule, optimizer, loss_fn, epoch)
        with torch.no_grad():
            test(args, test_loader, gpmodule)
        print("Amount of time spent for epoch {}: {}s\n"
              .format(epoch, int(time.time() - start_time)))