in pyro/util.py [0:0]
def check_traceenum_requirements(model_trace, guide_trace):
"""
Warn if user could easily rewrite the model or guide in a way that would
clearly avoid invalid dependencies on enumerated variables.
:class:`~pyro.infer.traceenum_elbo.TraceEnum_ELBO` enumerates over
synchronized products rather than full cartesian products. Therefore models
must ensure that no variable outside of an plate depends on an enumerated
variable inside that plate. Since full dependency checking is impossible,
this function aims to warn only in cases where models can be easily
rewitten to be obviously correct.
"""
enumerated_sites = set(name for name, site in guide_trace.nodes.items()
if site["type"] == "sample" and site["infer"].get("enumerate"))
for role, trace in [('model', model_trace), ('guide', guide_trace)]:
plate_counters = {} # for sequential plates only
enumerated_contexts = defaultdict(set)
for name, site in trace.nodes.items():
if site["type"] != "sample":
continue
plate_counter = {f.name: f.counter for f in site["cond_indep_stack"] if not f.vectorized}
context = frozenset(f for f in site["cond_indep_stack"] if f.vectorized)
# Check that sites outside each independence context precede enumerated sites inside that context.
for enumerated_context, names in enumerated_contexts.items():
if not (context < enumerated_context):
continue
names = sorted(n for n in names if not _are_independent(plate_counter, plate_counters[n]))
if not names:
continue
diff = sorted(f.name for f in enumerated_context - context)
warnings.warn('\n '.join([
'at {} site "{}", possibly invalid dependency.'.format(role, name),
'Expected site "{}" to precede sites "{}"'.format(name, '", "'.join(sorted(names))),
'to avoid breaking independence of plates "{}"'.format('", "'.join(diff)),
]), RuntimeWarning)
plate_counters[name] = plate_counter
if name in enumerated_sites:
enumerated_contexts[context].add(name)