def _distribution()

in tf_agents/bandits/policies/linear_bandit_policy.py [0:0]


  def _distribution(self, time_step, policy_state):
    observation = time_step.observation
    if self.observation_and_action_constraint_splitter is not None:
      observation, _ = self.observation_and_action_constraint_splitter(
          observation)
    observation = tf.nest.map_structure(lambda o: tf.cast(o, dtype=self._dtype),
                                        observation)
    global_observation, arm_observations = self._split_observation(observation)

    if self._add_bias:
      # The bias is added via a constant 1 feature.
      global_observation = tf.concat([
          global_observation,
          tf.ones([tf.shape(global_observation)[0], 1], dtype=self._dtype)
      ],
                                     axis=1)
    # Check the shape of the observation matrix. The observations can be
    # batched.
    if not global_observation.shape.is_compatible_with(
        [None, self._global_context_dim]):
      raise ValueError(
          'Global observation shape is expected to be {}. Got {}.'.format(
              [None, self._global_context_dim],
              global_observation.shape.as_list()))
    global_observation = tf.reshape(global_observation,
                                    [-1, self._global_context_dim])

    est_rewards = []
    confidence_intervals = []
    for k in range(self._num_actions):
      current_observation = self._get_current_observation(
          global_observation, arm_observations, k)
      model_index = policy_utilities.get_model_index(
          k, self._accepts_per_arm_features)
      if self._use_eigendecomp:
        q_t_b = tf.matmul(
            self._eig_matrix[model_index],
            tf.linalg.matrix_transpose(current_observation),
            transpose_a=True)
        lambda_inv = tf.divide(
            tf.ones_like(self._eig_vals[model_index]),
            self._eig_vals[model_index] + self._tikhonov_weight)
        a_inv_x = tf.matmul(self._eig_matrix[model_index],
                            tf.einsum('j,jk->jk', lambda_inv, q_t_b))
      else:
        a_inv_x = linalg.conjugate_gradient(
            self._cov_matrix[model_index] + self._tikhonov_weight *
            tf.eye(self._overall_context_dim, dtype=self._dtype),
            tf.linalg.matrix_transpose(current_observation))
      est_mean_reward = tf.einsum('j,jk->k', self._data_vector[model_index],
                                  a_inv_x)
      est_rewards.append(est_mean_reward)

      ci = tf.reshape(
          tf.linalg.tensor_diag_part(tf.matmul(current_observation, a_inv_x)),
          [-1, 1])
      confidence_intervals.append(ci)

    if self._exploration_strategy == ExplorationStrategy.optimistic:
      optimistic_estimates = [
          tf.reshape(mean_reward, [-1, 1]) + self._alpha * tf.sqrt(confidence)
          for mean_reward, confidence in zip(est_rewards, confidence_intervals)
      ]
      # Keeping the batch dimension during the squeeze, even if batch_size == 1.
      rewards_for_argmax = tf.squeeze(
          tf.stack(optimistic_estimates, axis=-1), axis=[1])
    elif self._exploration_strategy == ExplorationStrategy.sampling:
      mu_sampler = tfd.Normal(
          loc=tf.stack(est_rewards, axis=-1),
          scale=self._alpha *
          tf.sqrt(tf.squeeze(tf.stack(confidence_intervals, axis=-1), axis=1)))
      rewards_for_argmax = mu_sampler.sample()
    else:
      raise ValueError('Exploraton strategy %s not implemented.' %
                       self._exploration_strategy)

    mask = constraints.construct_mask_from_multiple_sources(
        time_step.observation, self._observation_and_action_constraint_splitter,
        (), self._num_actions)
    if mask is not None:
      chosen_actions = policy_utilities.masked_argmax(
          rewards_for_argmax,
          mask,
          output_type=tf.nest.flatten(self._action_spec)[0].dtype)
    else:
      chosen_actions = tf.argmax(
          rewards_for_argmax,
          axis=-1,
          output_type=tf.nest.flatten(self._action_spec)[0].dtype)

    action_distributions = tfp.distributions.Deterministic(loc=chosen_actions)

    policy_info = policy_utilities.populate_policy_info(
        arm_observations, chosen_actions, rewards_for_argmax,
        tf.stack(est_rewards, axis=-1), self._emit_policy_info,
        self._accepts_per_arm_features)

    return policy_step.PolicyStep(
        action_distributions, policy_state, policy_info)