in tools/stats.py [0:0]
def get_epoch_averages(self, epoch=None):
stat_sets = list(self.stats.keys())
if epoch is None:
epoch = self.epoch
if epoch == -1:
epoch = list(range(self.epoch))
outvals = {}
for stat_set in stat_sets:
outvals[stat_set] = {'epoch': epoch,
'it': self.it[stat_set],
'epoch_max': self.epoch}
for stat in self.stats[stat_set].keys():
if self.stats[stat_set][stat].count == 0:
continue
if isinstance(epoch, Iterable):
avgs = self.stats[stat_set][stat].get_epoch_averages()
avgs = [avgs[e] for e in epoch]
else:
avgs = self.stats[stat_set][stat].get_epoch_averages(
epoch=epoch)
outvals[stat_set][stat] = avgs
return outvals