def inverse_hvp_lissa()

in sample_info/modules/influence_functions.py [0:0]


def inverse_hvp_lissa(model, dataset, v, batch_size=128, scale=10,
                      damping=0.0, num_samples=1, recursion_depth=1000, num_workers=0):
    """
    reference: https://github.com/kohpangwei/influence-release/blob/master/influence/genericNeuralNet.py#L475
    """
    model.eval()
    inverse_hvp = None
    loader = DataLoader(dataset=dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)

    for i in range(num_samples):
        cur_estimate = v.copy()

        for _ in tqdm(range(recursion_depth), desc='computing inverse hvp'):
            for x, y in loader:
                with torch.set_grad_enabled(True):
                    outputs = model.forward(inputs=[x], labels=[y])
                    losses, _ = model.compute_loss(inputs=[x], labels=[y], outputs=outputs)
                    total_loss = sum([losses[k] for k in losses.keys()])
                    hv = hessian_vector_product(loss=total_loss, params=tuple(model.parameters()), v=cur_estimate)

                cur_estimate = [a + (1 - damping) * b - c / scale for (a, b, c) in zip(v, cur_estimate, hv)]

                break

        if inverse_hvp is None:
            inverse_hvp = [b / scale for b in cur_estimate]
        else:
            inverse_hvp = [a + b / scale for (a, b) in zip(inverse_hvp, cur_estimate)]

    inverse_hvp = [a / num_samples for a in inverse_hvp]
    return inverse_hvp