def check_site_shape()

in pyro/util.py [0:0]


def check_site_shape(site, max_plate_nesting):
    actual_shape = list(site["log_prob"].shape)

    # Compute expected shape.
    expected_shape = []
    for f in site["cond_indep_stack"]:
        if f.dim is not None:
            # Use the specified plate dimension, which counts from the right.
            assert f.dim < 0
            if len(expected_shape) < -f.dim:
                expected_shape = [None] * (-f.dim - len(expected_shape)) + expected_shape
            if expected_shape[f.dim] is not None:
                raise ValueError('\n  '.join([
                    'at site "{}" within plate("{}", dim={}), dim collision'.format(site["name"], f.name, f.dim),
                    'Try setting dim arg in other plates.']))
            expected_shape[f.dim] = f.size
    expected_shape = [-1 if e is None else e for e in expected_shape]

    # Check for plate stack overflow.
    if len(expected_shape) > max_plate_nesting:
        raise ValueError('\n  '.join([
            'at site "{}", plate stack overflow'.format(site["name"]),
            'Try increasing max_plate_nesting to at least {}'.format(len(expected_shape))]))

    # Ignore dimensions left of max_plate_nesting.
    if max_plate_nesting < len(actual_shape):
        actual_shape = actual_shape[len(actual_shape) - max_plate_nesting:]

    # Check for incorrect plate placement on the right of max_plate_nesting.
    for actual_size, expected_size in zip_longest(reversed(actual_shape), reversed(expected_shape), fillvalue=1):
        if expected_size != -1 and expected_size != actual_size:
            raise ValueError('\n  '.join([
                'at site "{}", invalid log_prob shape'.format(site["name"]),
                'Expected {}, actual {}'.format(expected_shape, actual_shape),
                'Try one of the following fixes:',
                '- enclose the batched tensor in a with plate(...): context',
                '- .to_event(...) the distribution being sampled',
                '- .permute() data dimensions']))

    # Check parallel dimensions on the left of max_plate_nesting.
    enum_dim = site["infer"].get("_enumerate_dim")
    if enum_dim is not None:
        if len(site["fn"].batch_shape) >= -enum_dim and site["fn"].batch_shape[enum_dim] != 1:
            raise ValueError('\n  '.join([
                'Enumeration dim conflict at site "{}"'.format(site["name"]),
                'Try increasing pyro.markov history size']))