def _pyro_sample()

in pyro/poutine/broadcast_messenger.py [0:0]


    def _pyro_sample(msg):
        """
        :param msg: current message at a trace site.
        """
        if msg["done"] or msg["type"] != "sample":
            return

        dist = msg["fn"]
        actual_batch_shape = getattr(dist, "batch_shape", None)
        if actual_batch_shape is not None:
            target_batch_shape = [None if size == 1 else size
                                  for size in actual_batch_shape]
            for f in msg["cond_indep_stack"]:
                if f.dim is None or f.size == -1:
                    continue
                assert f.dim < 0
                target_batch_shape = [None] * (-f.dim - len(target_batch_shape)) + target_batch_shape
                if target_batch_shape[f.dim] is not None and target_batch_shape[f.dim] != f.size:
                    raise ValueError("Shape mismatch inside plate('{}') at site {} dim {}, {} vs {}".format(
                        f.name, msg['name'], f.dim, f.size, target_batch_shape[f.dim]))
                target_batch_shape[f.dim] = f.size
            # Starting from the right, if expected size is None at an index,
            # set it to the actual size if it exists, else 1.
            for i in range(-len(target_batch_shape) + 1, 1):
                if target_batch_shape[i] is None:
                    target_batch_shape[i] = actual_batch_shape[i] if len(actual_batch_shape) >= -i else 1
            msg["fn"] = dist.expand(target_batch_shape)
            if msg["fn"].has_rsample != dist.has_rsample:
                msg["fn"].has_rsample = dist.has_rsample  # copy custom attribute