in sample_info/modules/ntk.py [0:0]
def get_weights_at_time_t(t, eta, init_params, ntk, init_preds, Y, jacobians=None, continuous=False,
ntk_inv=None, large_model_regime=False, model=None, dataset=None, batch_size=256):
""" Computes the at time t. Since storing the full Jacobian can be impossible for large networks, the code
can work in two regimes (given by large_model_regime):
1. [small model regime] full Jacobian should be passed using the `jacobians` argument
2. [large model regime] `model`, `dataset`, ['example_indices'] should be specified, so we can go
over and compute all Jacobians one by one.
:param t: time/iteration. If t=None, then final weights are going to be returned (assuming ntk is invertible).
:param eta: the learning rate. NOTE: if network's loss is a mean over examples (instead of sum) then
eta=learning_rate / n_samples.
:param init_params: a dictionary containing parameters at initialziation.
:param jacobians: dictionary of Jacobians at initialization.
:param ntk: the neural tangent kernel.
:param init_preds: predictions at initialization # (n_samples, n_outputs)
:param Y: labels # (n_samples, n_outputs)
:param continuous: True => continuous GD, False => discrete GD
:param ntk_inv: Inverse of the ntk matrix. This argument is optional.
:param large_model_regime: which regime is it
# the following parameters are used in the large data regime only.
:param model: the model.
:param dataset: the training dataset for which ntk was computed for.
:param batch_size: the batch_size argument to pass to the DataLoader.
"""
# check that the needed arguments are not None
if large_model_regime:
assert (model is not None) and (dataset is not None)
else:
assert (jacobians is not None)
exp_matrix = compute_exp_matrix(t=t, eta=eta, ntk=ntk, continuous=continuous)
# compute the ntk inverse and multiply it with f(X) - Y
if ntk_inv is None:
ntk_inv = torch.inverse(ntk)
init_preds = init_preds.reshape((-1, 1)) # (n_samples * n_outputs, 1)
Y = Y.reshape((-1, 1)) # (n_samples * n_outputs, 1)
identity_matrix = torch.eye(ntk.shape[0], dtype=torch.float, device=ntk.device)
rhs_vector = -torch.mm(ntk_inv, torch.mm(identity_matrix - exp_matrix, init_preds - Y))
# Now we need to compute jacobians * rhs_vector. This corresponds to the sum of all output gradients weighted
# with rhs_vector coefficients. Therefore, we can go over (examples, output) pairs and sum their gradients.
out = defaultdict(lambda: None)
# is the Jacobian is already computed we can just use it
if not large_model_regime:
for k, v in jacobians.items():
v = v.reshape((v.shape[0], -1)) # (n_samples * n_outputs, n_dim)
out[k] = torch.mm(v.T, rhs_vector)[:, 0] # (n_dim,)
else: # we need to go over examples and compute the Jacobian
model.eval()
loader = DataLoader(dataset=dataset, batch_size=batch_size, shuffle=False)
# We want to compute \frac{\partial preds}{\partial W} and take a linear combination of its columns
# with coefficients given with rhs_vector. To not go over all examples/outputs one by one, we can compute
# compute #\frac{\partial <preds, coefficients>}{\partial W}.
# loop over the dataset
example_output_index = 0
for inputs_batch, labels_batch in tqdm(loader):
if isinstance(inputs_batch, torch.Tensor):
inputs_batch = [inputs_batch]
if not isinstance(labels_batch, list):
labels_batch = [labels_batch]
with torch.set_grad_enabled(True):
outputs = model.forward(inputs=inputs_batch, labels=labels_batch, loader=loader)
preds = outputs['pred']
preds = preds.reshape((-1, 1)) # (n_examples * n_outputs, 1)
coefficients = rhs_vector[example_output_index:example_output_index + preds.shape[0]].to(preds.device)
loss = torch.sum(preds * coefficients)
cur_jacobians = torch.autograd.grad(loss, model.parameters())
for (k, _), v in zip(model.named_parameters(), cur_jacobians):
v = v.detach().to(ntk.device).flatten()
if out[k] is None:
out[k] = v
else:
out[k] += v
example_output_index += preds.shape[0]
assert example_output_index == rhs_vector.shape[0]
# add the initialized values
for k, v in init_params.items():
out[k] += v.flatten()
out[k] = out[k].reshape(v.shape)
return out