def epoch_diagnostics()

in torch_svrg.py [0:0]


    def epoch_diagnostics(self, train_loss, train_err, test_loss, test_err):
        """
        Called after recalibrate, saves stats out to disk.
        """
        m = self.nbatches
        logging.info("Epoch diagnostics computation")

        layernum = 0
        layer_gradient_norm_sqs = []
        gavg_norm_acum = 0.0
        gavg_acum = []
        for group in self.param_groups:
            for p in group['params']:

                layer_gradient_norm_sqs.append([])
                gavg = self.state[p]['gavg'].cpu()
                gavg_acum.append(gavg.numpy())
                gavg_norm_acum += gavg.norm()**2 #torch.dot(gavg, gavg)
                layernum += 1

        gradient_norm_sqs = []
        vr_step_variance = []
        cos_acums = []
        variances = []

        for batch_id in range(m):
            norm_acum = 0.0
            ginorm_acum = 0.0
            vr_acum = 0.0
            layernum = 0
            cos_acum = 0.0
            var_acum = 0.0
            for group in self.param_groups:
                for p in group['params']:
                    param_state = self.state[p]

                    gktbl = param_state['gktbl']
                    gavg = param_state['gavg'].type_as(p.data).cpu()

                    gi = gktbl[batch_id, :]
                    var_norm_sq = (gi-gavg).norm()**2 #torch.dot(gi-gavg, gi-gavg)
                    norm_acum += var_norm_sq
                    ginorm_acum += gi.norm()**2 #torch.dot(gi, gi)
                    layer_gradient_norm_sqs[layernum].append(var_norm_sq)

                    gktbl_old = param_state['gktbl_old']
                    gavg_old = param_state['gavg_old'].type_as(p.data).cpu()
                    gi_old = gktbl_old[batch_id, :]
                    #pdb.set_trace()
                    vr_step = gi - gi_old + gavg_old
                    vr_acum += (vr_step - gavg).norm()**2 #torch.dot(vr_step - gavg, vr_step - gavg)
                    cos_acum += torch.sum(gavg*gi)

                    var_acum += (gi - gavg).norm()**2

                    layernum += 1
            gradient_norm_sqs.append(norm_acum)
            vr_step_variance.append(vr_acum)
            cosim = cos_acum/math.sqrt(ginorm_acum*gavg_norm_acum)
            #pdb.set_trace()
            cos_acums.append(cosim)
            variances.append(var_acum)

        variance = sum(variances)/len(variances)

        print("mean cosine: {}".format(sum(cos_acums)/len(cos_acums)))

        #pdb.set_trace()

        with open('stats/{}fastdiagnostics_epoch{}.pkl'.format(self.test_name, self.epoch), 'wb') as output:
            pickle.dump({
                'train_loss': train_loss,
                'train_err': train_err,
                'test_loss': test_loss,
                'test_err': test_err,
                'epoch': self.epoch,
                #'layer_gradient_norm_sqs': layer_gradient_norm_sqs,
                #'gradient_norm_sqs': gradient_norm_sqs,
                #'vr_step_variance': vr_step_variance,
                #'cosine_distances': cos_acums,
                #'variances': variances,
                'variance': variance,
                #'gavg_norm': gavg_norm_acum,
                #'gavg': gavg_acum,
                #'iterate_distances': self.inrun_iterate_distances,
                #'grad_distances': self.inrun_grad_distances,
            }, output)
        print("Epoch diagnostics saved")
        #pdb.set_trace()

        self.inrun_iterate_distances = []
        self.inrun_grad_distances = []