in Sources/x10/swift_bindings/generate_ops.py [0:0]
def swift_wrapper_define(op):
args = op["args"]
results = op["results"]
def format_swift_arg(arg):
name, stype, (is_explicit, full_stype) = arg
if is_explicit:
return f"{name}: {full_stype}"
return f"_ {name}: {full_stype}"
def format_result_type(r):
return f"{r[0]} {r[1]}" if r[0] else r[1]
def format_tuple_packing(ridx):
r = results[ridx]
tag = "xy"[ridx] if len(results) == 2 else f"v{ridx}"
value = f"Tensor(_xlaHandle: tuple_output.{tag})"
return f"{r[0]}: {value}" if r[0] else value
generics = format_args((f"\n {k}: {v}" for k, v in op["generics"].items()),
comma=",",
ending="\n ")
if generics:
generics = f"<{generics}>"
args_gen = format_args(("\n " + format_swift_arg(arg) for arg in args),
comma=",",
ending="\n ")
results_gen = results[0][1] if len(
results
) == 1 else f"({format_args((format_result_type(r) for r in results))})"
def format_defer(arg):
name, stype, (is_explicit, full_stype) = arg
if stype == "Tensor":
return f"\n defer {{ _fixLifetime({name}) }}"
return ""
defers = format_args((format_defer(arg) for arg in args), comma="")
def format_arg_ref(arg):
name, stype, (is_explicit, full_stype) = arg
if stype == "Tensor":
return f"{name}.xlaHandle"
if stype == "AnyScalar":
return f"{name}.xlaScalar"
return name
body = ""
last_tensor = None
for arg in args:
if arg[1] == "Tensor":
if last_tensor:
body += (f" checkSameDevice({last_tensor[0]}.device, "
f"{arg[0]}.device)\n")
if last_tensor[2][1] == arg[2][1]:
body += f" checkSamePrecision({last_tensor[0]}, {arg[0]})\n"
else:
last_tensor = arg
withCounter = 0
for arg in args:
if arg[1][0] == "[": # is array type.
withCounter += 1
body += f"""{" " * withCounter} return {arg[0]}.withArrayRef {{ {arg[0]} in\n"""
dispatch = f"""XLATensor_{op["c_name"]}({format_args(format_arg_ref(arg) for arg in args)})"""
if len(results) == 1:
body += f"""{" " * withCounter} return Tensor(_xlaHandle: {dispatch})
"""
else:
body += f"""{" " * withCounter} let tuple_output = {dispatch}
"""
body += f"""{" " * withCounter} return ({format_args(format_tuple_packing(ridx) for ridx in range(len(results)))})
"""
for withCounter in range(withCounter, 0, -1):
body += f"""{" " * withCounter} }}\n"""
protection = "" if op[
"protection"] == "internal" else f"""{op["protection"]} """
return f"""