in tensorflow_probability/python/experimental/sts_gibbs/gibbs_sampler.py [0:0]
def _build_sampler_loop_body(model,
observed_time_series,
is_missing=None):
"""Builds a Gibbs sampler for the given model and observed data.
Args:
model: A `tf.sts.StructuralTimeSeries` model instance. This must be of the
form constructed by `build_model_for_gibbs_sampling`.
observed_time_series: Float `Tensor` time series of shape
`[..., num_timesteps]`.
is_missing: Optional `bool` `Tensor` of shape `[..., num_timesteps]`. A
`True` value indicates that the observation for that timestep is missing.
Returns:
sampler_loop_body: Python callable that performs a single cycle of Gibbs
sampling. Its first argument is a `GibbsSamplerState`, and it returns a
new `GibbsSamplerState`. The second argument (passed by `tf.scan`) is
ignored.
"""
level_component = model.components[0]
if not (isinstance(level_component, sts.LocalLevel) or
isinstance(level_component, sts.LocalLinearTrend)):
raise ValueError('Expected the first model component to be an instance of '
'`tfp.sts.LocalLevel` or `tfp.sts.LocalLinearTrend`; '
'instead saw {}'.format(level_component))
model_has_slope = isinstance(level_component, sts.LocalLinearTrend)
regression_component = model.components[1]
if not (isinstance(regression_component, sts.LinearRegression) or
isinstance(regression_component, SpikeAndSlabSparseLinearRegression)):
raise ValueError('Expected the second model component to be an instance of '
'`tfp.sts.LinearRegression` or '
'`SpikeAndSlabSparseLinearRegression`; '
'instead saw {}'.format(regression_component))
model_has_spike_slab_regression = isinstance(
regression_component, SpikeAndSlabSparseLinearRegression)
if is_missing is not None: # Ensure series does not contain NaNs.
observed_time_series = tf.where(is_missing,
tf.zeros_like(observed_time_series),
observed_time_series)
num_observed_steps = prefer_static.shape(observed_time_series)[-1]
design_matrix = _get_design_matrix(model).to_dense()[:num_observed_steps]
if is_missing is not None:
# Replace design matrix with zeros at unobserved timesteps. This ensures
# they will not affect the posterior on weights.
design_matrix = tf.where(is_missing[..., tf.newaxis],
tf.zeros_like(design_matrix),
design_matrix)
# Untransform scale priors -> variance priors by reaching thru Sqrt bijector.
observation_noise_param = model.parameters[0]
if 'observation_noise' not in observation_noise_param.name:
raise ValueError('Model parameters {} do not match the expected sampler '
'state.'.format(model.parameters))
observation_noise_variance_prior = observation_noise_param.prior.distribution
if model_has_slope:
level_scale_variance_prior, slope_scale_variance_prior = [
p.prior.distribution for p in level_component.parameters]
else:
level_scale_variance_prior = (
level_component.parameters[0].prior.distribution)
if model_has_spike_slab_regression:
spike_and_slab_sampler = spike_and_slab.SpikeSlabSampler(
design_matrix,
weights_prior_precision=regression_component._weights_prior_precision, # pylint: disable=protected-access
nonzero_prior_prob=regression_component._sparse_weights_nonzero_prob, # pylint: disable=protected-access
observation_noise_variance_prior_concentration=(
observation_noise_variance_prior.concentration),
observation_noise_variance_prior_scale=(
observation_noise_variance_prior.scale),
observation_noise_variance_upper_bound=(
observation_noise_variance_prior.upper_bound
if hasattr(observation_noise_variance_prior, 'upper_bound')
else None))
else:
weights_prior_scale = (
regression_component.parameters[0].prior.scale)
def sampler_loop_body(previous_sample, _):
"""Runs one sampler iteration, resampling all model variables."""
(weights_seed,
level_seed,
observation_noise_scale_seed,
level_scale_seed,
loop_seed) = samplers.split_seed(
previous_sample.seed, n=5, salt='sampler_loop_body')
# Preserve backward-compatible seed behavior by splitting slope separately.
slope_scale_seed, = samplers.split_seed(
previous_sample.seed, n=1, salt='sampler_loop_body_slope')
# We encourage a reasonable initialization by sampling the weights first,
# so at the first step they are regressed directly against the observed
# time series. If we instead sampled the level first it might 'explain away'
# some observed variation that we would ultimately prefer to explain through
# the regression weights, because the level can represent arbitrary
# variation, while the weights are limited to representing variation in the
# subspace given by the design matrix.
if model_has_spike_slab_regression:
(observation_noise_variance,
weights) = spike_and_slab_sampler.sample_noise_variance_and_weights(
initial_nonzeros=tf.not_equal(previous_sample.weights, 0.),
targets=observed_time_series - previous_sample.level,
seed=weights_seed)
observation_noise_scale = tf.sqrt(observation_noise_variance)
else:
weights = _resample_weights(
design_matrix=design_matrix,
target_residuals=observed_time_series - previous_sample.level,
observation_noise_scale=previous_sample.observation_noise_scale,
weights_prior_scale=weights_prior_scale,
seed=weights_seed)
# Noise scale will be resampled below.
observation_noise_scale = previous_sample.observation_noise_scale
regression_residuals = observed_time_series - tf.linalg.matvec(
design_matrix, weights)
latents = _resample_latents(
observed_residuals=regression_residuals,
level_scale=previous_sample.level_scale,
slope_scale=previous_sample.slope_scale if model_has_slope else None,
observation_noise_scale=observation_noise_scale,
initial_state_prior=level_component.initial_state_prior,
is_missing=is_missing,
seed=level_seed)
level = latents[..., 0]
level_residuals = level[..., 1:] - level[..., :-1]
if model_has_slope:
slope = latents[..., 1]
level_residuals -= slope[..., :-1]
slope_residuals = slope[..., 1:] - slope[..., :-1]
# Estimate level scale from the empirical changes in level.
level_scale = _resample_scale(
prior=level_scale_variance_prior,
observed_residuals=level_residuals,
is_missing=None,
seed=level_scale_seed)
if model_has_slope:
slope_scale = _resample_scale(
prior=slope_scale_variance_prior,
observed_residuals=slope_residuals,
is_missing=None,
seed=slope_scale_seed)
if not model_has_spike_slab_regression:
# Estimate noise scale from the residuals.
observation_noise_scale = _resample_scale(
prior=observation_noise_variance_prior,
observed_residuals=regression_residuals - level,
is_missing=is_missing,
seed=observation_noise_scale_seed)
return GibbsSamplerState(
observation_noise_scale=observation_noise_scale,
level_scale=level_scale,
slope_scale=(slope_scale if model_has_slope
else previous_sample.slope_scale),
weights=weights,
level=level,
slope=(slope if model_has_slope
else previous_sample.slope),
seed=loop_seed)
return sampler_loop_body