in sample_info/modules/stability.py [0:0]
def weight_stability(t, n, eta, init_params, jacobians, ntk, init_preds, Y, continuous=False, without_sgd=True,
l2_reg_coef=0.0, large_model_regime=False, model=None, dataset=None, return_change_vectors=True,
**kwargs):
"""
:param without_sgd: if without_sgd = True, then only ||w1-w2|| will be returned,
otherwise (w1-w2)^T H Sigma^{-1} (w1-w2).
"""
if l2_reg_coef > 0:
ntk = ntk + l2_reg_coef * torch.eye(ntk.shape[0], dtype=torch.float, device=ntk.device)
ntk_inv = torch.inverse(ntk)
old_weights = get_weights_at_time_t(t=t, eta=eta, init_params=init_params, jacobians=jacobians,
ntk=ntk, ntk_inv=ntk_inv, init_preds=init_preds, Y=Y, continuous=continuous,
large_model_regime=large_model_regime, model=model, dataset=dataset, **kwargs)
steady_state_inv_cov = None
if not without_sgd:
if large_model_regime:
raise ValueError("SGD formula works only for small models")
# compute the SGD noise covariance matrix at the end
assert (model is not None) and (dataset is not None)
with utils.SetTemporaryParams(model=model, params=old_weights):
sgd_cov = get_sgd_covariance_full(model=model, dataset=dataset, cpu=False, **kwargs)
# add small amount of isotropic Gaussian noise to make sgd_cov invertible
sgd_cov += 1e-10 * torch.eye(sgd_cov.shape[0], device=sgd_cov.device, dtype=torch.float)
# now we compute H Sigma^{-1}
jacobians_cat = [v.view((v.shape[0], -1)) for k, v in jacobians.items()]
jacobians_cat = torch.cat(jacobians_cat, dim=1) # (n_samples * n_outputs, n_params)
H = torch.mm(jacobians_cat.T, jacobians_cat) + l2_reg_coef * torch.eye(jacobians_cat.shape[1],
device=ntk.device, dtype=torch.float)
# steady_state_inv_cov = torch.mm(H, torch.inverse(sgd_cov))
with utils.Timing(description="Solving the Lyapunov equation"):
steady_state_cov = solve_continuous_lyapunov(a=utils.to_numpy(H), q=utils.to_numpy(sgd_cov))
steady_state_cov = torch.tensor(steady_state_cov, dtype=torch.float, device=ntk.device)
# add small amount of isotropic Gaussian noise to make steady_state_cov invertible
steady_state_cov += 1e-10 * torch.eye(steady_state_cov.shape[0], device=steady_state_cov.device,
dtype=torch.float)
steady_state_inv_cov = torch.inverse(steady_state_cov)
change_vectors = []
change_quantities = []
n_outputs = init_preds.shape[-1]
for sample_idx in tqdm(range(n)):
example_indices = [i for i in range(n) if i != sample_idx]
example_output_indices = []
for i in example_indices:
example_output_indices.extend(range(i * n_outputs, (i + 1) * n_outputs))
new_ntk = ntk.clone()[example_output_indices]
new_ntk = new_ntk[:, example_output_indices]
new_ntk_inv = misc.update_ntk_inv(ntk=ntk, ntk_inv=ntk_inv, keep_indices=example_output_indices)
new_init_preds = init_preds[example_indices]
new_Y = Y[example_indices]
if not large_model_regime:
new_jacobians = dict()
for k, v in jacobians.items():
new_jacobians[k] = v[example_output_indices]
else:
new_jacobians = None
new_dataset = Subset(dataset, example_indices)
new_weights = get_weights_at_time_t(t=t, eta=eta * n / (n - 1), init_params=init_params,
jacobians=new_jacobians, ntk=new_ntk, ntk_inv=new_ntk_inv,
init_preds=new_init_preds, Y=new_Y, continuous=continuous,
large_model_regime=large_model_regime, model=model, dataset=new_dataset,
**kwargs)
total_change = 0.0
param_changes = dict()
for k in old_weights.keys():
param_changes[k] = (new_weights[k] - old_weights[k]).cpu() # to save GPU memory
if return_change_vectors:
change_vectors.append(param_changes)
if without_sgd:
for k in old_weights.keys():
total_change += torch.sum(param_changes[k] ** 2)
else:
param_changes = [v.flatten() for k, v in param_changes.items()]
param_changes = torch.cat(param_changes, dim=0)
total_change = torch.mm(param_changes.view((1, -1)),
torch.mm(steady_state_inv_cov.cpu(), param_changes.view(-1, 1)))
change_quantities.append(total_change)
return change_vectors, change_quantities