in tf_agents/bandits/agents/linear_bandit_agent.py [0:0]
def __init__(
self,
exploration_policy,
time_step_spec: types.TimeStep,
action_spec: types.BoundedTensorSpec,
variable_collection: Optional[LinearBanditVariableCollection] = None,
alpha: float = 1.0,
gamma: float = 1.0,
use_eigendecomp: bool = False,
tikhonov_weight: float = 1.0,
add_bias: bool = False,
emit_policy_info: Sequence[Text] = (),
emit_log_probability: bool = False,
observation_and_action_constraint_splitter: Optional[
types.Splitter] = None,
accepts_per_arm_features: bool = False,
debug_summaries: bool = False,
summarize_grads_and_vars: bool = False,
enable_summaries: bool = True,
dtype: tf.DType = tf.float32,
name: Optional[Text] = None):
"""Initialize an instance of `LinearBanditAgent`.
Args:
exploration_policy: An Enum of type `ExplorationPolicy`. The kind of
policy we use for exploration. Currently supported policies are
`LinUCBPolicy` and `LinearThompsonSamplingPolicy`.
time_step_spec: A `TimeStep` spec describing the expected `TimeStep`s.
action_spec: A scalar `BoundedTensorSpec` with `int32` or `int64` dtype
describing the number of actions for this agent.
variable_collection: Instance of `LinearBanditVariableCollection`.
Collection of variables to be updated by the agent. If `None`, a new
instance of `LinearBanditVariableCollection` will be created.
alpha: (float) positive scalar. This is the exploration parameter that
multiplies the confidence intervals.
gamma: a float forgetting factor in [0.0, 1.0]. When set to 1.0, the
algorithm does not forget.
use_eigendecomp: whether to use eigen-decomposition or not. The default
solver is Conjugate Gradient.
tikhonov_weight: (float) tikhonov regularization term.
add_bias: If true, a bias term will be added to the linear reward
estimation.
emit_policy_info: (tuple of strings) what side information we want to get
as part of the policy info. Allowed values can be found in
`policy_utilities.PolicyInfo`.
emit_log_probability: Whether the policy emits log-probabilities or not.
Since the policy is deterministic, the probability is just 1.
observation_and_action_constraint_splitter: A function used for masking
valid/invalid actions with each state of the environment. The function
takes in a full observation and returns a tuple consisting of 1) the
part of the observation intended as input to the bandit agent and
policy, and 2) the boolean mask. This function should also work with a
`TensorSpec` as input, and should output `TensorSpec` objects for the
observation and mask.
accepts_per_arm_features: (bool) Whether the agent accepts per-arm
features.
debug_summaries: A Python bool, default False. When True, debug summaries
are gathered.
summarize_grads_and_vars: A Python bool, default False. When True,
gradients and network variable summaries are written during training.
enable_summaries: A Python bool, default True. When False, all summaries
(debug or otherwise) should not be written.
dtype: The type of the parameters stored and updated by the agent. Should
be one of `tf.float32` and `tf.float64`. Defaults to `tf.float32`.
name: a name for this instance of `LinearBanditAgent`.
Raises:
ValueError if dtype is not one of `tf.float32` or `tf.float64`.
TypeError if variable_collection is not an instance of
`LinearBanditVariableCollection`.
"""
tf.Module.__init__(self, name=name)
common.tf_agents_gauge.get_cell('TFABandit').set(True)
self._num_actions = policy_utilities.get_num_actions_from_tensor_spec(
action_spec)
self._num_models = 1 if accepts_per_arm_features else self._num_actions
self._observation_and_action_constraint_splitter = (
observation_and_action_constraint_splitter)
self._time_step_spec = time_step_spec
self._accepts_per_arm_features = accepts_per_arm_features
self._add_bias = add_bias
if observation_and_action_constraint_splitter is not None:
context_spec, _ = observation_and_action_constraint_splitter(
time_step_spec.observation)
else:
context_spec = time_step_spec.observation
(self._global_context_dim,
self._arm_context_dim) = bandit_spec_utils.get_context_dims_from_spec(
context_spec, accepts_per_arm_features)
if self._add_bias:
# The bias is added via a constant 1 feature.
self._global_context_dim += 1
self._overall_context_dim = self._global_context_dim + self._arm_context_dim
self._alpha = alpha
if variable_collection is None:
variable_collection = LinearBanditVariableCollection(
context_dim=self._overall_context_dim,
num_models=self._num_models,
use_eigendecomp=use_eigendecomp,
dtype=dtype)
elif not isinstance(variable_collection, LinearBanditVariableCollection):
raise TypeError('Parameter `variable_collection` should be '
'of type `LinearBanditVariableCollection`.')
self._variable_collection = variable_collection
self._cov_matrix_list = variable_collection.cov_matrix_list
self._data_vector_list = variable_collection.data_vector_list
self._eig_matrix_list = variable_collection.eig_matrix_list
self._eig_vals_list = variable_collection.eig_vals_list
# We keep track of the number of samples per arm.
self._num_samples_list = variable_collection.num_samples_list
self._gamma = gamma
if self._gamma < 0.0 or self._gamma > 1.0:
raise ValueError('Forgetting factor `gamma` must be in [0.0, 1.0].')
self._dtype = dtype
if dtype not in (tf.float32, tf.float64):
raise ValueError(
'Agent dtype should be either `tf.float32 or `tf.float64`.')
self._use_eigendecomp = use_eigendecomp
self._tikhonov_weight = tikhonov_weight
if exploration_policy == ExplorationPolicy.linear_ucb_policy:
exploration_strategy = lin_policy.ExplorationStrategy.optimistic
elif exploration_policy == (
ExplorationPolicy.linear_thompson_sampling_policy):
exploration_strategy = lin_policy.ExplorationStrategy.sampling
else:
raise ValueError('Linear bandit agent with policy %s not implemented' %
exploration_policy)
policy = lin_policy.LinearBanditPolicy(
action_spec=action_spec,
cov_matrix=self._cov_matrix_list,
data_vector=self._data_vector_list,
num_samples=self._num_samples_list,
time_step_spec=time_step_spec,
exploration_strategy=exploration_strategy,
alpha=alpha,
eig_vals=self._eig_vals_list if self._use_eigendecomp else (),
eig_matrix=self._eig_matrix_list if self._use_eigendecomp else (),
tikhonov_weight=self._tikhonov_weight,
add_bias=add_bias,
emit_policy_info=emit_policy_info,
emit_log_probability=emit_log_probability,
accepts_per_arm_features=accepts_per_arm_features,
observation_and_action_constraint_splitter=(
observation_and_action_constraint_splitter))
training_data_spec = None
if accepts_per_arm_features:
training_data_spec = bandit_spec_utils.drop_arm_observation(
policy.trajectory_spec)
super(LinearBanditAgent, self).__init__(
time_step_spec=time_step_spec,
action_spec=action_spec,
policy=policy,
collect_policy=policy,
training_data_spec=training_data_spec,
debug_summaries=debug_summaries,
summarize_grads_and_vars=summarize_grads_and_vars,
enable_summaries=enable_summaries,
train_sequence_length=None)
self._as_trajectory = data_converter.AsTrajectory(
self.data_context, sequence_length=None)