def calculate_and_log_all_metrics_train()

in lib/utils/metrics.py [0:0]


    def calculate_and_log_all_metrics_train(
            self, curr_iter, timer, suffix=''):
        """Calculate and log metrics for training."""

        # To be safe, we always read lr from workspace.
        self.lr = float(
            workspace.FetchBlob('gpu_{}/lr'.format(cfg.ROOT_GPU_ID)))

        # To be safe, we only trust what we load from workspace.
        cur_batch_size = get_batch_size_from_workspace()

        # We only compute loss for train.
        # We multiply by cfg.MODEL.GRAD_ACCUM_PASS to calibrate.
        cur_loss = sum_multi_gpu_blob('loss')
        cur_loss = float(np.sum(cur_loss))

        self.aggr_loss += cur_loss * cur_batch_size
        self.aggr_batch_size += cur_batch_size

        if not cfg.MODEL.MULTI_LABEL:
            accuracy_metrics = compute_multi_gpu_topk_accuracy(
                top_k=1, split=self.split, suffix=suffix)
            accuracy5_metrics = compute_multi_gpu_topk_accuracy(
                top_k=5, split=self.split, suffix=suffix)

            cur_err = (1.0 - accuracy_metrics['topk_accuracy']) * 100
            cur_err5 = (1.0 - accuracy5_metrics['topk_accuracy']) * 100

            self.aggr_err += cur_err * cur_batch_size
            self.aggr_err5 += cur_err5 * cur_batch_size

        if (curr_iter + 1) % cfg.LOG_PERIOD == 0:
            rem_iters = cfg.SOLVER.MAX_ITER - curr_iter - 1
            eta_seconds = timer.average_time * rem_iters
            eta = str(datetime.timedelta(seconds=int(eta_seconds)))
            epoch = (curr_iter + 1) \
                / (cfg.TRAIN.DATASET_SIZE / cfg.TRAIN.BATCH_SIZE)

            log_str = ' '.join((
                '| Train ETA: {} LR: {:.8f}',
                ' Iters [{}/{}]',
                '[{:.2f}ep]',
                ' Time {:0.3f}',
                ' Loss {:7.4f}',
            )).format(
                eta, self.lr,
                curr_iter + 1, cfg.SOLVER.MAX_ITER,
                epoch,
                timer.diff,
                cur_loss,
            )

            if not cfg.MODEL.MULTI_LABEL:
                log_str += ' top1 {:7.3f} top5 {:7.3f}'.format(
                    cur_err, cur_err5)
            print(log_str)