Dassl.pytorch/dassl/utils/meters.py (42 lines of code) (raw):

from collections import defaultdict import torch __all__ = ["AverageMeter", "MetricMeter"] class AverageMeter: """Compute and store the average and current value. Examples:: >>> # 1. Initialize a meter to record loss >>> losses = AverageMeter() >>> # 2. Update meter after every mini-batch update >>> losses.update(loss_value, batch_size) """ def __init__(self, ema=False): """ Args: ema (bool, optional): apply exponential moving average. """ self.ema = ema self.reset() def reset(self): self.val = 0 self.avg = 0 self.sum = 0 self.count = 0 def update(self, val, n=1): if isinstance(val, torch.Tensor): val = val.item() self.val = val self.sum += val * n self.count += n if self.ema: self.avg = self.avg * 0.9 + self.val * 0.1 else: self.avg = self.sum / self.count class MetricMeter: """Store the average and current value for a set of metrics. Examples:: >>> # 1. Create an instance of MetricMeter >>> metric = MetricMeter() >>> # 2. Update using a dictionary as input >>> input_dict = {'loss_1': value_1, 'loss_2': value_2} >>> metric.update(input_dict) >>> # 3. Convert to string and print >>> print(str(metric)) """ def __init__(self, delimiter=" "): self.meters = defaultdict(AverageMeter) self.delimiter = delimiter def update(self, input_dict): if input_dict is None: return if not isinstance(input_dict, dict): raise TypeError( "Input to MetricMeter.update() must be a dictionary" ) for k, v in input_dict.items(): if isinstance(v, torch.Tensor): v = v.item() self.meters[k].update(v) def __str__(self): output_str = [] for name, meter in self.meters.items(): output_str.append(f"{name} {meter.val:.4f} ({meter.avg:.4f})") return self.delimiter.join(output_str)