in Sources/x10/swift_bindings/generate_ops.py [0:0]
def node_type_define(op):
tensor_args = []
attr_args = []
has_tensor_list_arg = False
for arg in op["args"]:
if arg[1] == "Tensor":
if has_tensor_list_arg:
raise ValueError("[Tensor] must be the last argument")
tensor_args.append(arg)
elif arg[1] == "[Tensor]":
if has_tensor_list_arg:
raise ValueError("[Tensor] must be the last argument")
tensor_args.append(arg)
has_tensor_list_arg = True
else: attr_args.append(arg)
def format_pretty_print(arg):
if arg[0] == "shape":
return ""
return f" OpFieldToString(ss, \"{arg[0]}\", {arg[0]}_);\n"
def format_ctor_arg(arg):
name, stype, _ = arg
if name == "shape":
return f"xla::Shape {name}"
if stype == "Tensor": return f"const Value& {name}"
if stype == "[Tensor]":
return f"absl::Span<const Value> {name}"
if stype == "Int64": return f"xla::int64 {name}"
if stype == "Bool": return f"bool {name}"
if stype == "Float": return f"float {name}"
if stype == "[Int64]":
return f"std::vector<xla::int64> {name}"
if stype == "ScalarType?": return f"c10::optional<at::ScalarType> {name}"
if stype == "ScalarType":
return f"at::ScalarType {name}"
if stype == "AnyScalar":
return f"at::Scalar {name}"
if stype in builtin_types:
return f"{builtin_types[stype][2]} {name}"
raise ValueError(f"Problem: no such type: {stype}")
lower_arg_i = 0
def format_lower_arg(arg):
nonlocal lower_arg_i
name, stype, _ = arg
if name == "shape":
return "shape()"
if stype == "Tensor":
i = lower_arg_i
lower_arg_i += 1
return "loctx->GetOutputOp(operand(" + str(i) + "))"
if stype == "[Tensor]":
i = lower_arg_i
lower_arg_i += 1
return "GetArrayOperands(loctx, operands(), " + str(i) + ")"
return f"{name}_"
clone_arg_i = 0
def format_clone_arg(arg):
nonlocal clone_arg_i
name, stype, _ = arg
if name == "shape":
return "shape()"
if stype == "Tensor":
i = clone_arg_i
clone_arg_i += 1
return "operands.at(" + str(i) + ")"
if stype == "[Tensor]":
i = clone_arg_i
clone_arg_i += 1
return "operands.subspan(" + str(i) + ")"
return f"{name}_"
def format_attr_define(arg):
name, stype, _ = arg
if name == "shape":
return ""
if stype == "Int64": return f" xla::int64 {name}_;\n"
if stype == "Bool": return f" bool {name}_;\n"
if stype == "Float": return " float " + name + "_;\n"
if stype == "ScalarType?": return (f" c10::optional<at::ScalarType> "
f"{name}_;\n")
if stype == "ScalarType":
return f" at::ScalarType {name}_;\n"
if stype == "AnyScalar":
return f" at::Scalar {name}_;"
if stype == "[Int64]":
return f" std::vector<xla::int64> {name}_;\n"
if stype in builtin_types:
return f" {builtin_types[stype][2]} {name}_;\n"
raise ValueError(f"Problem: no such type: {stype}")
def format_attr_init(arg):
return f",\n {arg[0]}_(std::move({arg[0]}))"
shape_fn = None # f"""{{}}\n#error no shape function for {op["op_node_name"]}\n"""
def resolve_shape_fn(shape_fn):
for arg in tensor_args:
if arg[0] == shape_fn: return f"{arg[0]}.shape()"
if shape_fn == "shape":
return "shape"
return f"""{shape_fn}({", ".join(arg[0] for arg in op["args"])})"""
def format_shape_lower_arg(arg):
name, stype, _ = arg
if stype == "Tensor": return f"{name}_ir"
if stype == "[Tensor]":
return f"{name}_ir"
return name
param_convert_i = 0
def param_convert(arg):
nonlocal param_convert_i
i = param_convert_i
param_convert_i += 1
name, stype, _ = arg
if stype == "[Tensor]":
return f" auto {name}_ir = MakeParameterList(&b, {i}, {name}, \"p{i}\");\n"
else:
return f" auto {name}_ir = xla::Parameter(&b, {i}, {name}.shape(), \"p{i}\");\n"
if "shape_fn" in op:
shape_fn = resolve_shape_fn(op["shape_fn"])
if shape_fn == None:
if op["n_results"] == 1:
shape_fn = f"""[&]() {{
xla::XlaBuilder b("InferOutputShape");
{"".join(param_convert(arg) for arg in tensor_args)} xla::XlaOp result = {op["lower_fn"]}(
{", ".join(format_shape_lower_arg(arg) for arg in op["args"])});
return XlaHelpers::ShapeOfXlaOp(result);
}}"""
else:
shape_fn = f"""[&]() {{
xla::XlaBuilder b("InferOutputShape");
{"".join(param_convert(arg) for arg in tensor_args)} auto results = {op["lower_fn"]}(
{", ".join(format_shape_lower_arg(arg) for arg in op["args"])});
return ShapeOfXlaOpList(results);
}}"""
num_outputs = op["n_results"]
ctx = []
if "needs_lowering_context" in [i[0] for i in op["extras"]]:
ctx = ["loctx"]
tensors_ctor = f"""{{{", ".join(arg[0] for arg in tensor_args if arg[1] == "Tensor")}}}"""
if has_tensor_list_arg:
if len(tensor_args) == 1:
tensors_ctor = tensor_args[-1][0]
else:
tensors_ctor = f"""TensorArgsConcat({tensors_ctor}, {tensor_args[-1][0]})"""
lower_body = None
if num_outputs == 1:
lower_body = f"""
xla::XlaOp result = {op["lower_fn"]}(
{", ".join([format_lower_arg(arg) for arg in op["args"]] + ctx)});
return ReturnOp(result, loctx);
"""
else:
lower_body = f"""
auto result = {op["lower_fn"]}(
{", ".join([format_lower_arg(arg) for arg in op["args"]] + ctx)});
return ReturnOps(result, loctx);
"""
return f"""