def __init__()

in pyro/distributions/hmm.py [0:0]


    def __init__(self, initial_dist, transition_matrix, transition_dist,
                 observation_matrix, observation_dist,
                 validate_args=None, duration=None):
        assert initial_dist.has_rsample
        assert initial_dist.event_dim == 1
        assert (isinstance(transition_matrix, torch.Tensor) and
                transition_matrix.dim() >= 2)
        assert transition_dist.has_rsample
        assert transition_dist.event_dim == 1
        assert (isinstance(observation_matrix, torch.Tensor) and
                observation_matrix.dim() >= 2)
        assert observation_dist.has_rsample
        assert observation_dist.event_dim == 1

        hidden_dim, obs_dim = observation_matrix.shape[-2:]
        assert initial_dist.event_shape == (hidden_dim,)
        assert transition_matrix.shape[-2:] == (hidden_dim, hidden_dim)
        assert transition_dist.event_shape == (hidden_dim,)
        assert observation_dist.event_shape == (obs_dim,)
        shape = broadcast_shape(initial_dist.batch_shape + (1,),
                                transition_matrix.shape[:-2],
                                transition_dist.batch_shape,
                                observation_matrix.shape[:-2],
                                observation_dist.batch_shape)
        batch_shape, time_shape = shape[:-1], shape[-1:]
        event_shape = time_shape + (obs_dim,)
        super().__init__(duration, batch_shape, event_shape, validate_args=validate_args)

        # Expand eagerly.
        if initial_dist.batch_shape != batch_shape:
            initial_dist = initial_dist.expand(batch_shape)
        if transition_matrix.shape[:-2] != batch_shape + time_shape:
            transition_matrix = transition_matrix.expand(
                batch_shape + time_shape + (hidden_dim, hidden_dim))
        if transition_dist.batch_shape != batch_shape + time_shape:
            transition_dist = transition_dist.expand(batch_shape + time_shape)
        if observation_matrix.shape[:-2] != batch_shape + time_shape:
            observation_matrix = observation_matrix.expand(
                batch_shape + time_shape + (hidden_dim, obs_dim))
        if observation_dist.batch_shape != batch_shape + time_shape:
            observation_dist = observation_dist.expand(batch_shape + time_shape)

        # Extract observation transforms.
        transforms = []
        while True:
            if isinstance(observation_dist, torch.distributions.Independent):
                observation_dist = observation_dist.base_dist
            elif isinstance(observation_dist, torch.distributions.TransformedDistribution):
                transforms = observation_dist.transforms + transforms
                observation_dist = observation_dist.base_dist
            else:
                break
        if not observation_dist.event_shape:
            observation_dist = Independent(observation_dist, 1)

        self.hidden_dim = hidden_dim
        self.obs_dim = obs_dim
        self.initial_dist = initial_dist
        self.transition_matrix = transition_matrix
        self.transition_dist = transition_dist
        self.observation_matrix = observation_matrix
        self.observation_dist = observation_dist
        self.transforms = transforms