in sample_info/modules/sgd.py [0:0]
def get_sgd_covariance_diagonal(model, dataset, cpu=True, max_num_examples=2**30, num_workers=0, seed=42, **kwargs):
""" Returns the diagonal of the per-sample SGD noise covariance matrix.
The formula is \Sigma = \frac{1}{n} \sum_{i=1}^n g_i g_i^T - \bar{g} \bar{g}^T, where g_i is the gradient
corresponding to the ith example and \bar{g} is the total gradient. Note that we can ignore weight decay here,
as adding weight decay doesn't change the SGD noise covariance matrix.
"""
np.random.seed(seed)
model.eval()
if num_workers > 0:
torch.multiprocessing.set_sharing_strategy('file_system')
torch.multiprocessing.set_start_method('spawn', force=True)
n_examples = min(len(dataset), max_num_examples)
loader = DataLoader(dataset=Subset(dataset, range(n_examples)),
batch_size=1, shuffle=False, num_workers=num_workers)
grad_sum = defaultdict(lambda: None)
grad_squared_sum = defaultdict(lambda: None)
# loop over the dataset
for inputs_batch, labels_batch in tqdm(loader, desc='Computing sgd noise covariance...'):
if isinstance(inputs_batch, torch.Tensor):
inputs_batch = [inputs_batch]
if not isinstance(labels_batch, list):
labels_batch = [labels_batch]
with torch.set_grad_enabled(True):
outputs = model.forward(inputs=inputs_batch, labels=labels_batch, loader=loader, **kwargs)
batch_losses, outputs = model.compute_loss(inputs=inputs_batch, labels=labels_batch, outputs=outputs,
loader=loader, dataset=loader.dataset)
batch_total_loss = sum([loss for name, loss in batch_losses.items()])
grad = torch.autograd.grad(batch_total_loss, model.parameters())
if cpu:
grad = [utils.to_cpu(v) for v in grad]
for (k, _), v in zip(model.named_parameters(), grad):
if grad_sum[k] is None:
grad_sum[k] = v
else:
grad_sum[k] += v
if grad_squared_sum[k] is None:
grad_squared_sum[k] = v**2
else:
grad_squared_sum[k] += v**2
out = dict()
for k in grad_sum.keys():
out[k] = grad_squared_sum[k] / n_examples - (grad_sum[k] / n_examples) ** 2
return out