in sample_info/modules/stability.py [0:0]
def test_pred_stability(t, n, eta, ntk, test_train_ntk, train_init_preds, test_init_preds,
train_Y, l2_reg_coef=0.0, continuous=False):
if l2_reg_coef > 0:
ntk = ntk + l2_reg_coef * torch.eye(ntk.shape[0], dtype=torch.float, device=ntk.device)
ntk_inv = torch.inverse(ntk)
old_preds = get_test_predictions_at_time_t(t=t, eta=eta,
ntk=ntk,
test_train_ntk=test_train_ntk,
train_Y=train_Y,
train_init_preds=train_init_preds,
test_init_preds=test_init_preds,
continuous=continuous,
ntk_inv=ntk_inv)
n_outputs = train_init_preds.shape[-1]
change_vectors = []
change_quantities = []
for sample_idx in tqdm(range(n)):
example_indices = [i for i in range(n) if i != sample_idx]
example_output_indices = []
for i in example_indices:
example_output_indices.extend(range(i * n_outputs, (i + 1) * n_outputs))
new_ntk = ntk.clone()[example_output_indices]
new_ntk = new_ntk[:, example_output_indices]
new_test_train_ntk = test_train_ntk[:, example_output_indices]
new_ntk_inv = misc.update_ntk_inv(ntk=ntk, ntk_inv=ntk_inv, keep_indices=example_output_indices)
new_train_init_preds = train_init_preds[example_indices]
new_train_Y = train_Y[example_indices]
new_preds = get_test_predictions_at_time_t(
t=t,
eta=eta * n / (n-1),
train_Y=new_train_Y,
train_init_preds=new_train_init_preds,
test_init_preds=test_init_preds,
continuous=continuous,
ntk=new_ntk,
ntk_inv=new_ntk_inv,
test_train_ntk=new_test_train_ntk)
change_vectors.append(new_preds - old_preds)
change_quantities.append(torch.sum((new_preds - old_preds) ** 2, dim=1).mean(dim=0))
return change_vectors, change_quantities