def _sample_continuous_and_transpose()

in tf_agents/policies/samplers/cem_actions_sampler_continuous_and_one_hot.py [0:0]


  def _sample_continuous_and_transpose(
      self, mean, var, state, i, one_hot_index):
    num_samples = self._number_samples_all[i]

    def sample_and_transpose(mean, var, spec, mask):
      if spec.dtype.is_integer:
        sample = tf.one_hot(
            one_hot_index, self._num_mutually_exclusive_actions)
        sample = tf.broadcast_to(
            sample,
            [tf.shape(mean)[0],
             tf.constant(num_samples),  # pylint: disable=cell-var-from-loop
             tf.shape(mean)[1]])
      else:
        dist = tfp.distributions.Normal(loc=mean, scale=tf.sqrt(var))
        # Transpose to [B, N, A]
        sample = tf.transpose(
            dist.sample(num_samples), [1, 0, 2])  # pylint: disable=cell-var-from-loop
        sample = sample * mask
      return tf.cast(sample, spec.dtype)

    batch_size = tf.shape(tf.nest.flatten(mean)[0])[0]

    def sample_fn(mean_sample, var_sample, state_sample):
      # [B, N, A]
      samples_continuous = tf.nest.map_structure(sample_and_transpose,
                                                 mean_sample, var_sample,
                                                 self._action_spec,
                                                 self._masks[i])

      if self._sample_clippers[i]:
        for sample_clipper in self._sample_clippers[i]:
          samples_continuous = sample_clipper(samples_continuous, state_sample)

      samples_continuous = tf.nest.map_structure(
          common.clip_to_spec, samples_continuous, self._action_spec)
      return samples_continuous

    @tf.function
    def rejection_sampling(sample_rejector):
      valid_batch_samples = tf.nest.map_structure(
          lambda spec: tf.TensorArray(spec.dtype, size=batch_size),
          self._action_spec)

      for b_indx in tf.range(batch_size):
        k = tf.constant(0)
        # pylint: disable=cell-var-from-loop
        valid_samples = tf.nest.map_structure(
            lambda spec: tf.TensorArray(spec.dtype, size=num_samples),
            self._action_spec)

        count = tf.constant(0)
        while count < self._max_rejection_iterations:
          count += 1
          mean_sample = tf.nest.map_structure(
              lambda t: tf.expand_dims(tf.gather(t, b_indx), axis=0), mean)
          var_sample = tf.nest.map_structure(
              lambda t: tf.expand_dims(tf.gather(t, b_indx), axis=0), var)
          if state is not None:
            state_sample = tf.nest.map_structure(
                lambda t: tf.expand_dims(tf.gather(t, b_indx), axis=0), state)
          else:
            state_sample = None

          samples = sample_fn(mean_sample, var_sample, state_sample)  # n, a

          mask = sample_rejector(samples, state_sample)

          mask = mask[0, ...]
          mask_index = tf.where(mask)[:, 0]

          num_mask = tf.shape(mask_index)[0]
          if num_mask == 0:
            continue

          good_samples = tf.nest.map_structure(
              lambda t: tf.gather(t, mask_index, axis=1)[0, ...], samples)

          for sample_idx in range(num_mask):
            if k >= num_samples:
              break
            valid_samples = tf.nest.map_structure(
                lambda gs, vs: vs.write(k, gs[sample_idx:sample_idx+1, ...]),
                good_samples, valid_samples)
            k += 1

        if k < num_samples:
          def sample_zero_and_one_hot(spec):
            if spec.dtype.is_integer:
              sample = tf.one_hot(
                  one_hot_index, self._num_mutually_exclusive_actions)
            else:
              sample = tf.zeros(spec.shape, spec.dtype)

            sample = tf.broadcast_to(
                sample,
                tf.TensorShape([num_samples] + sample.shape.dims))
            return tf.cast(sample, spec.dtype)

          zero_samples = tf.nest.map_structure(
              sample_zero_and_one_hot, self._action_spec)
          for sample_idx in range(num_samples-k):
            valid_samples = tf.nest.map_structure(
                lambda gs, vs: vs.write(k, gs[sample_idx:sample_idx+1, ...]),
                zero_samples, valid_samples)

        valid_samples = tf.nest.map_structure(lambda vs: vs.concat(),
                                              valid_samples)

        valid_batch_samples = tf.nest.map_structure(
            lambda vbs, vs: vbs.write(b_indx, vs), valid_batch_samples,
            valid_samples)

      samples_continuous = tf.nest.map_structure(
          lambda a: a.stack(), valid_batch_samples)
      return samples_continuous

    if self._sample_rejecters[i]:
      samples_continuous = rejection_sampling(self._sample_rejecters[i])
      def set_b_n_shape(t):
        t.set_shape(tf.TensorShape([None, num_samples] + t.shape[2:].dims))

      tf.nest.map_structure(set_b_n_shape, samples_continuous)
    else:
      samples_continuous = sample_fn(mean, var, state)

    return samples_continuous