def _build_sampler_loop_body()

in tensorflow_probability/python/experimental/sts_gibbs/gibbs_sampler.py [0:0]


def _build_sampler_loop_body(model,
                             observed_time_series,
                             is_missing=None):
  """Builds a Gibbs sampler for the given model and observed data.

  Args:
    model: A `tf.sts.StructuralTimeSeries` model instance. This must be of the
      form constructed by `build_model_for_gibbs_sampling`.
    observed_time_series: Float `Tensor` time series of shape
      `[..., num_timesteps]`.
    is_missing: Optional `bool` `Tensor` of shape `[..., num_timesteps]`. A
      `True` value indicates that the observation for that timestep is missing.
  Returns:
    sampler_loop_body: Python callable that performs a single cycle of Gibbs
      sampling. Its first argument is a `GibbsSamplerState`, and it returns a
      new `GibbsSamplerState`. The second argument (passed by `tf.scan`) is
      ignored.
  """
  level_component = model.components[0]
  if not (isinstance(level_component, sts.LocalLevel) or
          isinstance(level_component, sts.LocalLinearTrend)):
    raise ValueError('Expected the first model component to be an instance of '
                     '`tfp.sts.LocalLevel` or `tfp.sts.LocalLinearTrend`; '
                     'instead saw {}'.format(level_component))
  model_has_slope = isinstance(level_component, sts.LocalLinearTrend)

  regression_component = model.components[1]
  if not (isinstance(regression_component, sts.LinearRegression) or
          isinstance(regression_component, SpikeAndSlabSparseLinearRegression)):
    raise ValueError('Expected the second model component to be an instance of '
                     '`tfp.sts.LinearRegression` or '
                     '`SpikeAndSlabSparseLinearRegression`; '
                     'instead saw {}'.format(regression_component))
  model_has_spike_slab_regression = isinstance(
      regression_component, SpikeAndSlabSparseLinearRegression)

  if is_missing is not None:  # Ensure series does not contain NaNs.
    observed_time_series = tf.where(is_missing,
                                    tf.zeros_like(observed_time_series),
                                    observed_time_series)

  num_observed_steps = prefer_static.shape(observed_time_series)[-1]
  design_matrix = _get_design_matrix(model).to_dense()[:num_observed_steps]
  if is_missing is not None:
    # Replace design matrix with zeros at unobserved timesteps. This ensures
    # they will not affect the posterior on weights.
    design_matrix = tf.where(is_missing[..., tf.newaxis],
                             tf.zeros_like(design_matrix),
                             design_matrix)

  # Untransform scale priors -> variance priors by reaching thru Sqrt bijector.
  observation_noise_param = model.parameters[0]
  if 'observation_noise' not in observation_noise_param.name:
    raise ValueError('Model parameters {} do not match the expected sampler '
                     'state.'.format(model.parameters))
  observation_noise_variance_prior = observation_noise_param.prior.distribution
  if model_has_slope:
    level_scale_variance_prior, slope_scale_variance_prior = [
        p.prior.distribution for p in level_component.parameters]
  else:
    level_scale_variance_prior = (
        level_component.parameters[0].prior.distribution)

  if model_has_spike_slab_regression:
    spike_and_slab_sampler = spike_and_slab.SpikeSlabSampler(
        design_matrix,
        weights_prior_precision=regression_component._weights_prior_precision,  # pylint: disable=protected-access
        nonzero_prior_prob=regression_component._sparse_weights_nonzero_prob,  # pylint: disable=protected-access
        observation_noise_variance_prior_concentration=(
            observation_noise_variance_prior.concentration),
        observation_noise_variance_prior_scale=(
            observation_noise_variance_prior.scale),
        observation_noise_variance_upper_bound=(
            observation_noise_variance_prior.upper_bound
            if hasattr(observation_noise_variance_prior, 'upper_bound')
            else None))
  else:
    weights_prior_scale = (
        regression_component.parameters[0].prior.scale)

  def sampler_loop_body(previous_sample, _):
    """Runs one sampler iteration, resampling all model variables."""

    (weights_seed,
     level_seed,
     observation_noise_scale_seed,
     level_scale_seed,
     loop_seed) = samplers.split_seed(
         previous_sample.seed, n=5, salt='sampler_loop_body')
    # Preserve backward-compatible seed behavior by splitting slope separately.
    slope_scale_seed, = samplers.split_seed(
        previous_sample.seed, n=1, salt='sampler_loop_body_slope')

    # We encourage a reasonable initialization by sampling the weights first,
    # so at the first step they are regressed directly against the observed
    # time series. If we instead sampled the level first it might 'explain away'
    # some observed variation that we would ultimately prefer to explain through
    # the regression weights, because the level can represent arbitrary
    # variation, while the weights are limited to representing variation in the
    # subspace given by the design matrix.
    if model_has_spike_slab_regression:
      (observation_noise_variance,
       weights) = spike_and_slab_sampler.sample_noise_variance_and_weights(
           initial_nonzeros=tf.not_equal(previous_sample.weights, 0.),
           targets=observed_time_series - previous_sample.level,
           seed=weights_seed)
      observation_noise_scale = tf.sqrt(observation_noise_variance)
    else:
      weights = _resample_weights(
          design_matrix=design_matrix,
          target_residuals=observed_time_series - previous_sample.level,
          observation_noise_scale=previous_sample.observation_noise_scale,
          weights_prior_scale=weights_prior_scale,
          seed=weights_seed)
      # Noise scale will be resampled below.
      observation_noise_scale = previous_sample.observation_noise_scale

    regression_residuals = observed_time_series - tf.linalg.matvec(
        design_matrix, weights)
    latents = _resample_latents(
        observed_residuals=regression_residuals,
        level_scale=previous_sample.level_scale,
        slope_scale=previous_sample.slope_scale if model_has_slope else None,
        observation_noise_scale=observation_noise_scale,
        initial_state_prior=level_component.initial_state_prior,
        is_missing=is_missing,
        seed=level_seed)
    level = latents[..., 0]
    level_residuals = level[..., 1:] - level[..., :-1]
    if model_has_slope:
      slope = latents[..., 1]
      level_residuals -= slope[..., :-1]
      slope_residuals = slope[..., 1:] - slope[..., :-1]

    # Estimate level scale from the empirical changes in level.
    level_scale = _resample_scale(
        prior=level_scale_variance_prior,
        observed_residuals=level_residuals,
        is_missing=None,
        seed=level_scale_seed)
    if model_has_slope:
      slope_scale = _resample_scale(
          prior=slope_scale_variance_prior,
          observed_residuals=slope_residuals,
          is_missing=None,
          seed=slope_scale_seed)
    if not model_has_spike_slab_regression:
      # Estimate noise scale from the residuals.
      observation_noise_scale = _resample_scale(
          prior=observation_noise_variance_prior,
          observed_residuals=regression_residuals - level,
          is_missing=is_missing,
          seed=observation_noise_scale_seed)

    return GibbsSamplerState(
        observation_noise_scale=observation_noise_scale,
        level_scale=level_scale,
        slope_scale=(slope_scale if model_has_slope
                     else previous_sample.slope_scale),
        weights=weights,
        level=level,
        slope=(slope if model_has_slope
               else previous_sample.slope),
        seed=loop_seed)
  return sampler_loop_body