pyro/poutine/enum_messenger.py [16:40]:
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
    dist, num_samples = msg["fn"], msg["infer"].get("num_samples")

    # find batch dims that aren't plate dims
    batch_shape = [1] * len(dist.batch_shape)
    for f in msg["cond_indep_stack"]:
        if f.vectorized:
            batch_shape[f.dim] = f.size if f.size > 0 else dist.batch_shape[f.dim]
    batch_shape = tuple(batch_shape)

    # sample a batch
    sample_shape = (num_samples,)
    fat_sample = dist(sample_shape=sample_shape)  # TODO thin before sampling
    assert fat_sample.shape == sample_shape + dist.batch_shape + dist.event_shape
    assert any(d > 1 for d in fat_sample.shape)

    target_shape = (num_samples,) + batch_shape + dist.event_shape

    # if this site has any possible ancestors, sample ancestor indices uniformly
    thin_sample = fat_sample
    if thin_sample.shape != target_shape:

        index = [Ellipsis] + [slice(None)] * (len(thin_sample.shape) - 1)
        squashed_dims = []
        for squashed_dim, squashed_size in zip(range(1, len(thin_sample.shape)), thin_sample.shape[1:]):
            if squashed_size > 1 and (target_shape[squashed_dim] == 1 or squashed_dim == 0):
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -



pyro/poutine/enum_messenger.py [56:80]:
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
    dist, num_samples = msg["fn"], msg["infer"].get("num_samples")

    # find batch dims that aren't plate dims
    batch_shape = [1] * len(dist.batch_shape)
    for f in msg["cond_indep_stack"]:
        if f.vectorized:
            batch_shape[f.dim] = f.size if f.size > 0 else dist.batch_shape[f.dim]
    batch_shape = tuple(batch_shape)

    # sample a batch
    sample_shape = (num_samples,)
    fat_sample = dist(sample_shape=sample_shape)  # TODO thin before sampling
    assert fat_sample.shape == sample_shape + dist.batch_shape + dist.event_shape
    assert any(d > 1 for d in fat_sample.shape)

    target_shape = (num_samples,) + batch_shape + dist.event_shape

    # if this site has any ancestors, choose ancestors from diagonal approximation
    thin_sample = fat_sample
    if thin_sample.shape != target_shape:

        index = [Ellipsis] + [slice(None)] * (len(thin_sample.shape) - 1)
        squashed_dims = []
        for squashed_dim, squashed_size in zip(range(1, len(thin_sample.shape)), thin_sample.shape[1:]):
            if squashed_size > 1 and (target_shape[squashed_dim] == 1 or squashed_dim == 0):
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -



