def compute_eigenvalues()

in plot_path_tools.py [0:0]


def compute_eigenvalues(gen, dis, dataloader, config,
                        model_loss_gen, model_loss_dis,
                        device=None, n_eigs=20, verbose=False, imaginary=False):
    """
    Computes stats for plotting path between checkpoint_1 and checkpoint_2.

    Parameters
    ----------
    gen: Generator
    dis: Discriminator
    dataloader: pytorch DataLoader
        real data loader (mnist)
    config: Namespace
        configuration (hyper-parameters) for the generator/discriminator
    """

    start_time = time.time()

    grad_gen_epoch = [torch.zeros_like(p) for p in gen.parameters()]
    grad_dis_epoch = [torch.zeros_like(p) for p in dis.parameters()]
    n_data = 0
    for i, x_true in enumerate(dataloader):
        print(i)
        x_true = x_true[0]
        z = torch.randn(x_true.size(0), config.nz, 1, 1)

        x_true = x_true.to(device)
        z = z.to(device)

        ################# Compute Loss #########################
        # TODO: Needs to be changed to be able to handle different kind of loss
        x_gen = gen(z)
        dis_loss, _, _ = model_loss_dis(x_true, x_gen, dis, device)
        gen_loss, _ = model_loss_gen(x_gen, dis, device)
        # p_true, p_gen = dis(x_true), dis(x_gen)
        # gen_loss, dis_loss = utils.compute_loss(p_true, p_gen, mode=config.model)
        if config.model == 'wgan_gp':
            penalty = dis.get_penalty(x_true, x_gen).mean()
            dis_loss += config.gp_lambda * penalty
        else:
            penalty = torch.zeros(1)

        grad_gen = autograd.grad(gen_loss, gen.parameters(), create_graph=True)
        grad_dis = autograd.grad(dis_loss, dis.parameters(), create_graph=True)

        for i, g in enumerate(grad_gen):
            grad_gen_epoch[i] += g * len(x_true)

        for i, g in enumerate(grad_dis):
            grad_dis_epoch[i] += g * len(x_true)
        n_data += len(x_true)

    grad_gen_epoch = [g / n_data for g in grad_gen_epoch]
    grad_dis_epoch = [g / n_data for g in grad_dis_epoch]

    t0 = time.time()
    A = JacobianVectorProduct(grad_gen_epoch, list(gen.parameters()))
    if imaginary:
        gen_eigs = linalg.eigs(A, k=n_eigs, which='LI')[0]
    else:
        gen_eigs = linalg.eigsh(A, k=n_eigs)[0]
    print("Time to compute Eig-values: %.2f" % (time.time() - t0))

    t0 = time.time()
    A = JacobianVectorProduct(grad_dis_epoch, list(dis.parameters()))
    if imaginary:
        dis_eigs = linalg.eigs(A, k=n_eigs, which='LI')[0]
    else:
        dis_eigs = linalg.eigsh(A, k=n_eigs)[0]
    print("Time to compute Eig-values: %.2f" % (time.time() - t0))

    t0 = time.time()
    grad = grad_gen_epoch + grad_dis_epoch
    params = list(gen.parameters()) + list(dis.parameters())
    A = JacobianVectorProduct(grad, params)
    if imaginary:
        game_eigs = linalg.eigs(A, k=n_eigs, which='LI')[0]
    else:
        game_eigs = linalg.eigs(A, k=n_eigs)[0]
    print("Time to compute Eig-values: %.2f" % (time.time() - t0))

    if verbose:
        print(gen_eigs[:5])
        print(dis_eigs[:5])
        print(game_eigs[:5])
        print("Time to finish: %.2f minutes" % ((time.time() - start_time) / 60.))

    return gen_eigs, dis_eigs, game_eigs