def __init__()

in tf_agents/policies/samplers/cem_actions_sampler_continuous_and_one_hot.py [0:0]


  def __init__(self, action_spec, sample_clippers=None,
               sub_actions_fields=None, sample_rejecters=None,
               max_rejection_iterations=10):
    """Builds a GaussianActionsSampler.

    Args:
      action_spec: A dict of BoundedTensorSpec representing the actions.
      sample_clippers: A list of list of sample clipper functions. The function
        takes a dict of Tensors of actions and a dict of Tensors of the state,
        output a dict of Tensors of clipped actions.
      sub_actions_fields: A list of list of action keys to group
        fields into sub_actions.
      sample_rejecters: A list of callables that will reject samples and return
        a mask tensor.
      max_rejection_iterations: max_rejection_iterations
    """

    super(GaussianActionsSampler, self).__init__(
        action_spec, sample_clippers, sample_rejecters)

    num_one_hot_action = 0
    for flat_action_spec in tf.nest.flatten(action_spec):
      if flat_action_spec.shape.rank != 1:
        raise ValueError('Only 1d action is supported by this sampler. '
                         'The action_spec: \n{}\n contains action whose rank is'
                         ' not 1. Consider coverting it into multiple 1d '
                         'actions.'.format(action_spec))
      if flat_action_spec.dtype.is_integer:
        num_one_hot_action = num_one_hot_action + 1
        # S
        self._num_mutually_exclusive_actions = (
            flat_action_spec.shape.as_list()[0])

    if num_one_hot_action != 1:
      raise ValueError('Only continuous action + 1 one_hot action is supported'
                       ' by this sampler. The action_spec: \n{}\n contains '
                       'either multiple one_hot actions or no one_hot '
                       'action'.format(action_spec))

    if sample_clippers is None:
      raise ValueError('Sampler clippers must be set!')

    if sub_actions_fields is None:
      raise ValueError('sub_actions_fields must be set!')

    if len(sample_clippers) != len(sub_actions_fields):
      raise ValueError('Number of sample_clippers must be the same as number of'
                       ' sub_actions_fields! sample_clippers: {}, '
                       'sub_actions_fields: {}'.format(
                           sample_clippers, sub_actions_fields))

    if self._sample_rejecters is None:
      self._sample_rejecters = [None] * len(sub_actions_fields)

    self._max_rejection_iterations = tf.constant(max_rejection_iterations)

    self._num_sub_actions = len(sample_clippers)
    self._sub_actions_fields = sub_actions_fields

    action_spec_keys = list(sorted(self._action_spec.keys()))
    sub_actions_fields_keys = [
        item for sublist in self._sub_actions_fields for item in sublist  # pylint: disable=g-complex-comprehension
    ]
    sub_actions_fields_keys.sort()
    if action_spec_keys != sub_actions_fields_keys:
      raise ValueError('sub_actions_fields must cover all keys in action_spec!'
                       'action_spec_keys: {}, sub_actions_fields_keys:'
                       ' {}'.format(action_spec_keys, sub_actions_fields_keys))

    self._categorical_index = -1
    for i in range(self._num_sub_actions):
      if (len(self._sub_actions_fields[i]) == 1 and
          self._action_spec[self._sub_actions_fields[i][0]].dtype.is_integer):
        self._categorical_index = i
        break

    if self._categorical_index == -1:
      raise ValueError('Categorical action cannot be grouped together w/ '
                       'continuous action into a sub_action.')
    self._categorical_key = self._sub_actions_fields[self._categorical_index][0]

    # K
    self._num_sub_continuous_actions = self._num_sub_actions - 1
    # S-K
    self._num_sub_categorical_actions = (
        self._num_mutually_exclusive_actions -
        self._num_sub_continuous_actions)

    # Because the sampler will sample for all fields and there are actions
    # that are mutually exclusive. Therefore masks are needed to zero
    # out the fields that does not belong to the sub_action.
    self._masks = []
    for i in range(self._num_sub_actions):
      mask = {}
      for k in self._action_spec.keys():
        if k in self._sub_actions_fields[i]:
          mask[k] = tf.ones([1])
        else:
          mask[k] = tf.zeros([1])
      self._masks.append(mask)

    self._index_range_min = {}
    self._index_range_max = {}