def _distribution()

in tf_agents/bandits/policies/boltzmann_reward_prediction_policy.py [0:0]


  def _distribution(self, time_step, policy_state):
    observation = time_step.observation
    if self.observation_and_action_constraint_splitter is not None:
      observation, _ = self.observation_and_action_constraint_splitter(
          observation)

    predictions, policy_state = self._reward_network(
        observation, time_step.step_type, policy_state)
    batch_size = tf.shape(predictions)[0]

    if isinstance(self._reward_network,
                  heteroscedastic_q_network.HeteroscedasticQNetwork):
      predicted_reward_values = predictions.q_value_logits
    else:
      predicted_reward_values = predictions

    predicted_reward_values.shape.with_rank_at_least(2)
    predicted_reward_values.shape.with_rank_at_most(3)
    if predicted_reward_values.shape[
        -1] is not None and predicted_reward_values.shape[
            -1] != self._expected_num_actions:
      raise ValueError(
          'The number of actions ({}) does not match the reward_network output'
          ' size ({}).'.format(self._expected_num_actions,
                               predicted_reward_values.shape[1]))

    mask = constr.construct_mask_from_multiple_sources(
        time_step.observation, self._observation_and_action_constraint_splitter,
        self._constraints, self._expected_num_actions)

    if self._boltzmann_gumbel_exploration_constant is not None:
      logits = predicted_reward_values

      # Apply masking if needed. Overwrite the logits for invalid actions to
      # logits.dtype.min.
      if mask is not None:
        almost_neg_inf = tf.constant(logits.dtype.min, dtype=logits.dtype)
        logits = tf.compat.v2.where(
            tf.cast(mask, tf.bool), logits, almost_neg_inf)

      gumbel_dist = tfp.distributions.Gumbel(loc=0., scale=1.)
      gumbel_samples = gumbel_dist.sample(tf.shape(logits))
      num_samples_list_float = tf.stack(
          [tf.cast(x.read_value(), tf.float32) for x in self._num_samples_list],
          axis=-1)
      exploration_weights = tf.math.divide_no_nan(
          self._boltzmann_gumbel_exploration_constant,
          tf.sqrt(num_samples_list_float))
      final_logits = logits + exploration_weights * gumbel_samples
      actions = tf.cast(
          tf.math.argmax(final_logits, axis=1), self._action_spec.dtype)
      # Log probability is not available in closed form. We treat this as a
      # deterministic policy at the moment.
      log_probability = tf.zeros([batch_size], tf.float32)
    else:
      # Apply the temperature scaling, needed for Boltzmann exploration.
      logits = predicted_reward_values / self._get_temperature_value()

      # Apply masking if needed. Overwrite the logits for invalid actions to
      # logits.dtype.min.
      if mask is not None:
        almost_neg_inf = tf.constant(logits.dtype.min, dtype=logits.dtype)
        logits = tf.compat.v2.where(
            tf.cast(mask, tf.bool), logits, almost_neg_inf)

      if self._action_offset != 0:
        distribution = shifted_categorical.ShiftedCategorical(
            logits=logits,
            dtype=self._action_spec.dtype,
            shift=self._action_offset)
      else:
        distribution = tfp.distributions.Categorical(
            logits=logits,
            dtype=self._action_spec.dtype)

      actions = distribution.sample()
      log_probability = distribution.log_prob(actions)

    bandit_policy_values = tf.fill([batch_size, 1],
                                   policy_utilities.BanditPolicyType.BOLTZMANN)

    if self._accepts_per_arm_features:
      # Saving the features for the chosen action to the policy_info.
      def gather_observation(obs):
        return tf.gather(params=obs, indices=actions, batch_dims=1)

      chosen_arm_features = tf.nest.map_structure(
          gather_observation,
          observation[bandit_spec_utils.PER_ARM_FEATURE_KEY])
      policy_info = policy_utilities.PerArmPolicyInfo(
          log_probability=log_probability if
          policy_utilities.InfoFields.LOG_PROBABILITY in self._emit_policy_info
          else (),
          predicted_rewards_mean=(
              predicted_reward_values if policy_utilities.InfoFields
              .PREDICTED_REWARDS_MEAN in self._emit_policy_info else ()),
          bandit_policy_type=(bandit_policy_values
                              if policy_utilities.InfoFields.BANDIT_POLICY_TYPE
                              in self._emit_policy_info else ()),
          chosen_arm_features=chosen_arm_features)
    else:
      policy_info = policy_utilities.PolicyInfo(
          log_probability=log_probability if
          policy_utilities.InfoFields.LOG_PROBABILITY in self._emit_policy_info
          else (),
          predicted_rewards_mean=(
              predicted_reward_values if policy_utilities.InfoFields
              .PREDICTED_REWARDS_MEAN in self._emit_policy_info else ()),
          bandit_policy_type=(bandit_policy_values
                              if policy_utilities.InfoFields.BANDIT_POLICY_TYPE
                              in self._emit_policy_info else ()))

    return policy_step.PolicyStep(
        tfp.distributions.Deterministic(loc=actions), policy_state, policy_info)