in pyro/util.py [0:0]
def check_model_guide_match(model_trace, guide_trace, max_plate_nesting=math.inf):
"""
:param pyro.poutine.Trace model_trace: Trace object of the model
:param pyro.poutine.Trace guide_trace: Trace object of the guide
:raises: RuntimeWarning, ValueError
Checks the following assumptions:
1. Each sample site in the model also appears in the guide and is not
marked auxiliary.
2. Each sample site in the guide either appears in the model or is marked,
auxiliary via ``infer={'is_auxiliary': True}``.
3. Each :class:``~pyro.plate`` statement in the guide also appears in the
model.
4. At each sample site that appears in both the model and guide, the model
and guide agree on sample shape.
"""
# Check ordinary sample sites.
guide_vars = set(name for name, site in guide_trace.nodes.items()
if site["type"] == "sample"
if type(site["fn"]).__name__ != "_Subsample")
aux_vars = set(name for name, site in guide_trace.nodes.items()
if site["type"] == "sample"
if site["infer"].get("is_auxiliary"))
model_vars = set(name for name, site in model_trace.nodes.items()
if site["type"] == "sample" and not site["is_observed"]
if type(site["fn"]).__name__ != "_Subsample")
enum_vars = set(name for name, site in model_trace.nodes.items()
if site["type"] == "sample" and not site["is_observed"]
if type(site["fn"]).__name__ != "_Subsample"
if site["infer"].get("_enumerate_dim") is not None
if name not in guide_vars)
if aux_vars & model_vars:
warnings.warn("Found auxiliary vars in the model: {}".format(aux_vars & model_vars))
if not (guide_vars <= model_vars | aux_vars):
warnings.warn("Found non-auxiliary vars in guide but not model, "
"consider marking these infer={{'is_auxiliary': True}}:\n{}".format(
guide_vars - aux_vars - model_vars))
if not (model_vars <= guide_vars | enum_vars):
warnings.warn("Found vars in model but not guide: {}".format(model_vars - guide_vars - enum_vars))
# Check shapes agree.
for name in model_vars & guide_vars:
model_site = model_trace.nodes[name]
guide_site = guide_trace.nodes[name]
if hasattr(model_site["fn"], "event_dim") and hasattr(guide_site["fn"], "event_dim"):
if model_site["fn"].event_dim != guide_site["fn"].event_dim:
raise ValueError("Model and guide event_dims disagree at site '{}': {} vs {}".format(
name, model_site["fn"].event_dim, guide_site["fn"].event_dim))
if hasattr(model_site["fn"], "shape") and hasattr(guide_site["fn"], "shape"):
model_shape = model_site["fn"].shape(*model_site["args"], **model_site["kwargs"])
guide_shape = guide_site["fn"].shape(*guide_site["args"], **guide_site["kwargs"])
if model_shape == guide_shape:
continue
# Allow broadcasting outside of max_plate_nesting.
if len(model_shape) > max_plate_nesting:
model_shape = model_shape[len(model_shape) - max_plate_nesting - model_site["fn"].event_dim:]
if len(guide_shape) > max_plate_nesting:
guide_shape = guide_shape[len(guide_shape) - max_plate_nesting - guide_site["fn"].event_dim:]
if model_shape == guide_shape:
continue
for model_size, guide_size in zip_longest(reversed(model_shape), reversed(guide_shape), fillvalue=1):
if model_size != guide_size:
raise ValueError("Model and guide shapes disagree at site '{}': {} vs {}".format(
name, model_shape, guide_shape))
# Check subsample sites introduced by plate.
model_vars = set(name for name, site in model_trace.nodes.items()
if site["type"] == "sample" and not site["is_observed"]
if type(site["fn"]).__name__ == "_Subsample")
guide_vars = set(name for name, site in guide_trace.nodes.items()
if site["type"] == "sample"
if type(site["fn"]).__name__ == "_Subsample")
if not (guide_vars <= model_vars):
warnings.warn("Found plate statements in guide but not model: {}".format(guide_vars - model_vars))