def _sequential_gaussian_filter_sample()

in pyro/distributions/hmm.py [0:0]


def _sequential_gaussian_filter_sample(init, trans, sample_shape):
    """
    Draws a reparameterized sample from a Markov product of Gaussians via
    parallel-scan forward-filter backward-sample.
    """
    assert isinstance(init, Gaussian)
    assert isinstance(trans, Gaussian)
    assert trans.dim() == 2 * init.dim()
    assert _is_subshape(trans.batch_shape[:-1], init.batch_shape)
    state_dim = trans.dim() // 2
    device = trans.precision.device
    perm = torch.cat([torch.arange(1 * state_dim, 2 * state_dim, device=device),
                      torch.arange(0 * state_dim, 1 * state_dim, device=device),
                      torch.arange(2 * state_dim, 3 * state_dim, device=device)])

    # Forward filter, similar to _sequential_gaussian_tensordot().
    tape = []
    shape = trans.batch_shape[:-1]  # Note trans may be unbroadcasted.
    gaussian = trans
    while gaussian.batch_shape[-1] > 1:
        time = gaussian.batch_shape[-1]
        even_time = time // 2 * 2
        even_part = gaussian[..., :even_time]
        x_y = even_part.reshape(shape + (even_time // 2, 2))
        x, y = x_y[..., 0], x_y[..., 1]
        x = x.event_pad(right=state_dim)
        y = y.event_pad(left=state_dim)
        joint = (x + y).event_permute(perm)
        tape.append(joint)
        contracted = joint.marginalize(left=state_dim)
        if time > even_time:
            contracted = Gaussian.cat((contracted, gaussian[..., -1:]), dim=-1)
        gaussian = contracted
    gaussian = gaussian[..., 0] + init.event_pad(right=state_dim)

    # Backward sample.
    shape = sample_shape + init.batch_shape
    result = gaussian.rsample(sample_shape).reshape(shape + (2, state_dim))
    for joint in reversed(tape):
        # The following comments demonstrate two example computations, one
        # EVEN, one ODD.  Ignoring sample_shape and batch_shape, let each zn be
        # a single sampled event of shape (state_dim,).
        if joint.batch_shape[-1] == result.size(-2) - 1:  # EVEN case.
            # Suppose e.g. result = [z0, z2, z4]
            cond = result.repeat_interleave(2, dim=-2)  # [z0, z0, z2, z2, z4, z4]
            cond = cond[..., 1:-1, :]  # [z0, z2, z2, z4]
            cond = cond.reshape(shape + (-1, 2 * state_dim))  # [z0z2, z2z4]
            sample = joint.condition(cond).rsample()  # [z1, z3]
            sample = torch.nn.functional.pad(sample, (0, 0, 0, 1))  # [z1, z3, 0]
            result = torch.stack([
                result,  # [z0, z2, z4]
                sample,  # [z1, z3, 0]
            ], dim=-2)  # [[z0, z1], [z2, z3], [z4, 0]]
            result = result.reshape(shape + (-1, state_dim))  # [z0, z1, z2, z3, z4, 0]
            result = result[..., :-1, :]  # [z0, z1, z2, z3, z4]
        else:  # ODD case.
            assert joint.batch_shape[-1] == result.size(-2) - 2
            # Suppose e.g. result = [z0, z2, z3]
            cond = result[..., :-1, :].repeat_interleave(2, dim=-2)  # [z0, z0, z2, z2]
            cond = cond[..., 1:-1, :]  # [z0, z2]
            cond = cond.reshape(shape + (-1, 2 * state_dim))  # [z0z2]
            sample = joint.condition(cond).rsample()  # [z1]
            sample = torch.cat([sample, result[..., -1:, :]], dim=-2)  # [z1, z3]
            result = torch.stack([
                result[..., :-1, :],  # [z0, z2]
                sample,  # [z1, z3]
            ], dim=-2)  # [[z0, z1], [z2, z3]]
            result = result.reshape(shape + (-1, state_dim))  # [z0, z1, z2, z3]

    return result[..., 1:, :]  # [z1, z2, z3, ...]