in sample_info/modules/ntk.py [0:0]
def compute_jacobian(self, model, dataset, cpu=True, description="", output_key='pred',
max_num_examples=2 ** 30, num_workers=0, seed=42, **kwargs):
np.random.seed(seed)
model.eval()
if num_workers > 0:
torch.multiprocessing.set_sharing_strategy('file_system')
torch.multiprocessing.set_start_method('spawn', force=True)
n_examples = min(len(dataset), max_num_examples)
loader = DataLoader(dataset=Subset(dataset, range(n_examples)),
batch_size=1, shuffle=False, num_workers=num_workers)
jacobians = defaultdict(list)
# loop over the dataset
n_outputs = None
for inputs_batch, labels_batch in tqdm(loader, desc=description):
if isinstance(inputs_batch, torch.Tensor):
inputs_batch = [inputs_batch]
if not isinstance(labels_batch, list):
labels_batch = [labels_batch]
with torch.set_grad_enabled(True):
outputs = model.forward(inputs=inputs_batch, labels=labels_batch, loader=loader, **kwargs)
preds = outputs[output_key][0]
n_outputs = preds.shape[-1]
for output_idx in range(n_outputs):
retain_graph = (output_idx != n_outputs - 1)
with torch.set_grad_enabled(True):
cur_jacobians = torch.autograd.grad(preds[output_idx], model.parameters(),
retain_graph=retain_graph)
if cpu:
cur_jacobians = [utils.to_cpu(v) for v in cur_jacobians]
if self.projection == 'none':
for (k, _), v in zip(model.named_parameters(), cur_jacobians):
jacobians[k].append(v)
if self.projection == 'random-subset':
self._prepare_random_subset_proj_indices(model.named_parameters())
for (k, _), v in zip(model.named_parameters(), cur_jacobians):
v = v.flatten()
n_select = len(self._random_subset_proj_indices[k])
v_proj = v[self._random_subset_proj_indices[k]] * np.sqrt(v.shape[0] / n_select)
jacobians[k].append(v_proj)
if self.projection == 'very-sparse':
self._prepare_very_sparse_proj_matrix(model.named_parameters())
for (k, _), v in zip(model.named_parameters(), cur_jacobians):
# now that the projection matrix is ready, we can project v into the smaller subspace
v = v.flatten()
v_proj = self._very_sparse_proj_matrix[k].T.dot(utils.to_numpy(v))
v_proj = torch.tensor(v_proj, dtype=v.dtype, device=v.device)
jacobians[k].append(v_proj)
for k in jacobians:
jacobians[k] = torch.stack(jacobians[k]) # n_samples * n_outputs x n_params
assert len(jacobians[k]) == n_outputs * n_examples
return jacobians