in tensorflow_probability/python/sts/components/sum.py [0:0]
def __init__(self,
component_ssms,
constant_offset=0.,
observation_noise_scale=None,
initial_state_prior=None,
initial_step=0,
validate_args=False,
name=None,
**linear_gaussian_ssm_kwargs):
"""Build a state space model representing the sum of component models.
Args:
component_ssms: Python `list` containing one or more
`tfd.LinearGaussianStateSpaceModel` instances. The components
will in general implement different time-series models, with possibly
different `latent_size`, but they must have the same `dtype`, event
shape (`num_timesteps` and `observation_size`), and their batch shapes
must broadcast to a compatible batch shape.
constant_offset: `float` `Tensor` of shape broadcasting to
`concat([batch_shape, [num_timesteps]]`) specifying a constant value
added to the sum of outputs from the component models. This allows the
components to model the shifted series
`observed_time_series - constant_offset`.
Default value: `0.`
observation_noise_scale: Optional scalar `float` `Tensor` indicating the
standard deviation of the observation noise. May contain additional
batch dimensions, which must broadcast with the batch shape of elements
in `component_ssms`. If `observation_noise_scale` is specified for the
`AdditiveStateSpaceModel`, the observation noise scales of component
models are ignored. If `None`, the observation noise scale is derived
by summing the noise variances of the component models, i.e.,
`observation_noise_scale = sqrt(sum(
[ssm.observation_noise_scale**2 for ssm in component_ssms]))`.
initial_state_prior: Optional instance of `tfd.MultivariateNormal`
representing a prior distribution on the latent state at time
`initial_step`. If `None`, defaults to the independent priors from
component models, i.e.,
`[component.initial_state_prior for component in component_ssms]`.
Default value: `None`.
initial_step: Optional scalar `int` `Tensor` specifying the starting
timestep.
Default value: 0.
validate_args: Python `bool`. Whether to validate input
with asserts. If `validate_args` is `False`, and the inputs are
invalid, correct behavior is not guaranteed.
Default value: `False`.
name: Python `str` name prefixed to ops created by this class.
Default value: "AdditiveStateSpaceModel".
**linear_gaussian_ssm_kwargs: Optional additional keyword arguments to
to the base `tfd.LinearGaussianStateSpaceModel` constructor.
Raises:
ValueError: if components have different `num_timesteps`.
"""
parameters = dict(locals())
parameters.update(linear_gaussian_ssm_kwargs)
del parameters['linear_gaussian_ssm_kwargs']
with tf.name_scope(name or 'AdditiveStateSpaceModel') as name:
# Check that all components have the same dtype
dtype = tf.debugging.assert_same_float_dtype(component_ssms)
# Convert scalar offsets to canonical shape `[..., num_timesteps]`.
constant_offset = (tf.convert_to_tensor(value=constant_offset,
name='constant_offset',
dtype=dtype) *
tf.ones([1], dtype=dtype))
offset_length = ps.shape(constant_offset)[-1]
assertions = []
# Construct an initial state prior as a block-diagonal combination
# of the component state priors.
if initial_state_prior is None:
initial_state_prior = sts_util.factored_joint_mvn(
[ssm.initial_state_prior for ssm in component_ssms])
dtype = initial_state_prior.dtype
static_num_timesteps = [
tf.get_static_value(ssm.num_timesteps)
for ssm in component_ssms
if tf.get_static_value(ssm.num_timesteps) is not None
]
# If any components have a static value for `num_timesteps`, use that
# value for the additive model. (and check that all other static values
# match it).
if static_num_timesteps:
num_timesteps = static_num_timesteps[0]
if not all([component_timesteps == num_timesteps
for component_timesteps in static_num_timesteps]):
raise ValueError('Additive model components must all have the same '
'number of timesteps '
'(saw: {})'.format(static_num_timesteps))
else:
num_timesteps = component_ssms[0].num_timesteps
if validate_args and len(static_num_timesteps) != len(component_ssms):
assertions += [
tf.debugging.assert_equal( # pylint: disable=g-complex-comprehension
num_timesteps,
ssm.num_timesteps,
message='Additive model components must all have '
'the same number of timesteps.') for ssm in component_ssms
]
# Define the transition and observation models for the additive SSM.
# See the "mathematical details" section of the class docstring for
# further information. Note that we define these as callables to
# handle the fully general case in which some components have time-
# varying dynamics.
def transition_matrix_fn(t):
return tfl.LinearOperatorBlockDiag(
[ssm.get_transition_matrix_for_timestep(t)
for ssm in component_ssms])
def transition_noise_fn(t):
return sts_util.factored_joint_mvn(
[ssm.get_transition_noise_for_timestep(t)
for ssm in component_ssms])
# Build the observation matrix, concatenating (broadcast) observation
# matrices from components. We also take this as an opportunity to enforce
# any dynamic assertions we may have generated above.
broadcast_batch_shape = ps.cast(
sts_util.broadcast_batch_shape(
[ssm.get_observation_matrix_for_timestep(initial_step)
for ssm in component_ssms]), dtype=tf.int32)
broadcast_obs_matrix = tf.ones(
ps.concat([broadcast_batch_shape, [1, 1]], axis=0), dtype=dtype)
if assertions:
with tf.control_dependencies(assertions):
broadcast_obs_matrix = tf.identity(broadcast_obs_matrix)
def observation_matrix_fn(t):
return tfl.LinearOperatorFullMatrix(
tf.concat([ssm.get_observation_matrix_for_timestep(t).to_dense() *
broadcast_obs_matrix for ssm in component_ssms],
axis=-1))
# Broadcast the constant offset across timesteps.
offset_at_step = lambda t: ( # pylint: disable=g-long-lambda
constant_offset if offset_length == 1
else tf.gather(constant_offset, tf.minimum(t, offset_length - 1),
axis=-1)[..., tf.newaxis])
if observation_noise_scale is not None:
observation_noise_scale = tf.convert_to_tensor(
value=observation_noise_scale,
name='observation_noise_scale',
dtype=dtype)
def observation_noise_fn(t):
return tfd.MultivariateNormalDiag(
loc=(sum([ssm.get_observation_noise_for_timestep(t).mean()
for ssm in component_ssms]) + offset_at_step(t)),
scale_diag=observation_noise_scale[..., tf.newaxis])
else:
def observation_noise_fn(t):
offset = offset_at_step(t)
return sts_util.sum_mvns(
[tfd.MultivariateNormalDiag(
loc=offset,
scale_diag=tf.zeros_like(offset))] +
[ssm.get_observation_noise_for_timestep(t)
for ssm in component_ssms])
super(AdditiveStateSpaceModel, self).__init__(
num_timesteps=num_timesteps,
transition_matrix=transition_matrix_fn,
transition_noise=transition_noise_fn,
observation_matrix=observation_matrix_fn,
observation_noise=observation_noise_fn,
initial_state_prior=initial_state_prior,
initial_step=initial_step,
validate_args=validate_args,
name=name,
**linear_gaussian_ssm_kwargs)
self._parameters = parameters