def check_traceenum_requirements()

in pyro/util.py [0:0]


def check_traceenum_requirements(model_trace, guide_trace):
    """
    Warn if user could easily rewrite the model or guide in a way that would
    clearly avoid invalid dependencies on enumerated variables.

    :class:`~pyro.infer.traceenum_elbo.TraceEnum_ELBO` enumerates over
    synchronized products rather than full cartesian products. Therefore models
    must ensure that no variable outside of an plate depends on an enumerated
    variable inside that plate. Since full dependency checking is impossible,
    this function aims to warn only in cases where models can be easily
    rewitten to be obviously correct.
    """
    enumerated_sites = set(name for name, site in guide_trace.nodes.items()
                           if site["type"] == "sample" and site["infer"].get("enumerate"))
    for role, trace in [('model', model_trace), ('guide', guide_trace)]:
        plate_counters = {}  # for sequential plates only
        enumerated_contexts = defaultdict(set)
        for name, site in trace.nodes.items():
            if site["type"] != "sample":
                continue
            plate_counter = {f.name: f.counter for f in site["cond_indep_stack"] if not f.vectorized}
            context = frozenset(f for f in site["cond_indep_stack"] if f.vectorized)

            # Check that sites outside each independence context precede enumerated sites inside that context.
            for enumerated_context, names in enumerated_contexts.items():
                if not (context < enumerated_context):
                    continue
                names = sorted(n for n in names if not _are_independent(plate_counter, plate_counters[n]))
                if not names:
                    continue
                diff = sorted(f.name for f in enumerated_context - context)
                warnings.warn('\n  '.join([
                    'at {} site "{}", possibly invalid dependency.'.format(role, name),
                    'Expected site "{}" to precede sites "{}"'.format(name, '", "'.join(sorted(names))),
                    'to avoid breaking independence of plates "{}"'.format('", "'.join(diff)),
                ]), RuntimeWarning)

            plate_counters[name] = plate_counter
            if name in enumerated_sites:
                enumerated_contexts[context].add(name)