def swift_wrapper_define()

in Sources/x10/swift_bindings/generate_ops.py [0:0]


def swift_wrapper_define(op):
  args = op["args"]
  results = op["results"]

  def format_swift_arg(arg):
    name, stype, (is_explicit, full_stype) = arg
    if is_explicit:
      return f"{name}: {full_stype}"
    return f"_ {name}: {full_stype}"

  def format_result_type(r):
    return f"{r[0]} {r[1]}" if r[0] else r[1]

  def format_tuple_packing(ridx):
    r = results[ridx]
    tag = "xy"[ridx] if len(results) == 2 else f"v{ridx}"
    value = f"Tensor(_xlaHandle: tuple_output.{tag})"
    return f"{r[0]}: {value}" if r[0] else value

  generics = format_args((f"\n    {k}: {v}" for k, v in op["generics"].items()),
                         comma=",",
                         ending="\n  ")
  if generics:
    generics = f"<{generics}>"
  args_gen = format_args(("\n    " + format_swift_arg(arg) for arg in args),
                         comma=",",
                         ending="\n  ")
  results_gen = results[0][1] if len(
      results
  ) == 1 else f"({format_args((format_result_type(r) for r in results))})"

  def format_defer(arg):
    name, stype, (is_explicit, full_stype) = arg
    if stype == "Tensor":
      return f"\n    defer {{ _fixLifetime({name}) }}"
    return ""

  defers = format_args((format_defer(arg) for arg in args), comma="")

  def format_arg_ref(arg):
    name, stype, (is_explicit, full_stype) = arg
    if stype == "Tensor":
      return f"{name}.xlaHandle"
    if stype == "AnyScalar":
      return f"{name}.xlaScalar"
    return name

  body = ""
  last_tensor = None
  for arg in args:
    if arg[1] == "Tensor":
      if last_tensor:
        body += (f"    checkSameDevice({last_tensor[0]}.device, "
                 f"{arg[0]}.device)\n")
        if last_tensor[2][1] == arg[2][1]:
          body += f"    checkSamePrecision({last_tensor[0]}, {arg[0]})\n"
      else:
        last_tensor = arg
  withCounter = 0
  for arg in args:
    if arg[1][0] == "[":  # is array type.
      withCounter += 1
      body += f"""{"  " * withCounter}  return {arg[0]}.withArrayRef {{ {arg[0]} in\n"""
  dispatch = f"""XLATensor_{op["c_name"]}({format_args(format_arg_ref(arg) for arg in args)})"""
  if len(results) == 1:
    body += f"""{"  " * withCounter}    return Tensor(_xlaHandle: {dispatch})
"""
  else:
    body += f"""{"  " * withCounter}    let tuple_output = {dispatch}
"""
    body += f"""{"  " * withCounter}    return ({format_args(format_tuple_packing(ridx) for ridx in range(len(results)))})
"""

  for withCounter in range(withCounter, 0, -1):
    body += f"""{"  " * withCounter}  }}\n"""

  protection = "" if op[
      "protection"] == "internal" else f"""{op["protection"]} """
  return f"""