in tools/stats.py [0:0]
def synchronize_logged_vars(self, log_vars, default_val=float('NaN')):
stat_sets = list(self.stats.keys())
# remove the additional log_vars
for stat_set in stat_sets:
for stat in self.stats[stat_set].keys():
if stat not in log_vars:
print("additional stat %s:%s -> removing" %
(stat_set, stat))
self.stats[stat_set] = {
stat: v for stat, v in self.stats[stat_set].items()
if stat in log_vars
}
self.log_vars = log_vars # !!!
for stat_set in stat_sets:
reference_stat = list(self.stats[stat_set].keys())[0]
for stat in log_vars:
if stat not in self.stats[stat_set]:
print("missing stat %s:%s -> filling with default values (%1.2f)" %
(stat_set, stat, default_val))
elif len(self.stats[stat_set][stat].history) != self.epoch+1:
h = self.stats[stat_set][stat].history
if len(h) == 0: # just never updated stat ... skip
continue
else:
print("incomplete stat %s:%s -> reseting with default values (%1.2f)" %
(stat_set, stat, default_val))
else:
continue
self.stats[stat_set][stat] = AverageMeter()
self.stats[stat_set][stat].reset()
lastep = self.epoch+1
for ep in range(lastep):
self.stats[stat_set][stat].update(
default_val, n=1, epoch=ep)
epoch_self = self.stats[stat_set][reference_stat].get_epoch()
epoch_generated = self.stats[stat_set][stat].get_epoch()
assert epoch_self == epoch_generated, \
"bad epoch of synchronized log_var! %d vs %d" % \
(epoch_self, epoch_generated)