in Sources/x10/swift_bindings/generate_ops.py [0:0]
def c_function_define(op):
args = op["args"]
tensor_args = [
arg for arg in args if arg[1] == "Tensor" or arg[1] == "[Tensor]"
]
tensor_names = [arg[0] for arg in tensor_args]
first_tensor = None
if "result_dtype" in op and op["result_dtype"] in tensor_names:
first_tensor = op["result_dtype"]
if "shape_fn" in op and op["shape_fn"] in tensor_names:
first_tensor = op["shape_fn"]
if not first_tensor:
if tensor_args[0][1] == "[Tensor]":
first_tensor = f"swift_xla::FirstTensor({tensor_args[0][0]})"
else:
first_tensor = tensor_args[0][0]
def listify(l):
if type(l) is list:
return l
return [l]
dtypes = (([None] * op["n_results"])
if "result_dtype" not in op else listify(op["result_dtype"]))
def format_arg_def(arg):
name, stype, _ = arg
if stype == "Tensor": return "OpaqueXLATensor* " + name
if stype == "[Tensor]":
return "OpaqueXLATensorArrayRef " + name
if stype == "Int64": return "int64_t " + name
if stype == "Float": return "float " + name
if stype == "Bool": return f"bool {name}"
if stype == "ScalarType?": return f"Optional_XLAScalarType {name}"
if stype == "ScalarType":
return f"XLATensorScalarType {name}"
if stype == "AnyScalar":
return f"XLAScalar {name}"
if stype == "[Int64]":
return f"Int64ArrayRef {name}"
if stype in builtin_types:
return f"{builtin_types[stype][0]} {name}"
raise ValueError("problem unknown type: " + stype)
def format_arg_ref(arg):
name, stype, _ = arg
if stype == "Tensor": return name + "_ir_value"
if stype == "[Tensor]":
return name + "_ir_value"
if name == "shape":
relement_type = f"{first_tensor}->shape().get().element_type()"
result_dtype_arg = None
if dtypes[0] and first_tensor != dtypes[0]:
for arg in args:
if arg[0] == dtypes[0]:
result_dtype_arg = arg
if result_dtype_arg:
relement_type = (
f"swift_xla::MakeXlaPrimitiveType({format_arg_ref(result_dtype_arg)},"
f" /*device=*/nullptr)")
return ("swift_xla::MakeArrayShapeFromDimensions(shape.slice(), {}, " +
f"{relement_type}, "
f"{first_tensor}->GetDevice().hw_type)")
if stype in builtin_types:
return builtin_types[stype][1](name)
for extra in op["extras"]:
if extra[0] == "canonicalize" and extra[1] == name:
if stype == "[Int64]":
if len(extra) == 4:
return (f"swift_xla::ir::ops::{extra[3]}({extra[2]}_ir_value.shape(),"
f" {name}.slice())")
else:
return f"swift_xla::XlaHelpers::GetCanonicalDimensionIndices({name}.slice(), {extra[2]}_ir_value.shape().rank())"
else:
if len(extra) == 4:
return (
f"swift_xla::ir::ops::{extra[3]}({extra[2]}_ir_value, {name})")
return f"swift_xla::XlaHelpers::GetCanonicalDimensionIndex({name}, {extra[2]}_ir_value.shape().rank())"
if stype == "ScalarType?": return f"{name}.value()"
if stype == "ScalarType":
return f"ToScalarType({name})"
if stype == "AnyScalar":
return f"atScalar({name})"
if stype == "[Int64]":
return f"swift_xla::XlaHelpers::I64List({name}.slice())"
return name
def unpack_arg(arg):
name, stype, _ = arg
if stype == "Tensor": return f" auto {name}_ir_value = {name}->GetIrValue();\n"
if stype == "[Tensor]":
return f" auto {name}_ir_value = swift_xla::UnpackIrValues({name});\n"
return ""
node_ctor = f"""swift_xla::ir::MakeNode<swift_xla::ir::ops::{op["op_node_name"]}>({", ".join(format_arg_ref(arg) for arg in op["args"])})"""
result_type = None
if op["n_results"] == 1:
result_type = "OpaqueXLATensor*"
elif op["n_results"] == 2:
result_type = "OpaqueXLATensor_pair"
elif op["n_results"] == 3:
result_type = "OpaqueXLATensor_tuple_3"
else:
raise ValueError(
f"""{op["c_name"]} has unsupported number of return values {op["n_results"]}"""
)
def format_result(result_i=0, dtype=None):
if not dtype:
dtype = dtypes[result_i]
if not dtype:
return (f"new "
f"swift_xla::XLATensor({first_tensor}->CreateFrom(swift_xla::ir::Value(result_node,"
f" {result_i})))")
if dtype in tensor_names:
return f"new swift_xla::XLATensor({dtype}->CreateFrom(swift_xla::ir::Value(result_node, {result_i})))"
result_dtype_arg = None
for arg in args:
if arg[0] == dtype:
result_dtype_arg = arg
if result_dtype_arg:
return (f"new "
f"swift_xla::XLATensor({first_tensor}->CreateFrom(swift_xla::ir::Value(result_node,"
f" {result_i}), {format_arg_ref(result_dtype_arg)}))")
return (f"new "
f"swift_xla::XLATensor({first_tensor}->CreateFrom(swift_xla::ir::Value(result_node,"
f" {result_i}), at::ScalarType::{dtype}))")
prelude = f"""
{result_type} XLATensor_{op["c_name"]}({", ".join(format_arg_def(arg) for arg in op["args"])}) {{
{"".join(unpack_arg(arg) for arg in op["args"])}
auto result_node = {node_ctor};"""
if op["n_results"] != 1:
tuple_names = []
if op["n_results"] == 2:
tuple_names = ["x", "y"]
else:
tuple_names = [f"v{i}" for i in range(op["n_results"])]
out = f"""{prelude}
{result_type} result;
"""
for i in range(op["n_results"]):
out += f""" result.{tuple_names[i]} = {format_result(i)};
"""
out += """ return result;
}
"""
return out
else:
return f"""{prelude}