def _execute_model()

in tensorflow_probability/python/distributions/joint_distribution.py [0:0]


  def _execute_model(self,
                     sample_shape=(),
                     seed=None,
                     value=None,
                     stop_index=None,
                     sample_and_trace_fn=trace_distributions_and_values):
    """Executes `model`, creating both samples and distributions."""
    values_out = []
    if samplers.is_stateful_seed(seed):
      seed_stream = SeedStream(seed, salt='JointDistribution')
      if not self._stateful_to_stateless:
        seed = None
    else:
      seed_stream = None  # We got a stateless seed for seed=.

    # TODO(b/166658748): Make _stateful_to_stateless always True (eliminate it).
    if self._stateful_to_stateless and (seed is not None or not JAX_MODE):
      seed = samplers.sanitize_seed(seed, salt='JointDistribution')
    gen = self._model_coroutine()
    index = 0
    d = next(gen)
    if self._require_root:
      if distribution_util.shape_may_be_nontrivial(
          sample_shape) and not isinstance(d, self.Root):
        raise ValueError('First distribution yielded by coroutine must '
                         'be wrapped in `Root` when requesting a nontrivial '
                         f'sample_shape = {sample_shape}.')
    try:
      while True:
        actual_distribution = d.distribution if isinstance(d, self.Root) else d
        # Ensure reproducibility even when xs are (partially) set. Always split.
        stateful_sample_seed = None if seed_stream is None else seed_stream()
        if seed is None:
          stateless_sample_seed = None
        else:
          stateless_sample_seed, seed = samplers.split_seed(seed)

        value_at_index = None
        if (value is not None and len(value) > index and
            value[index] is not None):
          value_at_index = value[index]
        try:
          next_value, traced_values = sample_and_trace_fn(
              actual_distribution,
              sample_shape=sample_shape if isinstance(d, self.Root) else (),
              seed=(stateful_sample_seed if stateless_sample_seed is None
                    else stateless_sample_seed),
              value=value_at_index)
        except TypeError as e:
          if ('Expected int for argument' not in str(e) and
              TENSOR_SEED_MSG_PREFIX not in str(e)) or (
                  stateful_sample_seed is None):
            raise
          msg = (
              'Falling back to stateful sampling for distribution #{index} '
              '(0-based) of type `{dist_cls}` with component name '
              '{component_name} and `dist.name` "{dist_name}". Please '
              'update to use `tf.random.stateless_*` RNGs. This fallback may '
              'be removed after 20-Dec-2020. ({exc})')
          component_name = get_explicit_name_for_component(actual_distribution)
          if component_name is None:
            component_name = '[None specified]'
          else:
            component_name = '"{}"'.format(component_name)
          warnings.warn(msg.format(
              index=index,
              component_name=component_name,
              dist_name=actual_distribution.name,
              dist_cls=type(actual_distribution),
              exc=str(e)))
          next_value, traced_values = sample_and_trace_fn(
              actual_distribution,
              sample_shape=sample_shape if isinstance(d, self.Root) else (),
              seed=stateful_sample_seed,
              value=value_at_index)
        if self._validate_args:
          with tf.control_dependencies(
              itertools.chain.from_iterable(
                  self._assert_compatible_shape(index, sample_shape, value_part)
                  for value_part in tf.nest.flatten(next_value))):
            values_out.append(
                tf.nest.map_structure(
                    lambda x: tf.identity(x) if tf.is_tensor(x) else x,
                    traced_values))
        else:
          values_out.append(traced_values)

        index += 1
        if stop_index is not None and index == stop_index:
          break
        d = gen.send(next_value)
    except StopIteration:
      pass
    return values_out