in pyro/distributions/hmm.py [0:0]
def _sequential_gaussian_filter_sample(init, trans, sample_shape):
"""
Draws a reparameterized sample from a Markov product of Gaussians via
parallel-scan forward-filter backward-sample.
"""
assert isinstance(init, Gaussian)
assert isinstance(trans, Gaussian)
assert trans.dim() == 2 * init.dim()
assert _is_subshape(trans.batch_shape[:-1], init.batch_shape)
state_dim = trans.dim() // 2
device = trans.precision.device
perm = torch.cat([torch.arange(1 * state_dim, 2 * state_dim, device=device),
torch.arange(0 * state_dim, 1 * state_dim, device=device),
torch.arange(2 * state_dim, 3 * state_dim, device=device)])
# Forward filter, similar to _sequential_gaussian_tensordot().
tape = []
shape = trans.batch_shape[:-1] # Note trans may be unbroadcasted.
gaussian = trans
while gaussian.batch_shape[-1] > 1:
time = gaussian.batch_shape[-1]
even_time = time // 2 * 2
even_part = gaussian[..., :even_time]
x_y = even_part.reshape(shape + (even_time // 2, 2))
x, y = x_y[..., 0], x_y[..., 1]
x = x.event_pad(right=state_dim)
y = y.event_pad(left=state_dim)
joint = (x + y).event_permute(perm)
tape.append(joint)
contracted = joint.marginalize(left=state_dim)
if time > even_time:
contracted = Gaussian.cat((contracted, gaussian[..., -1:]), dim=-1)
gaussian = contracted
gaussian = gaussian[..., 0] + init.event_pad(right=state_dim)
# Backward sample.
shape = sample_shape + init.batch_shape
result = gaussian.rsample(sample_shape).reshape(shape + (2, state_dim))
for joint in reversed(tape):
# The following comments demonstrate two example computations, one
# EVEN, one ODD. Ignoring sample_shape and batch_shape, let each zn be
# a single sampled event of shape (state_dim,).
if joint.batch_shape[-1] == result.size(-2) - 1: # EVEN case.
# Suppose e.g. result = [z0, z2, z4]
cond = result.repeat_interleave(2, dim=-2) # [z0, z0, z2, z2, z4, z4]
cond = cond[..., 1:-1, :] # [z0, z2, z2, z4]
cond = cond.reshape(shape + (-1, 2 * state_dim)) # [z0z2, z2z4]
sample = joint.condition(cond).rsample() # [z1, z3]
sample = torch.nn.functional.pad(sample, (0, 0, 0, 1)) # [z1, z3, 0]
result = torch.stack([
result, # [z0, z2, z4]
sample, # [z1, z3, 0]
], dim=-2) # [[z0, z1], [z2, z3], [z4, 0]]
result = result.reshape(shape + (-1, state_dim)) # [z0, z1, z2, z3, z4, 0]
result = result[..., :-1, :] # [z0, z1, z2, z3, z4]
else: # ODD case.
assert joint.batch_shape[-1] == result.size(-2) - 2
# Suppose e.g. result = [z0, z2, z3]
cond = result[..., :-1, :].repeat_interleave(2, dim=-2) # [z0, z0, z2, z2]
cond = cond[..., 1:-1, :] # [z0, z2]
cond = cond.reshape(shape + (-1, 2 * state_dim)) # [z0z2]
sample = joint.condition(cond).rsample() # [z1]
sample = torch.cat([sample, result[..., -1:, :]], dim=-2) # [z1, z3]
result = torch.stack([
result[..., :-1, :], # [z0, z2]
sample, # [z1, z3]
], dim=-2) # [[z0, z1], [z2, z3]]
result = result.reshape(shape + (-1, state_dim)) # [z0, z1, z2, z3]
return result[..., 1:, :] # [z1, z2, z3, ...]