def __init__()

in pyro/contrib/easyguide/easyguide.py [0:0]


    def __init__(self, guide, sites):
        assert isinstance(sites, list)
        assert sites
        self._guide = weakref.ref(guide)
        self.prototype_sites = sites
        self._site_sizes = {}
        self._site_batch_shapes = {}

        # A group is in a frame only if all its sample sites are in that frame.
        # Thus a group can be subsampled only if all its sites can be subsampled.
        self.common_frames = frozenset.intersection(*(
            frozenset(f for f in site["cond_indep_stack"] if f.vectorized)
            for site in sites))
        rightmost_common_dim = -float('inf')
        if self.common_frames:
            rightmost_common_dim = max(f.dim for f in self.common_frames)

        # Compute flattened concatenated event_shape and split batch_shape into
        # a common batch_shape (which can change each SVI step due to
        # subsampling) and site batch_shapes (which must remain constant size).
        for site in sites:
            site_event_numel = torch.Size(site["fn"].event_shape).numel()
            site_batch_shape = list(site["fn"].batch_shape)
            for f in self.common_frames:
                # Consider this dim part of the common_batch_shape.
                site_batch_shape[f.dim] = 1
            while site_batch_shape and site_batch_shape[0] == 1:
                site_batch_shape = site_batch_shape[1:]
            if len(site_batch_shape) > -rightmost_common_dim:
                raise ValueError(
                    "Group expects all per-site plates to be right of all common plates, "
                    "but found a per-site plate {} on left at site {}"
                    .format(-len(site_batch_shape), repr(site["name"])))
            site_batch_shape = torch.Size(site_batch_shape)
            self._site_batch_shapes[site["name"]] = site_batch_shape
            self._site_sizes[site["name"]] = site_batch_shape.numel() * site_event_numel
        self.event_shape = torch.Size([sum(self._site_sizes.values())])