def node_type_define()

in Sources/x10/swift_bindings/generate_ops.py [0:0]


def node_type_define(op):
  tensor_args = []
  attr_args = []
  has_tensor_list_arg = False
  for arg in op["args"]:
    if arg[1] == "Tensor":
      if has_tensor_list_arg:
        raise ValueError("[Tensor] must be the last argument")
      tensor_args.append(arg)
    elif arg[1] == "[Tensor]":
      if has_tensor_list_arg:
        raise ValueError("[Tensor] must be the last argument")
      tensor_args.append(arg)
      has_tensor_list_arg = True
    else: attr_args.append(arg)
  def format_pretty_print(arg):
    if arg[0] == "shape":
      return ""
    return f"    OpFieldToString(ss, \"{arg[0]}\", {arg[0]}_);\n"
  def format_ctor_arg(arg):
    name, stype, _ = arg
    if name == "shape":
      return f"xla::Shape {name}"
    if stype == "Tensor": return f"const Value& {name}"
    if stype == "[Tensor]":
      return f"absl::Span<const Value> {name}"
    if stype == "Int64": return f"xla::int64 {name}"
    if stype == "Bool": return f"bool {name}"
    if stype == "Float": return f"float {name}"
    if stype == "[Int64]":
      return f"std::vector<xla::int64> {name}"
    if stype == "ScalarType?": return f"c10::optional<at::ScalarType> {name}"
    if stype == "ScalarType":
      return f"at::ScalarType {name}"
    if stype == "AnyScalar":
      return f"at::Scalar {name}"
    if stype in builtin_types:
      return f"{builtin_types[stype][2]} {name}"
    raise ValueError(f"Problem: no such type: {stype}")
  lower_arg_i = 0
  def format_lower_arg(arg):
    nonlocal lower_arg_i
    name, stype, _ = arg
    if name == "shape":
      return "shape()"
    if stype == "Tensor":
      i = lower_arg_i
      lower_arg_i += 1
      return "loctx->GetOutputOp(operand(" + str(i) + "))"
    if stype == "[Tensor]":
      i = lower_arg_i
      lower_arg_i += 1
      return "GetArrayOperands(loctx, operands(), " + str(i) + ")"
    return f"{name}_"
  clone_arg_i = 0
  def format_clone_arg(arg):
    nonlocal clone_arg_i
    name, stype, _ = arg
    if name == "shape":
      return "shape()"
    if stype == "Tensor":
      i = clone_arg_i
      clone_arg_i += 1
      return "operands.at(" + str(i) + ")"
    if stype == "[Tensor]":
      i = clone_arg_i
      clone_arg_i += 1
      return "operands.subspan(" + str(i) + ")"
    return f"{name}_"
  def format_attr_define(arg):
    name, stype, _ = arg
    if name == "shape":
      return ""
    if stype == "Int64": return f"  xla::int64 {name}_;\n"
    if stype == "Bool": return f"  bool {name}_;\n"
    if stype == "Float": return "  float " + name + "_;\n"
    if stype == "ScalarType?": return (f"  c10::optional<at::ScalarType> "
                                       f"{name}_;\n")
    if stype == "ScalarType":
      return f"  at::ScalarType {name}_;\n"
    if stype == "AnyScalar":
      return f"  at::Scalar {name}_;"
    if stype == "[Int64]":
      return f"  std::vector<xla::int64> {name}_;\n"
    if stype in builtin_types:
      return f"  {builtin_types[stype][2]} {name}_;\n"
    raise ValueError(f"Problem: no such type: {stype}")
  def format_attr_init(arg):
    return f",\n        {arg[0]}_(std::move({arg[0]}))"

  shape_fn = None # f"""{{}}\n#error no shape function for {op["op_node_name"]}\n"""
  def resolve_shape_fn(shape_fn):
    for arg in tensor_args:
      if arg[0] == shape_fn: return f"{arg[0]}.shape()"
    if shape_fn == "shape":
      return "shape"
    return f"""{shape_fn}({", ".join(arg[0] for arg in op["args"])})"""
  def format_shape_lower_arg(arg):
    name, stype, _ = arg
    if stype == "Tensor": return f"{name}_ir"
    if stype == "[Tensor]":
      return f"{name}_ir"
    return name
  param_convert_i = 0
  def param_convert(arg):
    nonlocal param_convert_i
    i = param_convert_i
    param_convert_i += 1
    name, stype, _ = arg
    if stype == "[Tensor]":
      return f"       auto {name}_ir = MakeParameterList(&b, {i}, {name}, \"p{i}\");\n"
    else:
      return f"       auto {name}_ir = xla::Parameter(&b, {i}, {name}.shape(), \"p{i}\");\n"

  if "shape_fn" in op:
    shape_fn = resolve_shape_fn(op["shape_fn"])
  if shape_fn == None:
    if op["n_results"] == 1:
      shape_fn = f"""[&]() {{
       xla::XlaBuilder b("InferOutputShape");
{"".join(param_convert(arg) for arg in tensor_args)}       xla::XlaOp result = {op["lower_fn"]}(
         {", ".join(format_shape_lower_arg(arg) for arg in op["args"])});
       return XlaHelpers::ShapeOfXlaOp(result);
     }}"""
    else:
      shape_fn = f"""[&]() {{
       xla::XlaBuilder b("InferOutputShape");
{"".join(param_convert(arg) for arg in tensor_args)}       auto results = {op["lower_fn"]}(
         {", ".join(format_shape_lower_arg(arg) for arg in op["args"])});
       return ShapeOfXlaOpList(results);
     }}"""
  num_outputs = op["n_results"]
  ctx = []
  if "needs_lowering_context" in [i[0] for i in op["extras"]]:
    ctx = ["loctx"]
  tensors_ctor = f"""{{{", ".join(arg[0] for arg in tensor_args if arg[1] == "Tensor")}}}"""
  if has_tensor_list_arg:
    if len(tensor_args) == 1:
      tensors_ctor = tensor_args[-1][0]
    else:
      tensors_ctor = f"""TensorArgsConcat({tensors_ctor}, {tensor_args[-1][0]})"""
  lower_body = None
  if num_outputs == 1:
    lower_body = f"""
    xla::XlaOp result = {op["lower_fn"]}(
        {", ".join([format_lower_arg(arg) for arg in op["args"]] + ctx)});
    return ReturnOp(result, loctx);
  """
  else:
    lower_body = f"""
    auto result = {op["lower_fn"]}(
        {", ".join([format_lower_arg(arg) for arg in op["args"]] + ctx)});
    return ReturnOps(result, loctx);
  """

  return f"""