in tensorflow_probability/python/distributions/joint_distribution.py [0:0]
def _execute_model(self,
sample_shape=(),
seed=None,
value=None,
stop_index=None,
sample_and_trace_fn=trace_distributions_and_values):
"""Executes `model`, creating both samples and distributions."""
values_out = []
if samplers.is_stateful_seed(seed):
seed_stream = SeedStream(seed, salt='JointDistribution')
if not self._stateful_to_stateless:
seed = None
else:
seed_stream = None # We got a stateless seed for seed=.
# TODO(b/166658748): Make _stateful_to_stateless always True (eliminate it).
if self._stateful_to_stateless and (seed is not None or not JAX_MODE):
seed = samplers.sanitize_seed(seed, salt='JointDistribution')
gen = self._model_coroutine()
index = 0
d = next(gen)
if self._require_root:
if distribution_util.shape_may_be_nontrivial(
sample_shape) and not isinstance(d, self.Root):
raise ValueError('First distribution yielded by coroutine must '
'be wrapped in `Root` when requesting a nontrivial '
f'sample_shape = {sample_shape}.')
try:
while True:
actual_distribution = d.distribution if isinstance(d, self.Root) else d
# Ensure reproducibility even when xs are (partially) set. Always split.
stateful_sample_seed = None if seed_stream is None else seed_stream()
if seed is None:
stateless_sample_seed = None
else:
stateless_sample_seed, seed = samplers.split_seed(seed)
value_at_index = None
if (value is not None and len(value) > index and
value[index] is not None):
value_at_index = value[index]
try:
next_value, traced_values = sample_and_trace_fn(
actual_distribution,
sample_shape=sample_shape if isinstance(d, self.Root) else (),
seed=(stateful_sample_seed if stateless_sample_seed is None
else stateless_sample_seed),
value=value_at_index)
except TypeError as e:
if ('Expected int for argument' not in str(e) and
TENSOR_SEED_MSG_PREFIX not in str(e)) or (
stateful_sample_seed is None):
raise
msg = (
'Falling back to stateful sampling for distribution #{index} '
'(0-based) of type `{dist_cls}` with component name '
'{component_name} and `dist.name` "{dist_name}". Please '
'update to use `tf.random.stateless_*` RNGs. This fallback may '
'be removed after 20-Dec-2020. ({exc})')
component_name = get_explicit_name_for_component(actual_distribution)
if component_name is None:
component_name = '[None specified]'
else:
component_name = '"{}"'.format(component_name)
warnings.warn(msg.format(
index=index,
component_name=component_name,
dist_name=actual_distribution.name,
dist_cls=type(actual_distribution),
exc=str(e)))
next_value, traced_values = sample_and_trace_fn(
actual_distribution,
sample_shape=sample_shape if isinstance(d, self.Root) else (),
seed=stateful_sample_seed,
value=value_at_index)
if self._validate_args:
with tf.control_dependencies(
itertools.chain.from_iterable(
self._assert_compatible_shape(index, sample_shape, value_part)
for value_part in tf.nest.flatten(next_value))):
values_out.append(
tf.nest.map_structure(
lambda x: tf.identity(x) if tf.is_tensor(x) else x,
traced_values))
else:
values_out.append(traced_values)
index += 1
if stop_index is not None and index == stop_index:
break
d = gen.send(next_value)
except StopIteration:
pass
return values_out