in c3dm/tools/stats.py [0:0]
def plot_stats( self, visdom_env=None, plot_file=None, \
visdom_server=None, visdom_port=None ):
# use the cached visdom env if none supplied
if visdom_env is None: visdom_env = self.visdom_env
if visdom_server is None: visdom_server = self.visdom_server
if visdom_port is None: visdom_port = self.visdom_port
if plot_file is None: plot_file = self.plot_file
stat_sets = list(self.stats.keys())
print("printing charts to visdom env '%s' (%s:%d)" % \
(visdom_env,visdom_server,visdom_port) )
novisdom = False
viz = get_visdom_connection(server=visdom_server,port=visdom_port)
if not viz.check_connection():
print("no visdom server! -> skipping visdom plots")
novisdom = True
lines = []
# plot metrics
if not novisdom:
viz.close(env=visdom_env,win=None)
for stat in self.log_vars:
vals = []
stat_sets_now = []
for stat_set in stat_sets:
val = self.stats[stat_set][stat].get_epoch_averages()
if val is None:
continue;
else:
val = np.array(val)[:,None]
stat_sets_now.append(stat_set)
vals.append(val)
if len(vals)==0:
continue
# pad for skipped test evals
size = np.max([val.shape[0] for val in vals])
vals = [
np.pad(val, ((0, size - val.shape[0]), (0, 0)), mode='edge')
for val in vals
]
try:
vals = np.concatenate(vals, axis=1)
except:
print('cant plot %s!' % stat)
continue
x = np.arange(vals.shape[0])
lines.append( (stat_sets_now,stat,x,vals,) )
if not novisdom:
for idx , ( tmodes, stat , x , vals ) in enumerate( lines ):
if vals.shape[1] == 1: # eval
continue
title = "%s" % stat
opts = dict(title=title,legend=list(tmodes))
try:
viz.line( Y=vals,X=x,env=visdom_env,opts=opts)
except:
print("Warning: problem adding data point", x.shape, vals.shape)
if plot_file:
print("exporting stats to %s" % plot_file)
ncol = 3
nrow = int(np.ceil(float(len(lines))/ncol))
matplotlib.rcParams.update({'font.size': 5})
color=cycle(plt.cm.tab10(np.linspace(0,1,10)))
fig = plt.figure(1); plt.clf()
for idx , ( tmodes, stat , x , vals ) in enumerate( lines ):
c=next(color)
plt.subplot(nrow,ncol,idx+1)
ax = plt.gca()
for vali,vals_ in enumerate(vals.T):
c_ = c * ( 1. - float(vali) * 0.3 )
plt.plot( x, vals_, c = c_, linewidth=1 )
plt.ylabel( stat )
plt.xlabel( "epoch" )
plt.gca().yaxis.label.set_color(c[0:3]*0.75)
plt.legend(tmodes)
gcolor = np.array(mcolors.to_rgba('lightgray'))
plt.grid(b=True, which='major', color=gcolor, linestyle='-', linewidth=0.4)
plt.grid(b=True, which='minor', color=gcolor, linestyle='--', linewidth=0.2)
plt.minorticks_on()
plt.tight_layout()
plt.show()
fig.savefig( plot_file )