in tf_agents/policies/samplers/cem_actions_sampler_continuous_and_one_hot.py [0:0]
def __init__(self, action_spec, sample_clippers=None,
sub_actions_fields=None, sample_rejecters=None,
max_rejection_iterations=10):
"""Builds a GaussianActionsSampler.
Args:
action_spec: A dict of BoundedTensorSpec representing the actions.
sample_clippers: A list of list of sample clipper functions. The function
takes a dict of Tensors of actions and a dict of Tensors of the state,
output a dict of Tensors of clipped actions.
sub_actions_fields: A list of list of action keys to group
fields into sub_actions.
sample_rejecters: A list of callables that will reject samples and return
a mask tensor.
max_rejection_iterations: max_rejection_iterations
"""
super(GaussianActionsSampler, self).__init__(
action_spec, sample_clippers, sample_rejecters)
num_one_hot_action = 0
for flat_action_spec in tf.nest.flatten(action_spec):
if flat_action_spec.shape.rank != 1:
raise ValueError('Only 1d action is supported by this sampler. '
'The action_spec: \n{}\n contains action whose rank is'
' not 1. Consider coverting it into multiple 1d '
'actions.'.format(action_spec))
if flat_action_spec.dtype.is_integer:
num_one_hot_action = num_one_hot_action + 1
# S
self._num_mutually_exclusive_actions = (
flat_action_spec.shape.as_list()[0])
if num_one_hot_action != 1:
raise ValueError('Only continuous action + 1 one_hot action is supported'
' by this sampler. The action_spec: \n{}\n contains '
'either multiple one_hot actions or no one_hot '
'action'.format(action_spec))
if sample_clippers is None:
raise ValueError('Sampler clippers must be set!')
if sub_actions_fields is None:
raise ValueError('sub_actions_fields must be set!')
if len(sample_clippers) != len(sub_actions_fields):
raise ValueError('Number of sample_clippers must be the same as number of'
' sub_actions_fields! sample_clippers: {}, '
'sub_actions_fields: {}'.format(
sample_clippers, sub_actions_fields))
if self._sample_rejecters is None:
self._sample_rejecters = [None] * len(sub_actions_fields)
self._max_rejection_iterations = tf.constant(max_rejection_iterations)
self._num_sub_actions = len(sample_clippers)
self._sub_actions_fields = sub_actions_fields
action_spec_keys = list(sorted(self._action_spec.keys()))
sub_actions_fields_keys = [
item for sublist in self._sub_actions_fields for item in sublist # pylint: disable=g-complex-comprehension
]
sub_actions_fields_keys.sort()
if action_spec_keys != sub_actions_fields_keys:
raise ValueError('sub_actions_fields must cover all keys in action_spec!'
'action_spec_keys: {}, sub_actions_fields_keys:'
' {}'.format(action_spec_keys, sub_actions_fields_keys))
self._categorical_index = -1
for i in range(self._num_sub_actions):
if (len(self._sub_actions_fields[i]) == 1 and
self._action_spec[self._sub_actions_fields[i][0]].dtype.is_integer):
self._categorical_index = i
break
if self._categorical_index == -1:
raise ValueError('Categorical action cannot be grouped together w/ '
'continuous action into a sub_action.')
self._categorical_key = self._sub_actions_fields[self._categorical_index][0]
# K
self._num_sub_continuous_actions = self._num_sub_actions - 1
# S-K
self._num_sub_categorical_actions = (
self._num_mutually_exclusive_actions -
self._num_sub_continuous_actions)
# Because the sampler will sample for all fields and there are actions
# that are mutually exclusive. Therefore masks are needed to zero
# out the fields that does not belong to the sub_action.
self._masks = []
for i in range(self._num_sub_actions):
mask = {}
for k in self._action_spec.keys():
if k in self._sub_actions_fields[i]:
mask[k] = tf.ones([1])
else:
mask[k] = tf.zeros([1])
self._masks.append(mask)
self._index_range_min = {}
self._index_range_max = {}