def c_function_define()

in Sources/x10/swift_bindings/generate_ops.py [0:0]


def c_function_define(op):
  args = op["args"]
  tensor_args = [
      arg for arg in args if arg[1] == "Tensor" or arg[1] == "[Tensor]"
  ]
  tensor_names = [arg[0] for arg in tensor_args]
  first_tensor = None
  if "result_dtype" in op and op["result_dtype"] in tensor_names:
    first_tensor = op["result_dtype"]
  if "shape_fn" in op and op["shape_fn"] in tensor_names:
    first_tensor = op["shape_fn"]
  if not first_tensor:
    if tensor_args[0][1] == "[Tensor]":
      first_tensor = f"swift_xla::FirstTensor({tensor_args[0][0]})"
    else:
      first_tensor = tensor_args[0][0]

  def listify(l):
    if type(l) is list:
      return l
    return [l]

  dtypes = (([None] * op["n_results"])
            if "result_dtype" not in op else listify(op["result_dtype"]))

  def format_arg_def(arg):
    name, stype, _ = arg
    if stype == "Tensor": return "OpaqueXLATensor* " + name
    if stype == "[Tensor]":
      return "OpaqueXLATensorArrayRef " + name
    if stype == "Int64": return "int64_t " + name
    if stype == "Float": return "float " + name
    if stype == "Bool": return f"bool {name}"
    if stype == "ScalarType?": return f"Optional_XLAScalarType {name}"
    if stype == "ScalarType":
      return f"XLATensorScalarType {name}"
    if stype == "AnyScalar":
      return f"XLAScalar {name}"
    if stype == "[Int64]":
      return f"Int64ArrayRef {name}"
    if stype in builtin_types:
      return f"{builtin_types[stype][0]} {name}"
    raise ValueError("problem unknown type: " + stype)
  def format_arg_ref(arg):
    name, stype, _ = arg
    if stype == "Tensor": return name + "_ir_value"
    if stype == "[Tensor]":
      return name + "_ir_value"
    if name == "shape":
      relement_type = f"{first_tensor}->shape().get().element_type()"
      result_dtype_arg = None
      if dtypes[0] and first_tensor != dtypes[0]:
        for arg in args:
          if arg[0] == dtypes[0]:
            result_dtype_arg = arg
      if result_dtype_arg:
        relement_type = (
            f"swift_xla::MakeXlaPrimitiveType({format_arg_ref(result_dtype_arg)},"
            f" /*device=*/nullptr)")
      return ("swift_xla::MakeArrayShapeFromDimensions(shape.slice(), {}, " +
              f"{relement_type}, "
              f"{first_tensor}->GetDevice().hw_type)")
    if stype in builtin_types:
      return builtin_types[stype][1](name)
    for extra in op["extras"]:
      if extra[0] == "canonicalize" and extra[1] == name:
        if stype == "[Int64]":
          if len(extra) == 4:
            return (f"swift_xla::ir::ops::{extra[3]}({extra[2]}_ir_value.shape(),"
                    f" {name}.slice())")
          else:
            return f"swift_xla::XlaHelpers::GetCanonicalDimensionIndices({name}.slice(), {extra[2]}_ir_value.shape().rank())"
        else:
          if len(extra) == 4:
            return (
                f"swift_xla::ir::ops::{extra[3]}({extra[2]}_ir_value, {name})")
          return f"swift_xla::XlaHelpers::GetCanonicalDimensionIndex({name}, {extra[2]}_ir_value.shape().rank())"
    if stype == "ScalarType?": return f"{name}.value()"
    if stype == "ScalarType":
      return f"ToScalarType({name})"
    if stype == "AnyScalar":
      return f"atScalar({name})"
    if stype == "[Int64]":
      return f"swift_xla::XlaHelpers::I64List({name}.slice())"
    return name
  def unpack_arg(arg):
    name, stype, _ = arg
    if stype == "Tensor": return f"  auto {name}_ir_value = {name}->GetIrValue();\n"
    if stype == "[Tensor]":
      return f"  auto {name}_ir_value = swift_xla::UnpackIrValues({name});\n"
    return ""
  node_ctor = f"""swift_xla::ir::MakeNode<swift_xla::ir::ops::{op["op_node_name"]}>({", ".join(format_arg_ref(arg) for arg in op["args"])})"""
  result_type = None
  if op["n_results"] == 1:
    result_type = "OpaqueXLATensor*"
  elif op["n_results"] == 2:
    result_type = "OpaqueXLATensor_pair"
  elif op["n_results"] == 3:
    result_type = "OpaqueXLATensor_tuple_3"
  else:
    raise ValueError(
        f"""{op["c_name"]} has unsupported number of return values {op["n_results"]}"""
    )

  def format_result(result_i=0, dtype=None):
    if not dtype:
      dtype = dtypes[result_i]
    if not dtype:
      return (f"new "
              f"swift_xla::XLATensor({first_tensor}->CreateFrom(swift_xla::ir::Value(result_node,"
              f" {result_i})))")
    if dtype in tensor_names:
      return f"new swift_xla::XLATensor({dtype}->CreateFrom(swift_xla::ir::Value(result_node, {result_i})))"
    result_dtype_arg = None
    for arg in args:
      if arg[0] == dtype:
        result_dtype_arg = arg
    if result_dtype_arg:
      return (f"new "
              f"swift_xla::XLATensor({first_tensor}->CreateFrom(swift_xla::ir::Value(result_node,"
              f" {result_i}), {format_arg_ref(result_dtype_arg)}))")
    return (f"new "
            f"swift_xla::XLATensor({first_tensor}->CreateFrom(swift_xla::ir::Value(result_node,"
            f" {result_i}), at::ScalarType::{dtype}))")

  prelude = f"""
{result_type} XLATensor_{op["c_name"]}({", ".join(format_arg_def(arg) for arg in op["args"])}) {{
{"".join(unpack_arg(arg) for arg in op["args"])}
  auto result_node = {node_ctor};"""
  if op["n_results"] != 1:
    tuple_names = []
    if op["n_results"] == 2:
      tuple_names = ["x", "y"]
    else:
      tuple_names = [f"v{i}" for i in range(op["n_results"])]
    out = f"""{prelude}
  {result_type} result;
"""
    for i in range(op["n_results"]):
      out += f"""  result.{tuple_names[i]} = {format_result(i)};
"""
    out += """  return result;
}
"""
    return out
  else:
    return f"""{prelude}