in tf_agents/bandits/policies/boltzmann_reward_prediction_policy.py [0:0]
def _distribution(self, time_step, policy_state):
observation = time_step.observation
if self.observation_and_action_constraint_splitter is not None:
observation, _ = self.observation_and_action_constraint_splitter(
observation)
predictions, policy_state = self._reward_network(
observation, time_step.step_type, policy_state)
batch_size = tf.shape(predictions)[0]
if isinstance(self._reward_network,
heteroscedastic_q_network.HeteroscedasticQNetwork):
predicted_reward_values = predictions.q_value_logits
else:
predicted_reward_values = predictions
predicted_reward_values.shape.with_rank_at_least(2)
predicted_reward_values.shape.with_rank_at_most(3)
if predicted_reward_values.shape[
-1] is not None and predicted_reward_values.shape[
-1] != self._expected_num_actions:
raise ValueError(
'The number of actions ({}) does not match the reward_network output'
' size ({}).'.format(self._expected_num_actions,
predicted_reward_values.shape[1]))
mask = constr.construct_mask_from_multiple_sources(
time_step.observation, self._observation_and_action_constraint_splitter,
self._constraints, self._expected_num_actions)
if self._boltzmann_gumbel_exploration_constant is not None:
logits = predicted_reward_values
# Apply masking if needed. Overwrite the logits for invalid actions to
# logits.dtype.min.
if mask is not None:
almost_neg_inf = tf.constant(logits.dtype.min, dtype=logits.dtype)
logits = tf.compat.v2.where(
tf.cast(mask, tf.bool), logits, almost_neg_inf)
gumbel_dist = tfp.distributions.Gumbel(loc=0., scale=1.)
gumbel_samples = gumbel_dist.sample(tf.shape(logits))
num_samples_list_float = tf.stack(
[tf.cast(x.read_value(), tf.float32) for x in self._num_samples_list],
axis=-1)
exploration_weights = tf.math.divide_no_nan(
self._boltzmann_gumbel_exploration_constant,
tf.sqrt(num_samples_list_float))
final_logits = logits + exploration_weights * gumbel_samples
actions = tf.cast(
tf.math.argmax(final_logits, axis=1), self._action_spec.dtype)
# Log probability is not available in closed form. We treat this as a
# deterministic policy at the moment.
log_probability = tf.zeros([batch_size], tf.float32)
else:
# Apply the temperature scaling, needed for Boltzmann exploration.
logits = predicted_reward_values / self._get_temperature_value()
# Apply masking if needed. Overwrite the logits for invalid actions to
# logits.dtype.min.
if mask is not None:
almost_neg_inf = tf.constant(logits.dtype.min, dtype=logits.dtype)
logits = tf.compat.v2.where(
tf.cast(mask, tf.bool), logits, almost_neg_inf)
if self._action_offset != 0:
distribution = shifted_categorical.ShiftedCategorical(
logits=logits,
dtype=self._action_spec.dtype,
shift=self._action_offset)
else:
distribution = tfp.distributions.Categorical(
logits=logits,
dtype=self._action_spec.dtype)
actions = distribution.sample()
log_probability = distribution.log_prob(actions)
bandit_policy_values = tf.fill([batch_size, 1],
policy_utilities.BanditPolicyType.BOLTZMANN)
if self._accepts_per_arm_features:
# Saving the features for the chosen action to the policy_info.
def gather_observation(obs):
return tf.gather(params=obs, indices=actions, batch_dims=1)
chosen_arm_features = tf.nest.map_structure(
gather_observation,
observation[bandit_spec_utils.PER_ARM_FEATURE_KEY])
policy_info = policy_utilities.PerArmPolicyInfo(
log_probability=log_probability if
policy_utilities.InfoFields.LOG_PROBABILITY in self._emit_policy_info
else (),
predicted_rewards_mean=(
predicted_reward_values if policy_utilities.InfoFields
.PREDICTED_REWARDS_MEAN in self._emit_policy_info else ()),
bandit_policy_type=(bandit_policy_values
if policy_utilities.InfoFields.BANDIT_POLICY_TYPE
in self._emit_policy_info else ()),
chosen_arm_features=chosen_arm_features)
else:
policy_info = policy_utilities.PolicyInfo(
log_probability=log_probability if
policy_utilities.InfoFields.LOG_PROBABILITY in self._emit_policy_info
else (),
predicted_rewards_mean=(
predicted_reward_values if policy_utilities.InfoFields
.PREDICTED_REWARDS_MEAN in self._emit_policy_info else ()),
bandit_policy_type=(bandit_policy_values
if policy_utilities.InfoFields.BANDIT_POLICY_TYPE
in self._emit_policy_info else ()))
return policy_step.PolicyStep(
tfp.distributions.Deterministic(loc=actions), policy_state, policy_info)