def test_pred_stability()

in sample_info/modules/stability.py [0:0]


def test_pred_stability(t, n, eta, ntk, test_train_ntk, train_init_preds, test_init_preds,
                        train_Y, l2_reg_coef=0.0, continuous=False):
    if l2_reg_coef > 0:
        ntk = ntk + l2_reg_coef * torch.eye(ntk.shape[0], dtype=torch.float, device=ntk.device)

    ntk_inv = torch.inverse(ntk)

    old_preds = get_test_predictions_at_time_t(t=t, eta=eta,
                                               ntk=ntk,
                                               test_train_ntk=test_train_ntk,
                                               train_Y=train_Y,
                                               train_init_preds=train_init_preds,
                                               test_init_preds=test_init_preds,
                                               continuous=continuous,
                                               ntk_inv=ntk_inv)

    n_outputs = train_init_preds.shape[-1]
    change_vectors = []
    change_quantities = []
    for sample_idx in tqdm(range(n)):
        example_indices = [i for i in range(n) if i != sample_idx]
        example_output_indices = []
        for i in example_indices:
            example_output_indices.extend(range(i * n_outputs, (i + 1) * n_outputs))

        new_ntk = ntk.clone()[example_output_indices]
        new_ntk = new_ntk[:, example_output_indices]
        new_test_train_ntk = test_train_ntk[:, example_output_indices]
        new_ntk_inv = misc.update_ntk_inv(ntk=ntk, ntk_inv=ntk_inv, keep_indices=example_output_indices)

        new_train_init_preds = train_init_preds[example_indices]
        new_train_Y = train_Y[example_indices]

        new_preds = get_test_predictions_at_time_t(
            t=t,
            eta=eta * n / (n-1),
            train_Y=new_train_Y,
            train_init_preds=new_train_init_preds,
            test_init_preds=test_init_preds,
            continuous=continuous,
            ntk=new_ntk,
            ntk_inv=new_ntk_inv,
            test_train_ntk=new_test_train_ntk)

        change_vectors.append(new_preds - old_preds)
        change_quantities.append(torch.sum((new_preds - old_preds) ** 2, dim=1).mean(dim=0))

    return change_vectors, change_quantities