in src/beanmachine/ppl/compiler/runtime.py [0:0]
def _rv_to_node(self, rv: RVIdentifier) -> BMGNode:
key = MemoizationKey(rv.wrapper, rv.arguments)
if key not in self.rv_map:
if key in self.in_flight:
# TODO: Better error message
raise RecursionError()
self.in_flight.add(key)
try:
# Under what circumstances does a random variable NOT have source code?
# When it is nested inside another rv that has already been compiled!
# See the note in _handle_ordinary_call for details.
if _has_source_code(rv.function):
rewritten_function = self._function_to_bmg_function(rv.function)
else:
rewritten_function = rv.function
# Here we deal with an issue caused by how Python produces the source
# code of a function.
#
# We started with a function that produced a random variable when
# called, and then we made a transformation based on the *source code*
# of that original function. The *source code* of that original function
# might OR might not have been decorated with a random_variable or
# functional decorator. For example, if we have:
#
# @random_variable
# def foo():
# return Normal(0., 1.)
#
# and we have a query on foo() then that is the exact code that
# we rewrite, and therefore the rewritten function that comes back
# is *also* run through the random_variable decorator. But if instead
# we have
#
# def foo():
# return Normal(0., 1.)
#
# bar = random_variable(foo)
#
# and a query on bar(), then when we ask Python for the source code of
# bar, it hands us back the *undecorated* source code for foo, and
# therefore the rewriter produces an undecorated rewritten function.
#
# How can we tell which situation we're in? Well, if we're in the first
# situation then when we call the rewritten function, we'll get back a
# RVID, and if we're in the second situation, we will not.
value = rewritten_function(*rv.arguments)
if isinstance(value, RVIdentifier):
# We have a rewritten function with a decorator already applied.
# Therefore the rewritten form of the *undecorated* function is
# stored in the rv. Call *that* function with the given arguments.
value = value.function(*rv.arguments)
# We now have the value returned by the undecorated random variable
# regardless of whether the source code was decorated or not.
# If we are calling a random_variable then we must have gotten
# back a distribution. This is the first time we have called this
# rv with these arguments -- because we had a cache miss -- and
# therefore we should generate a new sample node. If by contrast
# we are calling a functional then we check below that we got
# back either a graph node or a tensor that we can make into a constant.
if rv.is_random_variable:
value = self.handle_sample(value)
finally:
self.in_flight.remove(key)
if isinstance(value, torch.Tensor):
value = self._bmg.add_constant_tensor(value)
if not isinstance(value, BMGNode):
raise TypeError("A functional must return a tensor.")
self.rv_map[key] = value
return value
return self.rv_map[key]