def training_loss_stability()

in sample_info/modules/stability.py [0:0]


def training_loss_stability(ts, n, eta, ntk, init_preds, Y, l2_reg_coef=0.0, continuous=False):
    if l2_reg_coef > 0:
        ntk = ntk + l2_reg_coef * torch.eye(ntk.shape[0], dtype=torch.float, device=ntk.device)

    losses_without_excluding = [compute_training_loss_at_time_t(t=t,
                                                                eta=eta,
                                                                ntk=ntk,
                                                                init_preds=init_preds,
                                                                Y=Y,
                                                                continuous=continuous) for t in ts]
    losses_without_excluding = torch.stack(losses_without_excluding)

    n_outputs = init_preds.shape[-1]
    change_quantities = []
    change_vectors = []
    for sample_idx in tqdm(range(n)):
        example_indices = [i for i in range(n) if i != sample_idx]
        example_output_indices = []
        for i in example_indices:
            example_output_indices.extend(range(i * n_outputs, (i + 1) * n_outputs))

        new_ntk = ntk.clone()[example_output_indices]
        new_ntk = new_ntk[:, example_output_indices]
        new_init_preds = init_preds[example_indices]
        new_Y = Y[example_indices]

        losses = [compute_training_loss_at_time_t(t=t,
                                                  eta=eta * n / (n - 1),
                                                  ntk=new_ntk,
                                                  init_preds=new_init_preds,
                                                  Y=new_Y,
                                                  continuous=continuous) for t in ts]
        losses = torch.stack(losses)

        change_vectors.append(losses - losses_without_excluding)
        change_quantities.append(torch.mean((losses - losses_without_excluding) ** 2))

    return change_vectors, change_quantities