in pyro/contrib/easyguide/easyguide.py [0:0]
def __init__(self, guide, sites):
assert isinstance(sites, list)
assert sites
self._guide = weakref.ref(guide)
self.prototype_sites = sites
self._site_sizes = {}
self._site_batch_shapes = {}
# A group is in a frame only if all its sample sites are in that frame.
# Thus a group can be subsampled only if all its sites can be subsampled.
self.common_frames = frozenset.intersection(*(
frozenset(f for f in site["cond_indep_stack"] if f.vectorized)
for site in sites))
rightmost_common_dim = -float('inf')
if self.common_frames:
rightmost_common_dim = max(f.dim for f in self.common_frames)
# Compute flattened concatenated event_shape and split batch_shape into
# a common batch_shape (which can change each SVI step due to
# subsampling) and site batch_shapes (which must remain constant size).
for site in sites:
site_event_numel = torch.Size(site["fn"].event_shape).numel()
site_batch_shape = list(site["fn"].batch_shape)
for f in self.common_frames:
# Consider this dim part of the common_batch_shape.
site_batch_shape[f.dim] = 1
while site_batch_shape and site_batch_shape[0] == 1:
site_batch_shape = site_batch_shape[1:]
if len(site_batch_shape) > -rightmost_common_dim:
raise ValueError(
"Group expects all per-site plates to be right of all common plates, "
"but found a per-site plate {} on left at site {}"
.format(-len(site_batch_shape), repr(site["name"])))
site_batch_shape = torch.Size(site_batch_shape)
self._site_batch_shapes[site["name"]] = site_batch_shape
self._site_sizes[site["name"]] = site_batch_shape.numel() * site_event_numel
self.event_shape = torch.Size([sum(self._site_sizes.values())])