in src/beanmachine/ppl/compiler/runtime.py [0:0]
def handle_sample(self, operand: Any) -> bn.SampleNode: # noqa
"""As we execute the lifted program, this method is called every
time a model function decorated with @bm.random_variable returns; we verify that the
returned value is a distribution that we know how to accumulate into the
graph, and add a sample node to the graph."""
if isinstance(operand, bn.DistributionNode):
return self._bmg.add_sample(operand)
if not isinstance(operand, torch.distributions.Distribution):
# TODO: Better error
raise TypeError("A random_variable is required to return a distribution.")
if isinstance(operand, dist.Bernoulli):
b = self.handle_bernoulli(operand.probs)
return self._bmg.add_sample(b)
if isinstance(operand, dist.Binomial):
b = self.handle_binomial(operand.total_count, operand.probs)
return self._bmg.add_sample(b)
if isinstance(operand, dist.Categorical):
b = self.handle_categorical(operand.probs)
return self._bmg.add_sample(b)
if isinstance(operand, dist.Dirichlet):
b = self.handle_dirichlet(operand.concentration)
return self._bmg.add_sample(b)
if isinstance(operand, dist.Chi2):
b = self.handle_chi2(operand.df)
return self._bmg.add_sample(b)
if isinstance(operand, dist.Gamma):
b = self.handle_gamma(operand.concentration, operand.rate)
return self._bmg.add_sample(b)
if isinstance(operand, dist.HalfCauchy):
b = self.handle_halfcauchy(operand.scale)
return self._bmg.add_sample(b)
if isinstance(operand, dist.Normal):
b = self.handle_normal(operand.mean, operand.stddev)
return self._bmg.add_sample(b)
if isinstance(operand, dist.HalfNormal):
b = self.handle_halfnormal(operand.scale)
return self._bmg.add_sample(b)
if isinstance(operand, dist.StudentT):
b = self.handle_studentt(operand.df, operand.loc, operand.scale)
return self._bmg.add_sample(b)
if isinstance(operand, dist.Uniform):
b = self.handle_uniform(operand.low, operand.high)
return self._bmg.add_sample(b)
# TODO: Get this into alpha order
if isinstance(operand, dist.Beta):
b = self.handle_beta(operand.concentration1, operand.concentration0)
return self._bmg.add_sample(b)
if isinstance(operand, dist.Poisson):
b = self.handle_poisson(operand.rate)
return self._bmg.add_sample(b)
# TODO: Better error
n = type(operand).__name__
raise TypeError(f"Distribution '{n}' is not supported by Bean Machine Graph.")