def _needs_fixing()

in src/beanmachine/ppl/compiler/fix_multiary_ops.py [0:0]


    def _needs_fixing(self, n: bn.BMGNode) -> bool:
        # A binary operator is fixable if:
        #
        # * There is more than one output OR the single output is NOT the given operation
        # * At least one of the left or right inputs is a binary operator with only
        #   one output.
        #
        # Let us say the operator is addition, we are looking for stuff like:
        #
        #  A   B
        #   \ /
        #    +   C
        #     \ /
        #      +   D
        #       \ /
        #        +
        #
        # to turn it into
        #
        #  A  B C  D
        #   \ | | /
        #     sum
        #
        # Why do we have these conditions?
        #
        # * Consider the (A + B) + C node. We do not want to fix it.
        #   If there is exactly one output and it is an addition, then
        #   this node's output is itself a candidate for fixing; we can skip
        #   this one and fix it instead. No need to do extra work we're just
        #   going to throw away.
        #
        # * Any addition with two or more outputs is an addition that is
        #   deduplicated. We do not want to eliminate it; doing so causes
        #   the deduplicated work to be done twice. That is, if we have
        #
        #  A   B
        #   \ /
        #    +   C
        #     \ /
        #  E   +   D
        #   \ / \ /
        #    *   +
        #
        # Then the bottom addition node is NOT fixable but the A + B + C addition
        # is fixable. The desired final graph is:
        #
        #   A  B  C
        #    \ | /
        #  E  sum  D
        #   \ / \ /
        #    *   +
        #
        # and NOT
        #
        #   A  B  C
        #    \ | /
        #  E  sum    A  B C  D
        #   \ /       \ | | /
        #    *          sum
        #
        # Why not? Because our metrics are graph size and amount of arithmetic
        # performed when evaluating the graph in BMG.
        #
        # * The original graph has eight edges, nine nodes, and computes three additions:
        #   t1 = A + B, t2 = t1 + C, t3 = t2 + D
        # * The desired graph has seven edges, eight nodes, and computes three additions:
        #   t1 = sum(A, B, C) requires two additions, and t2 = t1 + D is one more.
        # The bad graph has nine edges, eight nodes, and computes five additions:
        # sum(A, B, C) does two additions and sum(A, B, C, D) does three.
        #
        # The desired graph is a clear win in its reduced edge and node count without
        # actually doing more math. The bad graph is in every way worse than the
        # desired graph.
        return (
            isinstance(n, self._operator)
            and len(n.inputs) == 2
            and not self._single_output_is_operator(n)
            and (
                self._addition_single_output_is_operator(n.inputs[0])
                or self._addition_single_output_is_operator(n.inputs[1])
            )
        )