in compare_models.py [0:0]
def __update_plots(self, valid_ret):
''' valid_ret is the dictionary returned by a _run_inference '''
logging.info("Updating validation plots...")
'''
fixed_gn = "18/TL_PvZ_GG5758.tcr" + ("" if not self.args.reduced else ".npz")
scmap, race, inp, targ, _, _ = self.other_dl.one(fixed_gn)
last_outs, _ = self._do_model_step(scmap, race, inp, targ, None, optimize=False)
# Replay is 10960 frames long (8fps)
frame = min(inp.data.cpu().numpy().shape[0] - 1,
5000 // self.args.combine_frames)
self.__plot_heatmaps(last_outs, inp, targ, frame, "fxd ")
'''
scmap, race, inp, targ = valid_ret['inputs']
last_outs = valid_ret['outputs']
length = inp.data.cpu().numpy().shape[0]
self.__plot_heatmaps(last_outs, inp, targ, length // 2)
'''
[hmap(self.vis, 'encoded hidden for {}'.format(model.model_name),
model.encoder(model.conv1x1(model.trunk(scmap, race, inp))
.contiguous()).data.cpu().numpy()[length//2]
.reshape(model.enc_embsize, -1))
for model in self.models]
'''
n_g = self.state.n_samples
for model, output in zip(self.models, last_outs):
bar(self.vis,
'{}: units by type for {}'.format(n_g, model.model_name),
targ.data.cpu().numpy()[length // 2].sum(axis=2).sum(axis=1),
output[REGRESSION].cpu().numpy()[length // 2].sum(axis=2).sum(axis=1),
select=True)
for i, v in enumerate(valid_ret['loss']):
self.state.running_valids[i].append(v)
lar = np.array(self.state.running_valids) # lar[model_id][valid_time]
self.vis.line(
np.log10(lar.transpose()),
np.arange(lar.shape[1]),
win=self.valid_loss_pane,
opts={
'legend': [m.model_name for m in self.models],
'title': '(log10) Validation loss',
})
lar_F1 = [] # lar_F1[F1score*model_id][valid_time]
legend_F1 = []
lar_dist = []
legend_dist = []
for metric, valids in valid_ret['metrics'].items():
for i, v in enumerate(valids):
if '_L1' in metric:
lar = lar_dist
legend = legend_dist
else:
lar = lar_F1
legend = legend_F1
self.state.running_f1_scores[i][metric].append(v[-1])
lar.append(self.state.running_f1_scores[i][metric])
legend.append(self.models[i].model_name + "_" + metric)
lar_F1 = np.array(lar_F1)
self.vis.line(
lar_F1.transpose(),
np.arange(lar_F1.shape[1]),
win=self.valid_F1_pane,
opts={
'legend': legend_F1,
'title': 'F1 validation scores',
})
lar_dist = np.array(lar_dist)
self.vis.line(
lar_dist.transpose(),
np.arange(lar_dist.shape[1]),
win=self.valid_dist_pane,
opts={
'legend': legend_dist,
'title': 'dist validation scores',
})
self.vis.save([self.env_id])
logging.info("Done updating validation plots")