def gen_binary_nnc()

in fx/nnc_compile.py [0:0]


def gen_binary_nnc(op):
    def is_nnc_obj(x):
        return isinstance(x, te.Placeholder) or isinstance(x, te.Tensor)
    def gen_op_nnc(inp_shapes, args):
        if is_nnc_obj(args[0]) and is_nnc_obj(args[1]):
            A_shape, A_dtype = inp_shapes[0]
            B_shape, B_dtype = inp_shapes[1]
            A, B = args

            def index_or_broadcast(shape, *args):
                out = []
                for idx, arg in enumerate(args):
                    if idx >= len(shape): continue
                    if shape[idx] == 1:
                        out.append(to_expr(0))
                    else:
                        out.append(arg)
                return out

            def f(*idxs):
                return op(A.load(index_or_broadcast(A_shape, *idxs)), B.load(index_or_broadcast(B_shape, *idxs)))
            return f
        else:
            if is_nnc_obj(args[0]):
                def f(*idxs):
                    return op(args[0].load(idxs), to_expr(args[1]))
                return f
            else:
                def f(*idxs):
                    return op(to_expr(args[0]), args[1].load(idxs))
                return f

    return gen_op_nnc