def _train()

in tf_agents/agents/ppo/ppo_agent.py [0:0]


  def _train(self, experience, weights):
    experience = self._as_trajectory(experience)

    if self._compute_value_and_advantage_in_train:
      processed_experience = self._preprocess(experience)
    else:
      processed_experience = experience

    # Mask trajectories that cannot be used for training.
    valid_mask = ppo_utils.make_trajectory_mask(processed_experience)
    if weights is None:
      masked_weights = valid_mask
    else:
      masked_weights = weights * valid_mask

    # Reconstruct per-timestep policy distribution from stored distribution
    #   parameters.
    old_action_distribution_parameters = processed_experience.policy_info[
        'dist_params']

    old_actions_distribution = (
        ppo_utils.distribution_from_spec(
            self._action_distribution_spec,
            old_action_distribution_parameters,
            legacy_distribution_network=isinstance(
                self._actor_net, network.DistributionNetwork)))

    # Compute log probability of actions taken during data collection, using the
    #   collect policy distribution.
    old_act_log_probs = common.log_probability(old_actions_distribution,
                                               processed_experience.action,
                                               self._action_spec)

    # TODO(b/171573175): remove the condition once histograms are
    # supported on TPUs.
    if self._debug_summaries and not tf.config.list_logical_devices('TPU'):
      actions_list = tf.nest.flatten(processed_experience.action)
      show_action_index = len(actions_list) != 1
      for i, single_action in enumerate(actions_list):
        action_name = ('actions_{}'.format(i)
                       if show_action_index else 'actions')
        tf.compat.v2.summary.histogram(
            name=action_name, data=single_action, step=self.train_step_counter)

    time_steps = ts.TimeStep(
        step_type=processed_experience.step_type,
        reward=processed_experience.reward,
        discount=processed_experience.discount,
        observation=processed_experience.observation)
    actions = processed_experience.action
    returns = processed_experience.policy_info['return']
    advantages = processed_experience.policy_info['advantage']

    normalized_advantages = _normalize_advantages(advantages,
                                                  variance_epsilon=1e-8)

    # TODO(b/171573175): remove the condition once histograms are
    # supported on TPUs.
    if self._debug_summaries and not tf.config.list_logical_devices('TPU'):
      tf.compat.v2.summary.histogram(
          name='advantages_normalized',
          data=normalized_advantages,
          step=self.train_step_counter)
    old_value_predictions = processed_experience.policy_info['value_prediction']

    batch_size = nest_utils.get_outer_shape(time_steps, self._time_step_spec)[0]
    # Loss tensors across batches will be aggregated for summaries.
    policy_gradient_losses = []
    value_estimation_losses = []
    l2_regularization_losses = []
    entropy_regularization_losses = []
    kl_penalty_losses = []

    loss_info = None  # TODO(b/123627451): Remove.
    variables_to_train = list(
        object_identity.ObjectIdentitySet(self._actor_net.trainable_weights +
                                          self._value_net.trainable_weights))
    # Sort to ensure tensors on different processes end up in same order.
    variables_to_train = sorted(variables_to_train, key=lambda x: x.name)

    for i_epoch in range(self._num_epochs):
      with tf.name_scope('epoch_%d' % i_epoch):
        # Only save debug summaries for first and last epochs.
        debug_summaries = (
            self._debug_summaries and
            (i_epoch == 0 or i_epoch == self._num_epochs - 1))

        with tf.GradientTape() as tape:
          loss_info = self.get_loss(
              time_steps,
              actions,
              old_act_log_probs,
              returns,
              normalized_advantages,
              old_action_distribution_parameters,
              masked_weights,
              self.train_step_counter,
              debug_summaries,
              old_value_predictions=old_value_predictions,
              training=True)

        grads = tape.gradient(loss_info.loss, variables_to_train)
        if self._gradient_clipping > 0:
          grads, _ = tf.clip_by_global_norm(grads, self._gradient_clipping)

        # Tuple is used for py3, where zip is a generator producing values once.
        grads_and_vars = tuple(zip(grads, variables_to_train))

        # If summarize_gradients, create functions for summarizing both
        # gradients and variables.
        if self._summarize_grads_and_vars and debug_summaries:
          eager_utils.add_gradients_summaries(grads_and_vars,
                                              self.train_step_counter)
          eager_utils.add_variables_summaries(grads_and_vars,
                                              self.train_step_counter)

        self._optimizer.apply_gradients(grads_and_vars)
        self.train_step_counter.assign_add(1)

        policy_gradient_losses.append(loss_info.extra.policy_gradient_loss)
        value_estimation_losses.append(loss_info.extra.value_estimation_loss)
        l2_regularization_losses.append(loss_info.extra.l2_regularization_loss)
        entropy_regularization_losses.append(
            loss_info.extra.entropy_regularization_loss)
        kl_penalty_losses.append(loss_info.extra.kl_penalty_loss)

    # TODO(b/1613650790: Move this logic to PPOKLPenaltyAgent.
    if self._initial_adaptive_kl_beta > 0:
      # After update epochs, update adaptive kl beta, then update observation
      #   normalizer and reward normalizer.
      policy_state = self._collect_policy.get_initial_state(batch_size)
      # Compute the mean kl from previous action distribution.
      kl_divergence = self._kl_divergence(
          time_steps, old_action_distribution_parameters,
          self._collect_policy.distribution(time_steps, policy_state).action)
      self.update_adaptive_kl_beta(kl_divergence)

    if self.update_normalizers_in_train:
      self.update_observation_normalizer(time_steps.observation)
      self.update_reward_normalizer(processed_experience.reward)

    loss_info = tf.nest.map_structure(tf.identity, loss_info)

    # Make summaries for total loss averaged across all epochs.
    # The *_losses lists will have been populated by
    #   calls to self.get_loss. Assumes all the losses have same length.
    with tf.name_scope('Losses/'):
      num_epochs = len(policy_gradient_losses)
      total_policy_gradient_loss = tf.add_n(policy_gradient_losses) / num_epochs
      total_value_estimation_loss = tf.add_n(
          value_estimation_losses) / num_epochs
      total_l2_regularization_loss = tf.add_n(
          l2_regularization_losses) / num_epochs
      total_entropy_regularization_loss = tf.add_n(
          entropy_regularization_losses) / num_epochs
      total_kl_penalty_loss = tf.add_n(kl_penalty_losses) / num_epochs
      tf.compat.v2.summary.scalar(
          name='policy_gradient_loss',
          data=total_policy_gradient_loss,
          step=self.train_step_counter)
      tf.compat.v2.summary.scalar(
          name='value_estimation_loss',
          data=total_value_estimation_loss,
          step=self.train_step_counter)
      tf.compat.v2.summary.scalar(
          name='l2_regularization_loss',
          data=total_l2_regularization_loss,
          step=self.train_step_counter)
      tf.compat.v2.summary.scalar(
          name='entropy_regularization_loss',
          data=total_entropy_regularization_loss,
          step=self.train_step_counter)
      tf.compat.v2.summary.scalar(
          name='kl_penalty_loss',
          data=total_kl_penalty_loss,
          step=self.train_step_counter)

      total_abs_loss = (
          tf.abs(total_policy_gradient_loss) +
          tf.abs(total_value_estimation_loss) +
          tf.abs(total_entropy_regularization_loss) +
          tf.abs(total_l2_regularization_loss) + tf.abs(total_kl_penalty_loss))

      tf.compat.v2.summary.scalar(
          name='total_abs_loss',
          data=total_abs_loss,
          step=self.train_step_counter)

    with tf.name_scope('LearningRate/'):
      learning_rate = ppo_utils.get_learning_rate(self._optimizer)
      tf.compat.v2.summary.scalar(
          name='learning_rate',
          data=learning_rate,
          step=self.train_step_counter)

    # TODO(b/171573175): remove the condition once histograms are
    # supported on TPUs.
    if self._summarize_grads_and_vars and not tf.config.list_logical_devices(
        'TPU'):
      with tf.name_scope('Variables/'):
        all_vars = (
            self._actor_net.trainable_weights +
            self._value_net.trainable_weights)
        for var in all_vars:
          tf.compat.v2.summary.histogram(
              name=var.name.replace(':', '_'),
              data=var,
              step=self.train_step_counter)

    return loss_info