def check_model_guide_match()

in pyro/util.py [0:0]


def check_model_guide_match(model_trace, guide_trace, max_plate_nesting=math.inf):
    """
    :param pyro.poutine.Trace model_trace: Trace object of the model
    :param pyro.poutine.Trace guide_trace: Trace object of the guide
    :raises: RuntimeWarning, ValueError

    Checks the following assumptions:
    1. Each sample site in the model also appears in the guide and is not
        marked auxiliary.
    2. Each sample site in the guide either appears in the model or is marked,
        auxiliary via ``infer={'is_auxiliary': True}``.
    3. Each :class:``~pyro.plate`` statement in the guide also appears in the
        model.
    4. At each sample site that appears in both the model and guide, the model
        and guide agree on sample shape.
    """
    # Check ordinary sample sites.
    guide_vars = set(name for name, site in guide_trace.nodes.items()
                     if site["type"] == "sample"
                     if type(site["fn"]).__name__ != "_Subsample")
    aux_vars = set(name for name, site in guide_trace.nodes.items()
                   if site["type"] == "sample"
                   if site["infer"].get("is_auxiliary"))
    model_vars = set(name for name, site in model_trace.nodes.items()
                     if site["type"] == "sample" and not site["is_observed"]
                     if type(site["fn"]).__name__ != "_Subsample")
    enum_vars = set(name for name, site in model_trace.nodes.items()
                    if site["type"] == "sample" and not site["is_observed"]
                    if type(site["fn"]).__name__ != "_Subsample"
                    if site["infer"].get("_enumerate_dim") is not None
                    if name not in guide_vars)
    if aux_vars & model_vars:
        warnings.warn("Found auxiliary vars in the model: {}".format(aux_vars & model_vars))
    if not (guide_vars <= model_vars | aux_vars):
        warnings.warn("Found non-auxiliary vars in guide but not model, "
                      "consider marking these infer={{'is_auxiliary': True}}:\n{}".format(
                          guide_vars - aux_vars - model_vars))
    if not (model_vars <= guide_vars | enum_vars):
        warnings.warn("Found vars in model but not guide: {}".format(model_vars - guide_vars - enum_vars))

    # Check shapes agree.
    for name in model_vars & guide_vars:
        model_site = model_trace.nodes[name]
        guide_site = guide_trace.nodes[name]

        if hasattr(model_site["fn"], "event_dim") and hasattr(guide_site["fn"], "event_dim"):
            if model_site["fn"].event_dim != guide_site["fn"].event_dim:
                raise ValueError("Model and guide event_dims disagree at site '{}': {} vs {}".format(
                    name, model_site["fn"].event_dim, guide_site["fn"].event_dim))

        if hasattr(model_site["fn"], "shape") and hasattr(guide_site["fn"], "shape"):
            model_shape = model_site["fn"].shape(*model_site["args"], **model_site["kwargs"])
            guide_shape = guide_site["fn"].shape(*guide_site["args"], **guide_site["kwargs"])
            if model_shape == guide_shape:
                continue

            # Allow broadcasting outside of max_plate_nesting.
            if len(model_shape) > max_plate_nesting:
                model_shape = model_shape[len(model_shape) - max_plate_nesting - model_site["fn"].event_dim:]
            if len(guide_shape) > max_plate_nesting:
                guide_shape = guide_shape[len(guide_shape) - max_plate_nesting - guide_site["fn"].event_dim:]
            if model_shape == guide_shape:
                continue
            for model_size, guide_size in zip_longest(reversed(model_shape), reversed(guide_shape), fillvalue=1):
                if model_size != guide_size:
                    raise ValueError("Model and guide shapes disagree at site '{}': {} vs {}".format(
                        name, model_shape, guide_shape))

    # Check subsample sites introduced by plate.
    model_vars = set(name for name, site in model_trace.nodes.items()
                     if site["type"] == "sample" and not site["is_observed"]
                     if type(site["fn"]).__name__ == "_Subsample")
    guide_vars = set(name for name, site in guide_trace.nodes.items()
                     if site["type"] == "sample"
                     if type(site["fn"]).__name__ == "_Subsample")
    if not (guide_vars <= model_vars):
        warnings.warn("Found plate statements in guide but not model: {}".format(guide_vars - model_vars))