def _init_diagnostics_ops()

in MTRF/algorithms/softlearning/algorithms/multi_sac.py [0:0]


    def _init_diagnostics_ops(self):
        diagnosables_per_goal = [
            OrderedDict((
                (f'Q_value_{i}', self._Q_values_per_policy[i]),
                (f'Q_loss_{i}', self._Q_losses_per_policy[i]),
                (f'policy_loss_{i}', self._policy_losses[i]),
                (f'alpha_{i}', self._alphas[i])
            ))
            for i in range(self._num_goals)
        ]

        for i in range(self._num_goals):
            # Only record the intrinsic/extrinsic reward diagnostics if
            # the reward is actually used (i.e. the reward coeff is not 0)
            if self._rnd_int_rew_coeffs[i]:
                diagnosables_per_goal[i][f'rnd_reward_{i}'] = self._int_rewards[i]
                diagnosables_per_goal[i][f'rnd_error_{i}'] = self._rnd_errors[i]
                diagnosables_per_goal[i][f'running_rnd_reward_std_{i}'] = (
                    self._placeholders['reward'][f'running_int_rew_std_{i}'])

            if self._ext_reward_coeffs[i]:
                diagnosables_per_goal[i][f'ext_reward_{i}'] = self._ext_rewards[i]
                diagnosables_per_goal[i][f'normalized_ext_reward_{i}'] = (
                    self._normalized_ext_rewards[i])
                diagnosables_per_goal[i][f'unnormalized_ext_reward_{i}'] = (
                    self._unscaled_ext_rewards[i])

            diagnosables_per_goal[i][f'running_ext_reward_std_{i}'] = (
                self._placeholders['reward'][f'running_ext_rew_std_{i}'])
            diagnosables_per_goal[i][f'total_reward_{i}'] = self._total_rewards[i]

        if self._uses_vae:
            for i in range(self._num_goals):
                vae_metrics = self._vae_metrics_per_policy[i]
                for metric, logs in vae_metrics.items():
                    for name, log in logs.items():
                        diagnostic_key = f'vae/{name}/{metric}'
                        diagnosables_per_goal[i][diagnostic_key] = log
        elif self._uses_rae:
            for i in range(self._num_goals):
                rae_metrics = self._rae_metrics_per_policy[i]
                for metric, logs in rae_metrics.items():
                    for name, log in logs.items():
                        diagnostic_key = f'rae/{name}/{metric}'
                        diagnosables_per_goal[i][diagnostic_key] = log

        diagnostic_metrics = OrderedDict((
            ('mean', tf.reduce_mean),
            ('std', lambda x: tfp.stats.stddev(x, sample_axis=None)),
            ('max', tf.math.reduce_max),
            ('min', tf.math.reduce_min),
        ))

        self._diagnostics_ops_per_goal = [
            OrderedDict([
                (f'{key}-{metric_name}', metric_fn(values))
                for key, values in diagnosables.items()
                for metric_name, metric_fn in diagnostic_metrics.items()
            ])
            for diagnosables in diagnosables_per_goal
        ]