def __init__()

in tf_agents/bandits/agents/linear_bandit_agent.py [0:0]


  def __init__(
      self,
      exploration_policy,
      time_step_spec: types.TimeStep,
      action_spec: types.BoundedTensorSpec,
      variable_collection: Optional[LinearBanditVariableCollection] = None,
      alpha: float = 1.0,
      gamma: float = 1.0,
      use_eigendecomp: bool = False,
      tikhonov_weight: float = 1.0,
      add_bias: bool = False,
      emit_policy_info: Sequence[Text] = (),
      emit_log_probability: bool = False,
      observation_and_action_constraint_splitter: Optional[
          types.Splitter] = None,
      accepts_per_arm_features: bool = False,
      debug_summaries: bool = False,
      summarize_grads_and_vars: bool = False,
      enable_summaries: bool = True,
      dtype: tf.DType = tf.float32,
      name: Optional[Text] = None):
    """Initialize an instance of `LinearBanditAgent`.

    Args:
      exploration_policy: An Enum of type `ExplorationPolicy`. The kind of
        policy we use for exploration. Currently supported policies are
        `LinUCBPolicy` and `LinearThompsonSamplingPolicy`.
      time_step_spec: A `TimeStep` spec describing the expected `TimeStep`s.
      action_spec: A scalar `BoundedTensorSpec` with `int32` or `int64` dtype
        describing the number of actions for this agent.
      variable_collection: Instance of `LinearBanditVariableCollection`.
        Collection of variables to be updated by the agent. If `None`, a new
        instance of `LinearBanditVariableCollection` will be created.
      alpha: (float) positive scalar. This is the exploration parameter that
        multiplies the confidence intervals.
      gamma: a float forgetting factor in [0.0, 1.0]. When set to 1.0, the
        algorithm does not forget.
      use_eigendecomp: whether to use eigen-decomposition or not. The default
        solver is Conjugate Gradient.
      tikhonov_weight: (float) tikhonov regularization term.
      add_bias: If true, a bias term will be added to the linear reward
        estimation.
      emit_policy_info: (tuple of strings) what side information we want to get
        as part of the policy info. Allowed values can be found in
        `policy_utilities.PolicyInfo`.
      emit_log_probability: Whether the policy emits log-probabilities or not.
        Since the policy is deterministic, the probability is just 1.
      observation_and_action_constraint_splitter: A function used for masking
        valid/invalid actions with each state of the environment. The function
        takes in a full observation and returns a tuple consisting of 1) the
        part of the observation intended as input to the bandit agent and
        policy, and 2) the boolean mask. This function should also work with a
        `TensorSpec` as input, and should output `TensorSpec` objects for the
        observation and mask.
      accepts_per_arm_features: (bool) Whether the agent accepts per-arm
        features.
      debug_summaries: A Python bool, default False. When True, debug summaries
        are gathered.
      summarize_grads_and_vars: A Python bool, default False. When True,
        gradients and network variable summaries are written during training.
      enable_summaries: A Python bool, default True. When False, all summaries
        (debug or otherwise) should not be written.
      dtype: The type of the parameters stored and updated by the agent. Should
        be one of `tf.float32` and `tf.float64`. Defaults to `tf.float32`.
      name: a name for this instance of `LinearBanditAgent`.

    Raises:
      ValueError if dtype is not one of `tf.float32` or `tf.float64`.
      TypeError if variable_collection is not an instance of
        `LinearBanditVariableCollection`.
    """
    tf.Module.__init__(self, name=name)
    common.tf_agents_gauge.get_cell('TFABandit').set(True)
    self._num_actions = policy_utilities.get_num_actions_from_tensor_spec(
        action_spec)
    self._num_models = 1 if accepts_per_arm_features else self._num_actions
    self._observation_and_action_constraint_splitter = (
        observation_and_action_constraint_splitter)
    self._time_step_spec = time_step_spec
    self._accepts_per_arm_features = accepts_per_arm_features
    self._add_bias = add_bias
    if observation_and_action_constraint_splitter is not None:
      context_spec, _ = observation_and_action_constraint_splitter(
          time_step_spec.observation)
    else:
      context_spec = time_step_spec.observation

    (self._global_context_dim,
     self._arm_context_dim) = bandit_spec_utils.get_context_dims_from_spec(
         context_spec, accepts_per_arm_features)
    if self._add_bias:
      # The bias is added via a constant 1 feature.
      self._global_context_dim += 1
    self._overall_context_dim = self._global_context_dim + self._arm_context_dim

    self._alpha = alpha
    if variable_collection is None:
      variable_collection = LinearBanditVariableCollection(
          context_dim=self._overall_context_dim,
          num_models=self._num_models,
          use_eigendecomp=use_eigendecomp,
          dtype=dtype)
    elif not isinstance(variable_collection, LinearBanditVariableCollection):
      raise TypeError('Parameter `variable_collection` should be '
                      'of type `LinearBanditVariableCollection`.')
    self._variable_collection = variable_collection
    self._cov_matrix_list = variable_collection.cov_matrix_list
    self._data_vector_list = variable_collection.data_vector_list
    self._eig_matrix_list = variable_collection.eig_matrix_list
    self._eig_vals_list = variable_collection.eig_vals_list
    # We keep track of the number of samples per arm.
    self._num_samples_list = variable_collection.num_samples_list
    self._gamma = gamma
    if self._gamma < 0.0 or self._gamma > 1.0:
      raise ValueError('Forgetting factor `gamma` must be in [0.0, 1.0].')
    self._dtype = dtype
    if dtype not in (tf.float32, tf.float64):
      raise ValueError(
          'Agent dtype should be either `tf.float32 or `tf.float64`.')
    self._use_eigendecomp = use_eigendecomp
    self._tikhonov_weight = tikhonov_weight

    if exploration_policy == ExplorationPolicy.linear_ucb_policy:
      exploration_strategy = lin_policy.ExplorationStrategy.optimistic
    elif exploration_policy == (
        ExplorationPolicy.linear_thompson_sampling_policy):
      exploration_strategy = lin_policy.ExplorationStrategy.sampling
    else:
      raise ValueError('Linear bandit agent with policy %s not implemented' %
                       exploration_policy)
    policy = lin_policy.LinearBanditPolicy(
        action_spec=action_spec,
        cov_matrix=self._cov_matrix_list,
        data_vector=self._data_vector_list,
        num_samples=self._num_samples_list,
        time_step_spec=time_step_spec,
        exploration_strategy=exploration_strategy,
        alpha=alpha,
        eig_vals=self._eig_vals_list if self._use_eigendecomp else (),
        eig_matrix=self._eig_matrix_list if self._use_eigendecomp else (),
        tikhonov_weight=self._tikhonov_weight,
        add_bias=add_bias,
        emit_policy_info=emit_policy_info,
        emit_log_probability=emit_log_probability,
        accepts_per_arm_features=accepts_per_arm_features,
        observation_and_action_constraint_splitter=(
            observation_and_action_constraint_splitter))

    training_data_spec = None
    if accepts_per_arm_features:
      training_data_spec = bandit_spec_utils.drop_arm_observation(
          policy.trajectory_spec)
    super(LinearBanditAgent, self).__init__(
        time_step_spec=time_step_spec,
        action_spec=action_spec,
        policy=policy,
        collect_policy=policy,
        training_data_spec=training_data_spec,
        debug_summaries=debug_summaries,
        summarize_grads_and_vars=summarize_grads_and_vars,
        enable_summaries=enable_summaries,
        train_sequence_length=None)
    self._as_trajectory = data_converter.AsTrajectory(
        self.data_context, sequence_length=None)