in tf_agents/bandits/policies/linear_bandit_policy.py [0:0]
def _distribution(self, time_step, policy_state):
observation = time_step.observation
if self.observation_and_action_constraint_splitter is not None:
observation, _ = self.observation_and_action_constraint_splitter(
observation)
observation = tf.nest.map_structure(lambda o: tf.cast(o, dtype=self._dtype),
observation)
global_observation, arm_observations = self._split_observation(observation)
if self._add_bias:
# The bias is added via a constant 1 feature.
global_observation = tf.concat([
global_observation,
tf.ones([tf.shape(global_observation)[0], 1], dtype=self._dtype)
],
axis=1)
# Check the shape of the observation matrix. The observations can be
# batched.
if not global_observation.shape.is_compatible_with(
[None, self._global_context_dim]):
raise ValueError(
'Global observation shape is expected to be {}. Got {}.'.format(
[None, self._global_context_dim],
global_observation.shape.as_list()))
global_observation = tf.reshape(global_observation,
[-1, self._global_context_dim])
est_rewards = []
confidence_intervals = []
for k in range(self._num_actions):
current_observation = self._get_current_observation(
global_observation, arm_observations, k)
model_index = policy_utilities.get_model_index(
k, self._accepts_per_arm_features)
if self._use_eigendecomp:
q_t_b = tf.matmul(
self._eig_matrix[model_index],
tf.linalg.matrix_transpose(current_observation),
transpose_a=True)
lambda_inv = tf.divide(
tf.ones_like(self._eig_vals[model_index]),
self._eig_vals[model_index] + self._tikhonov_weight)
a_inv_x = tf.matmul(self._eig_matrix[model_index],
tf.einsum('j,jk->jk', lambda_inv, q_t_b))
else:
a_inv_x = linalg.conjugate_gradient(
self._cov_matrix[model_index] + self._tikhonov_weight *
tf.eye(self._overall_context_dim, dtype=self._dtype),
tf.linalg.matrix_transpose(current_observation))
est_mean_reward = tf.einsum('j,jk->k', self._data_vector[model_index],
a_inv_x)
est_rewards.append(est_mean_reward)
ci = tf.reshape(
tf.linalg.tensor_diag_part(tf.matmul(current_observation, a_inv_x)),
[-1, 1])
confidence_intervals.append(ci)
if self._exploration_strategy == ExplorationStrategy.optimistic:
optimistic_estimates = [
tf.reshape(mean_reward, [-1, 1]) + self._alpha * tf.sqrt(confidence)
for mean_reward, confidence in zip(est_rewards, confidence_intervals)
]
# Keeping the batch dimension during the squeeze, even if batch_size == 1.
rewards_for_argmax = tf.squeeze(
tf.stack(optimistic_estimates, axis=-1), axis=[1])
elif self._exploration_strategy == ExplorationStrategy.sampling:
mu_sampler = tfd.Normal(
loc=tf.stack(est_rewards, axis=-1),
scale=self._alpha *
tf.sqrt(tf.squeeze(tf.stack(confidence_intervals, axis=-1), axis=1)))
rewards_for_argmax = mu_sampler.sample()
else:
raise ValueError('Exploraton strategy %s not implemented.' %
self._exploration_strategy)
mask = constraints.construct_mask_from_multiple_sources(
time_step.observation, self._observation_and_action_constraint_splitter,
(), self._num_actions)
if mask is not None:
chosen_actions = policy_utilities.masked_argmax(
rewards_for_argmax,
mask,
output_type=tf.nest.flatten(self._action_spec)[0].dtype)
else:
chosen_actions = tf.argmax(
rewards_for_argmax,
axis=-1,
output_type=tf.nest.flatten(self._action_spec)[0].dtype)
action_distributions = tfp.distributions.Deterministic(loc=chosen_actions)
policy_info = policy_utilities.populate_policy_info(
arm_observations, chosen_actions, rewards_for_argmax,
tf.stack(est_rewards, axis=-1), self._emit_policy_info,
self._accepts_per_arm_features)
return policy_step.PolicyStep(
action_distributions, policy_state, policy_info)