in tools/stats.py [0:0]
def plot_stats(self, visdom_env=None, plot_file=None,
visdom_server=None, visdom_port=None):
# use the cached visdom env if none supplied
if visdom_env is None:
visdom_env = self.visdom_env
if visdom_server is None:
visdom_server = self.visdom_server
if visdom_port is None:
visdom_port = self.visdom_port
if plot_file is None:
plot_file = self.plot_file
stat_sets = list(self.stats.keys())
print("printing charts to visdom env '%s' (%s:%d)" %
(visdom_env, visdom_server, visdom_port))
novisdom = False
viz = get_visdom_connection(server=visdom_server, port=visdom_port)
if not viz.check_connection():
print("no visdom server! -> skipping visdom plots")
novisdom = True
lines = []
# plot metrics
if not novisdom:
viz.close(env=visdom_env, win=None)
for stat in self.log_vars:
vals = []
stat_sets_now = []
for stat_set in stat_sets:
val = self.stats[stat_set][stat].get_epoch_averages()
if val is None:
continue
else:
val = np.array(val)[:, None]
stat_sets_now.append(stat_set)
vals.append(val)
if len(vals) == 0:
continue
vals = np.concatenate(vals, axis=1)
x = np.arange(vals.shape[0])
lines.append((stat_sets_now, stat, x, vals,))
if not novisdom:
for idx, (tmodes, stat, x, vals) in enumerate(lines):
title = "%s" % stat
opts = dict(title=title, legend=list(tmodes))
if vals.shape[1] == 1:
vals = vals[:, 0]
viz.line(Y=vals, X=x, env=visdom_env, opts=opts)
if plot_file:
print("exporting stats to %s" % plot_file)
ncol = 3
nrow = int(np.ceil(float(len(lines))/ncol))
matplotlib.rcParams.update({'font.size': 5})
color = cycle(plt.cm.tab10(np.linspace(0, 1, 10)))
fig = plt.figure(1)
plt.clf()
for idx, (tmodes, stat, x, vals) in enumerate(lines):
c = next(color)
plt.subplot(nrow, ncol, idx+1)
for vali, vals_ in enumerate(vals.T):
c_ = c * (1. - float(vali) * 0.3)
plt.plot(x, vals_, c=c_, linewidth=1)
plt.ylabel(stat)
plt.xlabel("epoch")
plt.gca().yaxis.label.set_color(c[0:3]*0.75)
plt.legend(tmodes)
gcolor = np.array(mcolors.to_rgba('lightgray'))
plt.grid(b=True, which='major', color=gcolor,
linestyle='-', linewidth=0.4)
plt.grid(b=True, which='minor', color=gcolor,
linestyle='--', linewidth=0.2)
plt.minorticks_on()
plt.tight_layout()
plt.show()
fig.savefig(plot_file)