in MTRF/algorithms/softlearning/algorithms/multi_sac.py [0:0]
def _init_diagnostics_ops(self):
diagnosables_per_goal = [
OrderedDict((
(f'Q_value_{i}', self._Q_values_per_policy[i]),
(f'Q_loss_{i}', self._Q_losses_per_policy[i]),
(f'policy_loss_{i}', self._policy_losses[i]),
(f'alpha_{i}', self._alphas[i])
))
for i in range(self._num_goals)
]
for i in range(self._num_goals):
# Only record the intrinsic/extrinsic reward diagnostics if
# the reward is actually used (i.e. the reward coeff is not 0)
if self._rnd_int_rew_coeffs[i]:
diagnosables_per_goal[i][f'rnd_reward_{i}'] = self._int_rewards[i]
diagnosables_per_goal[i][f'rnd_error_{i}'] = self._rnd_errors[i]
diagnosables_per_goal[i][f'running_rnd_reward_std_{i}'] = (
self._placeholders['reward'][f'running_int_rew_std_{i}'])
if self._ext_reward_coeffs[i]:
diagnosables_per_goal[i][f'ext_reward_{i}'] = self._ext_rewards[i]
diagnosables_per_goal[i][f'normalized_ext_reward_{i}'] = (
self._normalized_ext_rewards[i])
diagnosables_per_goal[i][f'unnormalized_ext_reward_{i}'] = (
self._unscaled_ext_rewards[i])
diagnosables_per_goal[i][f'running_ext_reward_std_{i}'] = (
self._placeholders['reward'][f'running_ext_rew_std_{i}'])
diagnosables_per_goal[i][f'total_reward_{i}'] = self._total_rewards[i]
if self._uses_vae:
for i in range(self._num_goals):
vae_metrics = self._vae_metrics_per_policy[i]
for metric, logs in vae_metrics.items():
for name, log in logs.items():
diagnostic_key = f'vae/{name}/{metric}'
diagnosables_per_goal[i][diagnostic_key] = log
elif self._uses_rae:
for i in range(self._num_goals):
rae_metrics = self._rae_metrics_per_policy[i]
for metric, logs in rae_metrics.items():
for name, log in logs.items():
diagnostic_key = f'rae/{name}/{metric}'
diagnosables_per_goal[i][diagnostic_key] = log
diagnostic_metrics = OrderedDict((
('mean', tf.reduce_mean),
('std', lambda x: tfp.stats.stddev(x, sample_axis=None)),
('max', tf.math.reduce_max),
('min', tf.math.reduce_min),
))
self._diagnostics_ops_per_goal = [
OrderedDict([
(f'{key}-{metric_name}', metric_fn(values))
for key, values in diagnosables.items()
for metric_name, metric_fn in diagnostic_metrics.items()
])
for diagnosables in diagnosables_per_goal
]