in MTRF/algorithms/softlearning/algorithms/sql.py [0:0]
def _init_svgd_update(self):
"""Create a minimization operation for policy update (SVGD)."""
policy_inputs = flatten_input_structure({
name: tf.reshape(
tf.tile(
self._placeholders['observations'][name][:, None, :],
(1, self._kernel_n_particles, 1)),
(-1, *self._placeholders['observations'][name].shape[1:]))
for name in self._policy.observation_keys
})
actions = self._policy.actions(policy_inputs)
action_shape = actions.shape[1:]
actions = tf.reshape(
actions, (-1, self._kernel_n_particles, *action_shape))
assert_shape(
actions, (None, self._kernel_n_particles, *action_shape))
# SVGD requires computing two empirical expectations over actions
# (see Appendix C1.1.). To that end, we first sample a single set of
# actions, and later split them into two sets: `fixed_actions` are used
# to evaluate the expectation indexed by `j` and `updated_actions`
# the expectation indexed by `i`.
n_updated_actions = int(
self._kernel_n_particles * self._kernel_update_ratio)
n_fixed_actions = self._kernel_n_particles - n_updated_actions
fixed_actions, updated_actions = tf.split(
actions, [n_fixed_actions, n_updated_actions], axis=1)
fixed_actions = tf.stop_gradient(fixed_actions)
assert_shape(fixed_actions,
[None, n_fixed_actions, *action_shape])
assert_shape(updated_actions,
[None, n_updated_actions, *action_shape])
Q_observations = {
name: tf.reshape(
tf.tile(
self._placeholders['observations'][name][:, None, :],
(1, n_fixed_actions, 1)),
(-1, *self._placeholders['observations'][name].shape[1:]))
for name in self._policy.observation_keys
}
Q_actions = tf.reshape(fixed_actions, (-1, *action_shape))
Q_inputs = flatten_input_structure({
**Q_observations, 'actions': Q_actions})
Q_log_targets = tuple(Q(Q_inputs) for Q in self._Qs)
min_Q_log_target = tf.reduce_min(Q_log_targets, axis=0)
svgd_target_values = tf.reshape(
min_Q_log_target, (-1, n_fixed_actions, 1))
# Target log-density. Q_soft in Equation 13:
assert self._policy._squash
squash_correction = tf.reduce_sum(
tf.math.log(1 - fixed_actions ** 2 + EPS), axis=-1, keepdims=True)
log_probs = svgd_target_values + squash_correction
grad_log_probs = tf.gradients(log_probs, fixed_actions)[0]
grad_log_probs = tf.expand_dims(grad_log_probs, axis=2)
grad_log_probs = tf.stop_gradient(grad_log_probs)
assert_shape(grad_log_probs,
[None, n_fixed_actions, 1, *action_shape])
kernel_dict = self._kernel_fn(xs=fixed_actions, ys=updated_actions)
# Kernel function in Equation 13:
kappa = kernel_dict["output"][..., tf.newaxis]
assert_shape(kappa, [None, n_fixed_actions, n_updated_actions, 1])
# Stein Variational Gradient in Equation 13:
action_gradients = tf.reduce_mean(
kappa * grad_log_probs + kernel_dict["gradient"], axis=1)
assert_shape(action_gradients,
[None, n_updated_actions, *action_shape])
# Propagate the gradient through the policy network (Equation 14).
gradients = tf.gradients(
updated_actions,
self._policy.trainable_variables,
grad_ys=action_gradients)
surrogate_loss = tf.reduce_sum([
tf.reduce_sum(w * tf.stop_gradient(g))
for w, g in zip(self._policy.trainable_variables, gradients)
])
self._policy_optimizer = tf.compat.v1.train.AdamOptimizer(
learning_rate=self._policy_lr,
name='policy_optimizer'
)
if self._train_policy:
svgd_training_op = self._policy_optimizer.minimize(
loss=-surrogate_loss,
var_list=self._policy.trainable_variables)
self._training_ops.update({
'svgd': svgd_training_op
})