in fx/nnc_compile.py [0:0]
def gen_binary_nnc(op):
def is_nnc_obj(x):
return isinstance(x, te.Placeholder) or isinstance(x, te.Tensor)
def gen_op_nnc(inp_shapes, args):
if is_nnc_obj(args[0]) and is_nnc_obj(args[1]):
A_shape, A_dtype = inp_shapes[0]
B_shape, B_dtype = inp_shapes[1]
A, B = args
def index_or_broadcast(shape, *args):
out = []
for idx, arg in enumerate(args):
if idx >= len(shape): continue
if shape[idx] == 1:
out.append(to_expr(0))
else:
out.append(arg)
return out
def f(*idxs):
return op(A.load(index_or_broadcast(A_shape, *idxs)), B.load(index_or_broadcast(B_shape, *idxs)))
return f
else:
if is_nnc_obj(args[0]):
def f(*idxs):
return op(args[0].load(idxs), to_expr(args[1]))
return f
else:
def f(*idxs):
return op(to_expr(args[0]), args[1].load(idxs))
return f
return gen_op_nnc