tf_agents/policies/samplers/cem_actions_sampler_continuous.py [122:176]:
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
          samples_continuous = sample_clipper(samples_continuous, state_sample)

      samples_continuous = tf.nest.map_structure(
          common.clip_to_spec, samples_continuous, self._action_spec)
      return samples_continuous

    @tf.function
    def rejection_sampling(sample_rejector):
      valid_batch_samples = tf.nest.map_structure(
          lambda spec: tf.TensorArray(spec.dtype, size=batch_size),
          self._action_spec)

      for b_indx in tf.range(batch_size):
        k = tf.constant(0)
        # pylint: disable=cell-var-from-loop
        valid_samples = tf.nest.map_structure(
            lambda spec: tf.TensorArray(spec.dtype, size=num_samples),
            self._action_spec)

        count = tf.constant(0)
        while count < self._max_rejection_iterations:
          count += 1
          mean_sample = tf.nest.map_structure(
              lambda t: tf.expand_dims(tf.gather(t, b_indx), axis=0), mean)
          var_sample = tf.nest.map_structure(
              lambda t: tf.expand_dims(tf.gather(t, b_indx), axis=0), var)
          if state is not None:
            state_sample = tf.nest.map_structure(
                lambda t: tf.expand_dims(tf.gather(t, b_indx), axis=0), state)
          else:
            state_sample = None

          samples = sample_fn(mean_sample, var_sample, state_sample)  # n, a

          mask = sample_rejector(samples, state_sample)

          mask = mask[0, ...]
          mask_index = tf.where(mask)[:, 0]

          num_mask = tf.shape(mask_index)[0]
          if num_mask == 0:
            continue

          good_samples = tf.nest.map_structure(
              lambda t: tf.gather(t, mask_index, axis=1)[0, ...], samples)

          for sample_idx in range(num_mask):
            if k >= num_samples:
              break
            valid_samples = tf.nest.map_structure(
                lambda gs, vs: vs.write(k, gs[sample_idx:sample_idx+1, ...]),
                good_samples, valid_samples)
            k += 1

        if k < num_samples:
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -



tf_agents/policies/samplers/cem_actions_sampler_continuous_and_one_hot.py [380:434]:
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
          samples_continuous = sample_clipper(samples_continuous, state_sample)

      samples_continuous = tf.nest.map_structure(
          common.clip_to_spec, samples_continuous, self._action_spec)
      return samples_continuous

    @tf.function
    def rejection_sampling(sample_rejector):
      valid_batch_samples = tf.nest.map_structure(
          lambda spec: tf.TensorArray(spec.dtype, size=batch_size),
          self._action_spec)

      for b_indx in tf.range(batch_size):
        k = tf.constant(0)
        # pylint: disable=cell-var-from-loop
        valid_samples = tf.nest.map_structure(
            lambda spec: tf.TensorArray(spec.dtype, size=num_samples),
            self._action_spec)

        count = tf.constant(0)
        while count < self._max_rejection_iterations:
          count += 1
          mean_sample = tf.nest.map_structure(
              lambda t: tf.expand_dims(tf.gather(t, b_indx), axis=0), mean)
          var_sample = tf.nest.map_structure(
              lambda t: tf.expand_dims(tf.gather(t, b_indx), axis=0), var)
          if state is not None:
            state_sample = tf.nest.map_structure(
                lambda t: tf.expand_dims(tf.gather(t, b_indx), axis=0), state)
          else:
            state_sample = None

          samples = sample_fn(mean_sample, var_sample, state_sample)  # n, a

          mask = sample_rejector(samples, state_sample)

          mask = mask[0, ...]
          mask_index = tf.where(mask)[:, 0]

          num_mask = tf.shape(mask_index)[0]
          if num_mask == 0:
            continue

          good_samples = tf.nest.map_structure(
              lambda t: tf.gather(t, mask_index, axis=1)[0, ...], samples)

          for sample_idx in range(num_mask):
            if k >= num_samples:
              break
            valid_samples = tf.nest.map_structure(
                lambda gs, vs: vs.write(k, gs[sample_idx:sample_idx+1, ...]),
                good_samples, valid_samples)
            k += 1

        if k < num_samples:
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -



