def minibatch_stats()

in diagnostics.py [0:0]


def minibatch_stats():
    # This is just copied from run.py, needs to be modified to work.
    if False:
        batch_idx, (data, target) = next(enumerate(train_loader))
        if args.cuda:
            data, target = data.cuda(), target.cuda()
        data, target = Variable(data), Variable(target)

        idx = 0
        ###
        optimizer.zero_grad()
        output = model(data)
        #pdb.set_trace()
        loss = criterion(output[idx, None], target[idx])
        loss.backward()

        baseline_sq = 0.0

        for group in optimizer.param_groups:
            for p in group['params']:
                gk = p.grad.data
                param_state = optimizer.state[p]
                param_state['baseline'] = gk.clone()
                baseline_sq += torch.dot(gk, gk)

        for idx in range(1, 5):
            optimizer.zero_grad()
            output = model(data)
            loss = criterion(output[idx, None], target[idx])
            loss.backward()

            total_dot = 0.0
            square_norm = 0.0
            corrs = []

            for group in optimizer.param_groups:
                for p in group['params']:
                    gk = p.grad.data
                    param_state = optimizer.state[p]
                    baseline = param_state['baseline']
                    # Compute correlation
                    dp = torch.dot(baseline, gk)
                    corr = dp/(torch.norm(baseline)*torch.norm(gk))
                    corrs.append(corr)

                    total_dot += dp
                    square_norm += torch.dot(gk, gk)

            total_corr = total_dot/math.sqrt(square_norm*baseline_sq)
            logging.info("i={}, corr: {}, layers: {}".format(idx, total_corr, corrs))