def __update_plots()

in compare_models.py [0:0]


    def __update_plots(self, valid_ret):
        ''' valid_ret is the dictionary returned by a _run_inference '''
        logging.info("Updating validation plots...")

        '''
        fixed_gn = "18/TL_PvZ_GG5758.tcr" + ("" if not self.args.reduced else ".npz")
        scmap, race, inp, targ, _, _ = self.other_dl.one(fixed_gn)
        last_outs, _ = self._do_model_step(scmap, race, inp, targ, None, optimize=False)
        # Replay is 10960 frames long (8fps)
        frame = min(inp.data.cpu().numpy().shape[0] - 1,
                    5000 // self.args.combine_frames)
        self.__plot_heatmaps(last_outs, inp, targ, frame, "fxd ")
        '''

        scmap, race, inp, targ = valid_ret['inputs']
        last_outs = valid_ret['outputs']
        length = inp.data.cpu().numpy().shape[0]
        self.__plot_heatmaps(last_outs, inp, targ, length // 2)

        '''
        [hmap(self.vis, 'encoded hidden for {}'.format(model.model_name),
            model.encoder(model.conv1x1(model.trunk(scmap, race, inp))
                .contiguous()).data.cpu().numpy()[length//2]
                    .reshape(model.enc_embsize, -1))
            for model in self.models]
        '''
        n_g = self.state.n_samples
        for model, output in zip(self.models, last_outs):
            bar(self.vis,
                '{}: units by type for {}'.format(n_g, model.model_name),
                targ.data.cpu().numpy()[length // 2].sum(axis=2).sum(axis=1),
                output[REGRESSION].cpu().numpy()[length // 2].sum(axis=2).sum(axis=1),
                select=True)

        for i, v in enumerate(valid_ret['loss']):
            self.state.running_valids[i].append(v)
        lar = np.array(self.state.running_valids)  # lar[model_id][valid_time]
        self.vis.line(
            np.log10(lar.transpose()),
            np.arange(lar.shape[1]),
            win=self.valid_loss_pane,
            opts={
                'legend': [m.model_name for m in self.models],
                'title': '(log10) Validation loss',
            })
        lar_F1 = []  # lar_F1[F1score*model_id][valid_time]
        legend_F1 = []
        lar_dist = []
        legend_dist = []
        for metric, valids in valid_ret['metrics'].items():
            for i, v in enumerate(valids):
                if '_L1' in metric:
                    lar = lar_dist
                    legend = legend_dist
                else:
                    lar = lar_F1
                    legend = legend_F1
                self.state.running_f1_scores[i][metric].append(v[-1])
                lar.append(self.state.running_f1_scores[i][metric])
                legend.append(self.models[i].model_name + "_" + metric)
        lar_F1 = np.array(lar_F1)
        self.vis.line(
            lar_F1.transpose(),
            np.arange(lar_F1.shape[1]),
            win=self.valid_F1_pane,
            opts={
                'legend': legend_F1,
                'title': 'F1 validation scores',
            })
        lar_dist = np.array(lar_dist)
        self.vis.line(
            lar_dist.transpose(),
            np.arange(lar_dist.shape[1]),
            win=self.valid_dist_pane,
            opts={
                'legend': legend_dist,
                'title': 'dist validation scores',
            })
        self.vis.save([self.env_id])
        logging.info("Done updating validation plots")