in sample_info/modules/ntk.py [0:0]
def prepare_needed_items(model, train_data, test_data=None, projection='none', cpu=False,
batch_size=256, **kwargs):
jacobian_estimator = JacobianEstimator(projection=projection, **kwargs)
train_jacobians = jacobian_estimator.compute_jacobian(model=model, dataset=train_data,
output_key='pred', cpu=cpu)
test_jacobians = None
if test_data is not None:
test_jacobians = jacobian_estimator.compute_jacobian(model=model, dataset=test_data,
output_key='pred', cpu=cpu)
train_init_preds = utils.apply_on_dataset(model=model, dataset=train_data, cpu=cpu,
batch_size=batch_size)['pred']
test_init_preds = None
if test_data is not None:
test_init_preds = utils.apply_on_dataset(model=model, dataset=test_data, cpu=cpu,
batch_size=batch_size)['pred']
init_params = dict(model.named_parameters())
if cpu:
for k, v in init_params.items():
init_params[k] = v.to('cpu')
ntk = compute_ntk(jacobians=train_jacobians)
lamb, _ = torch.eig(ntk)
lamb = lamb[:, 0]
logging.info(f'Min eigenvalue of NTK: {torch.min(lamb).item():.3f}\t'
f'Max eigenvalue of NTK: {torch.max(lamb).item():.3f}')
if torch.min(lamb).item() < 0:
logging.warning('The lowest eigenvalue of NTK is negative, consider adding at least small weight decay.')
test_train_ntk = None
if test_data is not None:
test_train_ntk = compute_test_train_ntk(train_jacobians=train_jacobians,
test_jacobians=test_jacobians)
def extract_labels(data):
ys = [utils.to_tensor(y, device=ntk.device).view((-1,)) for x, y in data]
return torch.stack(ys).float()
train_Y = extract_labels(train_data)
test_Y = None
if test_data is not None:
test_Y = extract_labels(test_data)
return {
'jacobian_estimator': jacobian_estimator,
'train_jacobians': train_jacobians,
'test_jacobians': test_jacobians,
'train_init_preds': train_init_preds,
'test_init_preds': test_init_preds,
'init_params': init_params,
'ntk': ntk,
'test_train_ntk': test_train_ntk,
'train_Y': train_Y,
'test_Y': test_Y
}