in tf_agents/policies/samplers/cem_actions_sampler_continuous_and_one_hot.py [0:0]
def _sample_continuous_and_transpose(
self, mean, var, state, i, one_hot_index):
num_samples = self._number_samples_all[i]
def sample_and_transpose(mean, var, spec, mask):
if spec.dtype.is_integer:
sample = tf.one_hot(
one_hot_index, self._num_mutually_exclusive_actions)
sample = tf.broadcast_to(
sample,
[tf.shape(mean)[0],
tf.constant(num_samples), # pylint: disable=cell-var-from-loop
tf.shape(mean)[1]])
else:
dist = tfp.distributions.Normal(loc=mean, scale=tf.sqrt(var))
# Transpose to [B, N, A]
sample = tf.transpose(
dist.sample(num_samples), [1, 0, 2]) # pylint: disable=cell-var-from-loop
sample = sample * mask
return tf.cast(sample, spec.dtype)
batch_size = tf.shape(tf.nest.flatten(mean)[0])[0]
def sample_fn(mean_sample, var_sample, state_sample):
# [B, N, A]
samples_continuous = tf.nest.map_structure(sample_and_transpose,
mean_sample, var_sample,
self._action_spec,
self._masks[i])
if self._sample_clippers[i]:
for sample_clipper in self._sample_clippers[i]:
samples_continuous = sample_clipper(samples_continuous, state_sample)
samples_continuous = tf.nest.map_structure(
common.clip_to_spec, samples_continuous, self._action_spec)
return samples_continuous
@tf.function
def rejection_sampling(sample_rejector):
valid_batch_samples = tf.nest.map_structure(
lambda spec: tf.TensorArray(spec.dtype, size=batch_size),
self._action_spec)
for b_indx in tf.range(batch_size):
k = tf.constant(0)
# pylint: disable=cell-var-from-loop
valid_samples = tf.nest.map_structure(
lambda spec: tf.TensorArray(spec.dtype, size=num_samples),
self._action_spec)
count = tf.constant(0)
while count < self._max_rejection_iterations:
count += 1
mean_sample = tf.nest.map_structure(
lambda t: tf.expand_dims(tf.gather(t, b_indx), axis=0), mean)
var_sample = tf.nest.map_structure(
lambda t: tf.expand_dims(tf.gather(t, b_indx), axis=0), var)
if state is not None:
state_sample = tf.nest.map_structure(
lambda t: tf.expand_dims(tf.gather(t, b_indx), axis=0), state)
else:
state_sample = None
samples = sample_fn(mean_sample, var_sample, state_sample) # n, a
mask = sample_rejector(samples, state_sample)
mask = mask[0, ...]
mask_index = tf.where(mask)[:, 0]
num_mask = tf.shape(mask_index)[0]
if num_mask == 0:
continue
good_samples = tf.nest.map_structure(
lambda t: tf.gather(t, mask_index, axis=1)[0, ...], samples)
for sample_idx in range(num_mask):
if k >= num_samples:
break
valid_samples = tf.nest.map_structure(
lambda gs, vs: vs.write(k, gs[sample_idx:sample_idx+1, ...]),
good_samples, valid_samples)
k += 1
if k < num_samples:
def sample_zero_and_one_hot(spec):
if spec.dtype.is_integer:
sample = tf.one_hot(
one_hot_index, self._num_mutually_exclusive_actions)
else:
sample = tf.zeros(spec.shape, spec.dtype)
sample = tf.broadcast_to(
sample,
tf.TensorShape([num_samples] + sample.shape.dims))
return tf.cast(sample, spec.dtype)
zero_samples = tf.nest.map_structure(
sample_zero_and_one_hot, self._action_spec)
for sample_idx in range(num_samples-k):
valid_samples = tf.nest.map_structure(
lambda gs, vs: vs.write(k, gs[sample_idx:sample_idx+1, ...]),
zero_samples, valid_samples)
valid_samples = tf.nest.map_structure(lambda vs: vs.concat(),
valid_samples)
valid_batch_samples = tf.nest.map_structure(
lambda vbs, vs: vbs.write(b_indx, vs), valid_batch_samples,
valid_samples)
samples_continuous = tf.nest.map_structure(
lambda a: a.stack(), valid_batch_samples)
return samples_continuous
if self._sample_rejecters[i]:
samples_continuous = rejection_sampling(self._sample_rejecters[i])
def set_b_n_shape(t):
t.set_shape(tf.TensorShape([None, num_samples] + t.shape[2:].dims))
tf.nest.map_structure(set_b_n_shape, samples_continuous)
else:
samples_continuous = sample_fn(mean, var, state)
return samples_continuous