def _init_svgd_update()

in MTRF/algorithms/softlearning/algorithms/sql.py [0:0]


    def _init_svgd_update(self):
        """Create a minimization operation for policy update (SVGD)."""

        policy_inputs = flatten_input_structure({
            name: tf.reshape(
                tf.tile(
                    self._placeholders['observations'][name][:, None, :],
                    (1, self._kernel_n_particles, 1)),
                (-1, *self._placeholders['observations'][name].shape[1:]))
            for name in self._policy.observation_keys
        })
        actions = self._policy.actions(policy_inputs)
        action_shape = actions.shape[1:]
        actions = tf.reshape(
            actions, (-1, self._kernel_n_particles, *action_shape))

        assert_shape(
            actions, (None, self._kernel_n_particles, *action_shape))

        # SVGD requires computing two empirical expectations over actions
        # (see Appendix C1.1.). To that end, we first sample a single set of
        # actions, and later split them into two sets: `fixed_actions` are used
        # to evaluate the expectation indexed by `j` and `updated_actions`
        # the expectation indexed by `i`.
        n_updated_actions = int(
            self._kernel_n_particles * self._kernel_update_ratio)
        n_fixed_actions = self._kernel_n_particles - n_updated_actions

        fixed_actions, updated_actions = tf.split(
            actions, [n_fixed_actions, n_updated_actions], axis=1)
        fixed_actions = tf.stop_gradient(fixed_actions)
        assert_shape(fixed_actions,
                     [None, n_fixed_actions, *action_shape])
        assert_shape(updated_actions,
                     [None, n_updated_actions, *action_shape])

        Q_observations = {
            name: tf.reshape(
                    tf.tile(
                        self._placeholders['observations'][name][:, None, :],
                        (1, n_fixed_actions, 1)),
                    (-1, *self._placeholders['observations'][name].shape[1:]))
            for name in self._policy.observation_keys
        }
        Q_actions = tf.reshape(fixed_actions, (-1, *action_shape))
        Q_inputs = flatten_input_structure({
            **Q_observations, 'actions': Q_actions})
        Q_log_targets = tuple(Q(Q_inputs) for Q in self._Qs)
        min_Q_log_target = tf.reduce_min(Q_log_targets, axis=0)
        svgd_target_values = tf.reshape(
            min_Q_log_target, (-1, n_fixed_actions, 1))

        # Target log-density. Q_soft in Equation 13:
        assert self._policy._squash
        squash_correction = tf.reduce_sum(
            tf.math.log(1 - fixed_actions ** 2 + EPS), axis=-1, keepdims=True)
        log_probs = svgd_target_values + squash_correction

        grad_log_probs = tf.gradients(log_probs, fixed_actions)[0]
        grad_log_probs = tf.expand_dims(grad_log_probs, axis=2)
        grad_log_probs = tf.stop_gradient(grad_log_probs)
        assert_shape(grad_log_probs,
                     [None, n_fixed_actions, 1, *action_shape])

        kernel_dict = self._kernel_fn(xs=fixed_actions, ys=updated_actions)

        # Kernel function in Equation 13:
        kappa = kernel_dict["output"][..., tf.newaxis]
        assert_shape(kappa, [None, n_fixed_actions, n_updated_actions, 1])

        # Stein Variational Gradient in Equation 13:
        action_gradients = tf.reduce_mean(
            kappa * grad_log_probs + kernel_dict["gradient"], axis=1)
        assert_shape(action_gradients,
                     [None, n_updated_actions, *action_shape])

        # Propagate the gradient through the policy network (Equation 14).
        gradients = tf.gradients(
            updated_actions,
            self._policy.trainable_variables,
            grad_ys=action_gradients)

        surrogate_loss = tf.reduce_sum([
            tf.reduce_sum(w * tf.stop_gradient(g))
            for w, g in zip(self._policy.trainable_variables, gradients)
        ])

        self._policy_optimizer = tf.compat.v1.train.AdamOptimizer(
            learning_rate=self._policy_lr,
            name='policy_optimizer'
        )

        if self._train_policy:
            svgd_training_op = self._policy_optimizer.minimize(
                loss=-surrogate_loss,
                var_list=self._policy.trainable_variables)
            self._training_ops.update({
                'svgd': svgd_training_op
            })