in pyro/util.py [0:0]
def check_site_shape(site, max_plate_nesting):
actual_shape = list(site["log_prob"].shape)
# Compute expected shape.
expected_shape = []
for f in site["cond_indep_stack"]:
if f.dim is not None:
# Use the specified plate dimension, which counts from the right.
assert f.dim < 0
if len(expected_shape) < -f.dim:
expected_shape = [None] * (-f.dim - len(expected_shape)) + expected_shape
if expected_shape[f.dim] is not None:
raise ValueError('\n '.join([
'at site "{}" within plate("{}", dim={}), dim collision'.format(site["name"], f.name, f.dim),
'Try setting dim arg in other plates.']))
expected_shape[f.dim] = f.size
expected_shape = [-1 if e is None else e for e in expected_shape]
# Check for plate stack overflow.
if len(expected_shape) > max_plate_nesting:
raise ValueError('\n '.join([
'at site "{}", plate stack overflow'.format(site["name"]),
'Try increasing max_plate_nesting to at least {}'.format(len(expected_shape))]))
# Ignore dimensions left of max_plate_nesting.
if max_plate_nesting < len(actual_shape):
actual_shape = actual_shape[len(actual_shape) - max_plate_nesting:]
# Check for incorrect plate placement on the right of max_plate_nesting.
for actual_size, expected_size in zip_longest(reversed(actual_shape), reversed(expected_shape), fillvalue=1):
if expected_size != -1 and expected_size != actual_size:
raise ValueError('\n '.join([
'at site "{}", invalid log_prob shape'.format(site["name"]),
'Expected {}, actual {}'.format(expected_shape, actual_shape),
'Try one of the following fixes:',
'- enclose the batched tensor in a with plate(...): context',
'- .to_event(...) the distribution being sampled',
'- .permute() data dimensions']))
# Check parallel dimensions on the left of max_plate_nesting.
enum_dim = site["infer"].get("_enumerate_dim")
if enum_dim is not None:
if len(site["fn"].batch_shape) >= -enum_dim and site["fn"].batch_shape[enum_dim] != 1:
raise ValueError('\n '.join([
'Enumeration dim conflict at site "{}"'.format(site["name"]),
'Try increasing pyro.markov history size']))