in sample_info/modules/stability.py [0:0]
def training_loss_stability(ts, n, eta, ntk, init_preds, Y, l2_reg_coef=0.0, continuous=False):
if l2_reg_coef > 0:
ntk = ntk + l2_reg_coef * torch.eye(ntk.shape[0], dtype=torch.float, device=ntk.device)
losses_without_excluding = [compute_training_loss_at_time_t(t=t,
eta=eta,
ntk=ntk,
init_preds=init_preds,
Y=Y,
continuous=continuous) for t in ts]
losses_without_excluding = torch.stack(losses_without_excluding)
n_outputs = init_preds.shape[-1]
change_quantities = []
change_vectors = []
for sample_idx in tqdm(range(n)):
example_indices = [i for i in range(n) if i != sample_idx]
example_output_indices = []
for i in example_indices:
example_output_indices.extend(range(i * n_outputs, (i + 1) * n_outputs))
new_ntk = ntk.clone()[example_output_indices]
new_ntk = new_ntk[:, example_output_indices]
new_init_preds = init_preds[example_indices]
new_Y = Y[example_indices]
losses = [compute_training_loss_at_time_t(t=t,
eta=eta * n / (n - 1),
ntk=new_ntk,
init_preds=new_init_preds,
Y=new_Y,
continuous=continuous) for t in ts]
losses = torch.stack(losses)
change_vectors.append(losses - losses_without_excluding)
change_quantities.append(torch.mean((losses - losses_without_excluding) ** 2))
return change_vectors, change_quantities