in nestedtensor/csrc/scripts/binaryops.py [0:0]
def create_template_map(ops):
template_map = {}
for op in ops:
op_reg, op_args = op[1].split("(", 1)
op_args = "(" + op_args
variant = None
if "." in op_reg:
op_name, variant = op_reg.split(".", 1)
else:
op_name = op_reg
for b in get_binary_functions():
if op_name == b:
if variant is None:
template_map[op_reg] = BINARY_OP_DEFAULT.format(op=b)
if variant == "Tensor":
if "Scalar & alpha" in op[0]:
template_map[op_reg] = BINARY_OP_SCALAR.format(op=b)
else:
template_map[op_reg] = BINARY_OP.format(op=b)
if variant == "out":
if "Scalar & alpha" in op[0]:
template_map[op_reg] = BINARY_OUT_OP_SCALAR.format(op=b)
else:
template_map[op_reg] = BINARY_OUT_OP.format(op=b)
if op_name == b + "_":
if variant is None:
template_map[op_reg] = BINARY_INPLACE_OP_DEFAULT.format(op=b)
if variant == "Tensor":
if "Scalar & alpha" in op[0]:
template_map[op_reg] = BINARY_INPLACE_OP_SCALAR.format(op=b)
else:
template_map[op_reg] = BINARY_INPLACE_OP.format(op=b)
return template_map