MTRF/algorithms/softlearning/algorithms/multi_sac.py [652:687]:
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
        diagnostic_metrics = OrderedDict((
            ('mean', tf.reduce_mean),
            ('std', lambda x: tfp.stats.stddev(x, sample_axis=None)),
            ('max', tf.math.reduce_max),
            ('min', tf.math.reduce_min),
        ))

        self._diagnostics_ops_per_goal = [
            OrderedDict([
                (f'{key}-{metric_name}', metric_fn(values))
                for key, values in diagnosables.items()
                for metric_name, metric_fn in diagnostic_metrics.items()
            ])
            for diagnosables in diagnosables_per_goal
        ]

    def _training_batch(self, batch_size=None):
        return self._samplers[self._goal_index].random_batch(batch_size)

    def _evaluation_batches(self, batch_size=None):
        return [self._samplers[i].random_batch(batch_size) for i in range(self._num_goals)]

    def _update_target(self, i, tau=None):
        """ Update target networks for policy i. """
        tau = tau or self._tau

        for Q, Q_target in zip(self._Qs_per_policy[i], self._Q_targets_per_policy[i]):
            source_params = Q.get_weights()
            target_params = Q_target.get_weights()
            Q_target.set_weights([
                tau * source + (1.0 - tau) * target
                for source, target in zip(source_params, target_params)
            ])

    def _epoch_before_hook(self, *args, **kwargs):
        super(SAC, self)._epoch_before_hook(*args, **kwargs)
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -



MTRF/algorithms/softlearning/algorithms/phased_sac.py [457:492]:
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
        diagnostic_metrics = OrderedDict((
            ('mean', tf.reduce_mean),
            ('std', lambda x: tfp.stats.stddev(x, sample_axis=None)),
            ('max', tf.math.reduce_max),
            ('min', tf.math.reduce_min),
        ))

        self._diagnostics_ops_per_goal = [
            OrderedDict([
                (f'{key}-{metric_name}', metric_fn(values))
                for key, values in diagnosables.items()
                for metric_name, metric_fn in diagnostic_metrics.items()
            ])
            for diagnosables in diagnosables_per_goal
        ]

    def _training_batch(self, batch_size=None):
        return self._samplers[self._goal_index].random_batch(batch_size)

    def _evaluation_batches(self, batch_size=None):
        return [self._samplers[i].random_batch(batch_size) for i in range(self._num_goals)]

    def _update_target(self, i, tau=None):
        """ Update target networks for policy i. """
        tau = tau or self._tau

        for Q, Q_target in zip(self._Qs_per_policy[i], self._Q_targets_per_policy[i]):
            source_params = Q.get_weights()
            target_params = Q_target.get_weights()
            Q_target.set_weights([
                tau * source + (1.0 - tau) * target
                for source, target in zip(source_params, target_params)
            ])

    def _epoch_before_hook(self, *args, **kwargs):
        super(SAC, self)._epoch_before_hook(*args, **kwargs)
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -



