in diagnostics.py [0:0]
def minibatch_stats():
# This is just copied from run.py, needs to be modified to work.
if False:
batch_idx, (data, target) = next(enumerate(train_loader))
if args.cuda:
data, target = data.cuda(), target.cuda()
data, target = Variable(data), Variable(target)
idx = 0
###
optimizer.zero_grad()
output = model(data)
#pdb.set_trace()
loss = criterion(output[idx, None], target[idx])
loss.backward()
baseline_sq = 0.0
for group in optimizer.param_groups:
for p in group['params']:
gk = p.grad.data
param_state = optimizer.state[p]
param_state['baseline'] = gk.clone()
baseline_sq += torch.dot(gk, gk)
for idx in range(1, 5):
optimizer.zero_grad()
output = model(data)
loss = criterion(output[idx, None], target[idx])
loss.backward()
total_dot = 0.0
square_norm = 0.0
corrs = []
for group in optimizer.param_groups:
for p in group['params']:
gk = p.grad.data
param_state = optimizer.state[p]
baseline = param_state['baseline']
# Compute correlation
dp = torch.dot(baseline, gk)
corr = dp/(torch.norm(baseline)*torch.norm(gk))
corrs.append(corr)
total_dot += dp
square_norm += torch.dot(gk, gk)
total_corr = total_dot/math.sqrt(square_norm*baseline_sq)
logging.info("i={}, corr: {}, layers: {}".format(idx, total_corr, corrs))