def get_diagnostics()

in MTRF/algorithms/softlearning/algorithms/multi_sac.py [0:0]


    def get_diagnostics(self,
                        iteration,
                        batches,
                        training_paths_per_policy,
                        evaluation_paths_per_policy):
        """Return diagnostic information as ordered dictionary.

        Also calls the `draw` method of the plotter, if plotter defined.
        """
        goal_index = self._goal_index
        diagnostics = {}

        for i in range(self._num_goals):
            self._goal_index = i
            feed_dict = self._get_feed_dict(iteration, batches[i])
            diagnostics.update(
                self._session.run({**self._diagnostics_ops_per_goal[i]}, feed_dict))
            diagnostics.update(OrderedDict([
                (f'policy_{i}/{key}', value)
                for key, value in self._policies[i].get_diagnostics(
                    flatten_input_structure({
                        name: batches[i]['observations'][name]
                        for name in self._policies[i].observation_keys})
                ).items()
            ]))
        self._goal_index = goal_index
        should_save = iteration % self._save_reconstruction_frequency == 0

        # Generate random pixels to evaluate the preprocessors
        if 'pixels' in self._placeholders['observations']:
            random_idxs = np.random.choice(
                feed_dict[self._placeholders['observations']['pixels']].shape[0],
                size=self._n_preprocessor_evals_per_epoch)
            eval_pixels = (
                feed_dict[self._placeholders['observations']['pixels']][random_idxs])
        else:
            eval_pixels = None

        if self._uses_vae and should_save:
            assert eval_pixels
            for i, preprocessors in enumerate(self._preprocessors_per_policy):
                if self._ext_reward_coeffs[i] == 0:
                    continue

                for name, vae in preprocessors.items():
                    z_mean, z_logvar, z = self._session.run(vae.encoder(eval_pixels))
                    reconstructions = self._session.run(
                        tf.math.sigmoid(vae.decoder(z)))
                    concat = np.concatenate([
                        eval_pixels,
                        skimage.util.img_as_ubyte(reconstructions)
                    ], axis=2)

                    sampled_z = np.random.normal(
                        size=(eval_pixels.shape[0], vae.latent_dim))
                    decoded_samples = self._session.run(
                        tf.math.sigmoid(vae.decoder.output),
                        feed_dict={vae.decoder.input: sampled_z}
                    )

                    save_path = os.path.join(os.getcwd(), 'vae')
                    recon_concat = np.vstack(concat)
                    skimage.io.imsave(
                        os.path.join(
                            save_path,
                            f'{name}_reconstruction_{iteration}.png'),
                        recon_concat)
                    samples_concat = np.vstack(decoded_samples)
                    skimage.io.imsave(
                        os.path.join(
                            save_path,
                            f'{name}_sample_{iteration}.png'),
                        samples_concat)
        elif self._uses_rae:
            assert eval_pixels
            for i, preprocessors in enumerate(self._preprocessors_per_policy):
                if self._ext_reward_coeffs[i] == 0:
                    continue
                for name, rae in preprocessors.items():
                    z, reconstructions = self._session.run(
                        rae(eval_pixels, include_reconstructions=True))
                    concat = np.concatenate([
                        eval_pixels,
                        skimage.util.img_as_ubyte(reconstructions)
                    ], axis=2)

                    if should_save:
                        save_path = os.path.join(os.getcwd(), 'rae')
                        recon_concat = np.vstack(concat)
                        skimage.io.imsave(
                            os.path.join(
                                save_path,
                                f'{name}_reconstruction_{iteration}.png'),
                            recon_concat)

                    # Track latents
                    if self._fixed_eval_pixels is None:
                        self._fixed_eval_pixels = eval_pixels
                        self._fixed_eval_latents = np.zeros(z.shape)

                    z_fixed, reconstructions_fixed = self._session.run(
                        rae(self._fixed_eval_pixels, include_reconstructions=True))

                    if should_save:
                        concat_fixed = np.concatenate([
                            self._fixed_eval_pixels,
                            skimage.util.img_as_ubyte(reconstructions_fixed)
                        ], axis=2)
                        recon_concat_fixed = np.vstack(concat_fixed)
                        skimage.io.imsave(
                            os.path.join(
                                save_path,
                                f'{name}_fixed_reconstruction_{iteration}.png'),
                            recon_concat_fixed)

                    z_diff = np.linalg.norm(z_fixed - self._fixed_eval_latents, axis=1)
                    diagnostics.update({
                        f'rae/{name}/tracked-latent-l2-difference-with-prev-mean': np.mean(z_diff),
                        f'rae/{name}/tracked-latent-l2-difference-with-prev-std': np.std(z_diff),
                    })
                    # Save the previous latents to compare in the next epoch
                    self._fixed_eval_latents = z_fixed

        if self._save_eval_paths:
            import pickle
            file_name = f'eval_paths_{iteration // self.epoch_length}.pkl'
            with open(os.path.join(os.getcwd(), file_name)) as f:
                pickle.dump(evaluation_paths_per_policy, f)

        if self._plotter:
            self._plotter.draw()

        return diagnostics