in src/beanmachine/ppl/compiler/fix_multiary_ops.py [0:0]
def _needs_fixing(self, n: bn.BMGNode) -> bool:
# A binary operator is fixable if:
#
# * There is more than one output OR the single output is NOT the given operation
# * At least one of the left or right inputs is a binary operator with only
# one output.
#
# Let us say the operator is addition, we are looking for stuff like:
#
# A B
# \ /
# + C
# \ /
# + D
# \ /
# +
#
# to turn it into
#
# A B C D
# \ | | /
# sum
#
# Why do we have these conditions?
#
# * Consider the (A + B) + C node. We do not want to fix it.
# If there is exactly one output and it is an addition, then
# this node's output is itself a candidate for fixing; we can skip
# this one and fix it instead. No need to do extra work we're just
# going to throw away.
#
# * Any addition with two or more outputs is an addition that is
# deduplicated. We do not want to eliminate it; doing so causes
# the deduplicated work to be done twice. That is, if we have
#
# A B
# \ /
# + C
# \ /
# E + D
# \ / \ /
# * +
#
# Then the bottom addition node is NOT fixable but the A + B + C addition
# is fixable. The desired final graph is:
#
# A B C
# \ | /
# E sum D
# \ / \ /
# * +
#
# and NOT
#
# A B C
# \ | /
# E sum A B C D
# \ / \ | | /
# * sum
#
# Why not? Because our metrics are graph size and amount of arithmetic
# performed when evaluating the graph in BMG.
#
# * The original graph has eight edges, nine nodes, and computes three additions:
# t1 = A + B, t2 = t1 + C, t3 = t2 + D
# * The desired graph has seven edges, eight nodes, and computes three additions:
# t1 = sum(A, B, C) requires two additions, and t2 = t1 + D is one more.
# The bad graph has nine edges, eight nodes, and computes five additions:
# sum(A, B, C) does two additions and sum(A, B, C, D) does three.
#
# The desired graph is a clear win in its reduced edge and node count without
# actually doing more math. The bad graph is in every way worse than the
# desired graph.
return (
isinstance(n, self._operator)
and len(n.inputs) == 2
and not self._single_output_is_operator(n)
and (
self._addition_single_output_is_operator(n.inputs[0])
or self._addition_single_output_is_operator(n.inputs[1])
)
)