in pyro/distributions/hmm.py [0:0]
def __init__(self, initial_dist, transition_matrix, transition_dist,
observation_matrix, observation_dist,
validate_args=None, duration=None):
assert initial_dist.has_rsample
assert initial_dist.event_dim == 1
assert (isinstance(transition_matrix, torch.Tensor) and
transition_matrix.dim() >= 2)
assert transition_dist.has_rsample
assert transition_dist.event_dim == 1
assert (isinstance(observation_matrix, torch.Tensor) and
observation_matrix.dim() >= 2)
assert observation_dist.has_rsample
assert observation_dist.event_dim == 1
hidden_dim, obs_dim = observation_matrix.shape[-2:]
assert initial_dist.event_shape == (hidden_dim,)
assert transition_matrix.shape[-2:] == (hidden_dim, hidden_dim)
assert transition_dist.event_shape == (hidden_dim,)
assert observation_dist.event_shape == (obs_dim,)
shape = broadcast_shape(initial_dist.batch_shape + (1,),
transition_matrix.shape[:-2],
transition_dist.batch_shape,
observation_matrix.shape[:-2],
observation_dist.batch_shape)
batch_shape, time_shape = shape[:-1], shape[-1:]
event_shape = time_shape + (obs_dim,)
super().__init__(duration, batch_shape, event_shape, validate_args=validate_args)
# Expand eagerly.
if initial_dist.batch_shape != batch_shape:
initial_dist = initial_dist.expand(batch_shape)
if transition_matrix.shape[:-2] != batch_shape + time_shape:
transition_matrix = transition_matrix.expand(
batch_shape + time_shape + (hidden_dim, hidden_dim))
if transition_dist.batch_shape != batch_shape + time_shape:
transition_dist = transition_dist.expand(batch_shape + time_shape)
if observation_matrix.shape[:-2] != batch_shape + time_shape:
observation_matrix = observation_matrix.expand(
batch_shape + time_shape + (hidden_dim, obs_dim))
if observation_dist.batch_shape != batch_shape + time_shape:
observation_dist = observation_dist.expand(batch_shape + time_shape)
# Extract observation transforms.
transforms = []
while True:
if isinstance(observation_dist, torch.distributions.Independent):
observation_dist = observation_dist.base_dist
elif isinstance(observation_dist, torch.distributions.TransformedDistribution):
transforms = observation_dist.transforms + transforms
observation_dist = observation_dist.base_dist
else:
break
if not observation_dist.event_shape:
observation_dist = Independent(observation_dist, 1)
self.hidden_dim = hidden_dim
self.obs_dim = obs_dim
self.initial_dist = initial_dist
self.transition_matrix = transition_matrix
self.transition_dist = transition_dist
self.observation_matrix = observation_matrix
self.observation_dist = observation_dist
self.transforms = transforms