def logging_pass_end()

in torch_svrg.py [0:0]


    def logging_pass_end(self, batch_idx):
        m = self.nbatches
        logging.info("logging diagnostics computation")

        gradient_sqs = []
        vr_step_sqs = []
        forth_sqs = []
        dist_sq_acum = 0.0


        for group in self.param_groups:
            for p in group['params']:
                param_state = self.state[p]
                tilde_x = param_state['tilde_x']
                iterate_diff = p.data - tilde_x
                dist_sq_acum += iterate_diff.norm()**2 #torch.dot(iterate_diff,iterate_diff)

        dist = math.sqrt(dist_sq_acum)

        for batch_id in range(m):
            grad_norm_acum = 0.0
            vr_norm_acum = 0.0
            forth_acum = 0.0
            for group in self.param_groups:
                for p in group['params']:
                    param_state = self.state[p]

                    gktbl = param_state['gktbl']
                    gavg = param_state['gavg'].type_as(p.data).cpu()
                    gi = gktbl[batch_id, :].type_as(p.data).cpu()

                    # Logging versions are at current location xk,
                    # compared to gavg/tktbl which are at xtilde
                    logging_gktbl = param_state['logging_gktbl']
                    logging_gavg = param_state['logging_gavg'].type_as(p.data).cpu()
                    logging_gi = logging_gktbl[batch_id, :].type_as(p.data).cpu()

                    vr_step = (logging_gi - gi + gavg) - logging_gavg
                    gi_step = logging_gi - logging_gavg
                    grad_norm_acum += gi_step.pow(2.0).sum().item()
                    vr_norm_acum += vr_step.pow(2.0).sum().item()
                    forth_acum += gi_step.pow(2.0).sum().item()
            gradient_sqs.append(grad_norm_acum)
            vr_step_sqs.append(vr_norm_acum)
            forth_sqs.append(forth_acum**2)

        # Compute variance numbers
        gradient_variance = sum(gradient_sqs)/m
        fourth_moment = sum(forth_sqs)/m - gradient_variance**2
        vr_step_variance = sum(vr_step_sqs)/m
        logging.info("gradient variance: {} vr: {}, ratio vr/g: {}".format(
            gradient_variance, vr_step_variance, vr_step_variance/gradient_variance))
        logging.info(f"forth: {fourth_moment} relative std: {math.sqrt(fourth_moment)/gradient_variance} rel SE: {math.sqrt(fourth_moment/m)/gradient_variance}")
        logging.info("self.logging_evals: {}".format(self.logging_evals))
        #pdb.set_trace()

        self.gradient_variances.append(gradient_variance)
        self.vr_step_variances.append(vr_step_variance)
        self.batch_indices.append(batch_idx)
        self.iterate_distances.append(dist)