def handle_sample()

in src/beanmachine/ppl/compiler/runtime.py [0:0]


    def handle_sample(self, operand: Any) -> bn.SampleNode:  # noqa
        """As we execute the lifted program, this method is called every
        time a model function decorated with @bm.random_variable returns; we verify that the
        returned value is a distribution that we know how to accumulate into the
        graph, and add a sample node to the graph."""

        if isinstance(operand, bn.DistributionNode):
            return self._bmg.add_sample(operand)
        if not isinstance(operand, torch.distributions.Distribution):
            # TODO: Better error
            raise TypeError("A random_variable is required to return a distribution.")
        if isinstance(operand, dist.Bernoulli):
            b = self.handle_bernoulli(operand.probs)
            return self._bmg.add_sample(b)
        if isinstance(operand, dist.Binomial):
            b = self.handle_binomial(operand.total_count, operand.probs)
            return self._bmg.add_sample(b)
        if isinstance(operand, dist.Categorical):
            b = self.handle_categorical(operand.probs)
            return self._bmg.add_sample(b)
        if isinstance(operand, dist.Dirichlet):
            b = self.handle_dirichlet(operand.concentration)
            return self._bmg.add_sample(b)
        if isinstance(operand, dist.Chi2):
            b = self.handle_chi2(operand.df)
            return self._bmg.add_sample(b)
        if isinstance(operand, dist.Gamma):
            b = self.handle_gamma(operand.concentration, operand.rate)
            return self._bmg.add_sample(b)
        if isinstance(operand, dist.HalfCauchy):
            b = self.handle_halfcauchy(operand.scale)
            return self._bmg.add_sample(b)
        if isinstance(operand, dist.Normal):
            b = self.handle_normal(operand.mean, operand.stddev)
            return self._bmg.add_sample(b)
        if isinstance(operand, dist.HalfNormal):
            b = self.handle_halfnormal(operand.scale)
            return self._bmg.add_sample(b)
        if isinstance(operand, dist.StudentT):
            b = self.handle_studentt(operand.df, operand.loc, operand.scale)
            return self._bmg.add_sample(b)
        if isinstance(operand, dist.Uniform):
            b = self.handle_uniform(operand.low, operand.high)
            return self._bmg.add_sample(b)
        # TODO: Get this into alpha order
        if isinstance(operand, dist.Beta):
            b = self.handle_beta(operand.concentration1, operand.concentration0)
            return self._bmg.add_sample(b)
        if isinstance(operand, dist.Poisson):
            b = self.handle_poisson(operand.rate)
            return self._bmg.add_sample(b)
        # TODO: Better error
        n = type(operand).__name__
        raise TypeError(f"Distribution '{n}' is not supported by Bean Machine Graph.")