in grok/measure.py [0:0]
def get_sharpness(data_loader, model, subspace_dim=10, epsilon=1e-3, maxiter=10):
"""
Compute the sharpness around some point in weight space, as specified
in Keskar et. al. (2016) Sec 2.2.2:
https://arxiv.org/pdf/1609.04836.pdf
See:
https://gist.github.com/arthurmensch/c55ac413868550f89225a0b9212aa4cd
https://gist.github.com/gngdb/a9f912df362a85b37c730154ef3c294b
https://github.com/keskarnitish/large-batch-training
https://github.com/wenwei202/smoothout
https://github.com/keras-team/keras/pull/3064
"""
x0 = get_weights(model)
f_x0, _ = get_loss_and_grads(x0, model, data_loader)
f_x0 = -f_x0
logging.info("min loss f_x0 = {loss:.4f}".format(loss=f_x0))
if 0 == subspace_dim:
x_min = np.reshape(x0 - epsilon * (np.abs(x0) + 1), (x0.shape[0], 1))
x_max = np.reshape(x0 + epsilon * (np.abs(x0) + 1), (x0.shape[0], 1))
bounds = np.concatenate([x_min, x_max], 1)
func = lambda x: get_loss_and_grads(x, model, data_loader)
init_guess = x0
else:
assert subspace_dim <= x0.shape[0]
# Computed via Keskar, et. al
# https://arxiv.org/pdf/1609.04836.pdf
A_plus = np.random.rand(subspace_dim, x0.shape[0]) * 2.0 - 1.0
A_plus_norm = np.linalg.norm(A_plus, axis=1)
A_plus = A_plus / np.reshape(A_plus_norm, (subspace_dim, 1))
A = np.linalg.pinv(A_plus)
abs_bound = epsilon * (np.abs(np.dot(A_plus, x0)) + 1)
abs_bound = np.reshape(abs_bound, (abs_bound.shape[0], 1))
bounds = np.concatenate([-abs_bound, abs_bound], 1)
def func(y):
f_loss, f_grads = get_loss_and_grads(
x0 + np.dot(A, y),
model,
data_loader,
)
return f_loss, np.dot(np.transpose(A), f_grads)
init_guess = np.zeros(subspace_dim)
minimum_x, f_x, d = scipy.optimize.fmin_l_bfgs_b(
func,
init_guess,
maxiter=maxiter,
bounds=bounds,
disp=1,
)
f_x = -f_x
logging.info("max loss f_x = {loss:.4f}".format(loss=f_x))
# Eq 4 in Keskar
phi = (f_x - f_x0) / (1 + f_x0) * 100
# Restore parameter values
x0 = torch.from_numpy(x0).float()
# x0 = x0.cuda()
x_start = 0
for p in model.parameters():
param_size = p.data.size()
param_idx = 1
for s in param_size:
param_idx *= s
x_part = x0[x_start : x_start + param_idx]
p.data = x_part.view(param_size)
x_start += param_idx
return phi