void initPythonIRBindings()

in torch/csrc/jit/python/python_ir.cpp [192:1048]


void initPythonIRBindings(PyObject* module_) {
  auto m = py::handle(module_).cast<py::module>();

  py::class_<AliasDb, std::shared_ptr<AliasDb>>(m, "AliasDb")
      .def("dump", &AliasDb::dump)
      .def("to_graphviz_str", &AliasDb::toGraphviz)
      .def(
          "may_contain_alias",
          [&](AliasDb& db, Value* v1, Value* v2) {
            return db.mayContainAlias(v1, v2);
          })
      .def("__str__", &AliasDb::toString);

#define GS(name) def(#name, &Graph ::name)
  py::class_<Graph, std::shared_ptr<Graph>>(m, "Graph")
      .def(py::init<>())
      .def(
          "__repr__",
          [&](Graph& g) { return g.toString(global_print_source_ranges); })
      .def("str", &Graph::toString, py::arg("print_source_ranges") = true)
      .def_readonly_static(
          "global_print_source_ranges", &global_print_source_ranges)
      .def_static(
          "set_global_print_source_ranges",
          [&](const bool enabled) { global_print_source_ranges = enabled; },
          py::arg("enabled") = true)
      .def(
          "alias_db",
          [](std::shared_ptr<Graph> g) {
            return std::make_shared<AliasDb>(std::move(g));
          })
      .def(
          "dump_alias_db",
          [](std::shared_ptr<Graph> g) {
            AliasDb db(std::move(g));
            db.dump();
          })
      .def(
          "_export_onnx",
          [](const std::shared_ptr<Graph>& g,
             const std::map<std::string, at::Tensor>& initializers,
             int64_t onnx_opset_version,
             const std::unordered_map<
                 std::string,
                 std::unordered_map<int64_t, std::string>>& dynamic_axes,
             bool defer_weight_export,
             ::torch::onnx::OperatorExportTypes operator_export_type,
             bool strip_doc_string,
             bool keep_initializers_as_inputs,
             const std::map<std::string, int>& custom_opsets,
             bool add_node_names,
             const std::string& onnx_file_path,
             const NodeAttrNameMap& node_attr_to_name) {
            std::string graph;
            std::shared_ptr<::ONNX_NAMESPACE::ModelProto> model_proto;
            RawDataExportMap export_map;
            SymbolDimMap symbol_map;
            bool val_use_external_data_format = false;
            std::tie(
                model_proto,
                export_map,
                symbol_map,
                val_use_external_data_format) =
                export_onnx(
                    g,
                    initializers,
                    onnx_opset_version,
                    dynamic_axes,
                    defer_weight_export,
                    operator_export_type,
                    strip_doc_string,
                    keep_initializers_as_inputs,
                    custom_opsets,
                    add_node_names,
                    val_use_external_data_format,
                    onnx_file_path,
                    node_attr_to_name);
            std::unordered_map<std::string, py::bytes>
                python_serialized_export_map;
            for (auto& kv : export_map) {
              auto t = kv.second;
              size_t copy_bytes = t.element_size() * t.numel();
              // TODO: this is an unnecessary copy. In theory we can directly
              // return the map from identifier to Tensor, but we need some API
              // in Python to get raw `bytes` containing the raw tensor data.
              python_serialized_export_map[kv.first] =
                  py::bytes(static_cast<const char*>(t.data_ptr()), copy_bytes);
            }
            graph = serialize_model_proto_to_string(model_proto);
            return std::make_tuple(
                py::bytes(graph),
                python_serialized_export_map,
                val_use_external_data_format);
          },
          py::arg("initializers"),
          py::arg("onnx_opset_version") = 0,
          py::arg("dynamic_axes"),
          py::arg("defer_weight_export") = false,
          py::arg("operator_export_type") =
              ::torch::onnx::OperatorExportTypes::ONNX,
          py::arg("strip_doc_string") = true,
          py::arg("keep_initializers_as_inputs") = true,
          py::arg("custom_opsets"),
          py::arg("add_node_names") = true,
          py::arg("onnx_file_path") = std::string(),
          py::arg("node_attr_to_name") = NodeAttrNameMap())
      .def(
          "_pretty_print_onnx",
          [](const std::shared_ptr<Graph>& g,
             const std::map<std::string, at::Tensor>& initializers,
             int64_t onnx_opset_version,
             bool defer_weight_export,
             ::torch::onnx::OperatorExportTypes operator_export_type,
             bool google_printer,
             bool keep_initializers_as_inputs,
             const std::map<std::string, int>& custom_opsets,
             bool add_node_names) {
            return pretty_print_onnx(
                g,
                initializers,
                onnx_opset_version,
                defer_weight_export,
                operator_export_type,
                google_printer,
                keep_initializers_as_inputs,
                custom_opsets,
                add_node_names);
          },
          py::arg("initializers"),
          py::arg("onnx_opset_version") = 0,
          py::arg("defer_weight_export") = false,
          py::arg("operator_export_type") =
              ::torch::onnx::OperatorExportTypes::ONNX,
          py::arg("google_printer") = false,
          py::arg("keep_initializers_as_inputs") = true,
          py::arg("custom_opsets"),
          py::arg("add_node_names") = true)
      .def(
          "inputs",
          [](Graph& g) {
            return py::make_iterator(g.inputs().begin(), g.inputs().end());
          },
          py::keep_alive<0, 1>())
      .def(
          "outputs",
          [](Graph& g) {
            return py::make_iterator(g.outputs().begin(), g.outputs().end());
          },
          py::keep_alive<0, 1>())
      // We keep the graph alive while the iterator lives. Destroying
      // nodes might still be hazardous.
      .def(
          "nodes",
          [](Graph& g) {
            return py::make_iterator(g.nodes().begin(), g.nodes().end());
          },
          py::keep_alive<0, 1>())
      .def(
          "findNode",
          [](Graph& g, const std::string& kind, bool recurse) {
            return findNode(g.block(), Symbol::fromQualString(kind), recurse);
          },
          "Find Node",
          py::arg("kind"),
          py::arg("recurse") = true)
      .def(
          "findAllNodes",
          [](Graph& g, const std::string& kind, bool recurse) {
            return findAllNodes(g, Symbol::fromQualString(kind), recurse);
          },
          "Find all nodes",
          py::arg("kind"),
          py::arg("recurse") = true)
      .def("addInput", [](Graph& g) { return g.addInput(); })
      .def("copy", [](Graph& g) { return g.copy(); })
      .GS(eraseInput)
      .GS(eraseOutput)
      .GS(registerOutput)
      .def(
          "create",
          [](Graph& g, const char* str) {
            return g.create(Symbol::fromQualString(str));
          })
      .def(
          "create",
          [](Graph& g, const char* str, size_t noutputs) {
            return g.create(Symbol::fromQualString(str), noutputs);
          })
      .def(
          "create",
          [](Graph& g, const char* str, const std::vector<Value*>& inputs) {
            TORCH_CHECK_VALUE(
                std::all_of(
                    inputs.begin(),
                    inputs.end(),
                    [](Value* v) { return (v != nullptr); }),
                "cannot pass None in inputs");
            return g.create(Symbol::fromQualString(str), inputs);
          })
      .def(
          "create",
          [](Graph& g,
             const char* str,
             const std::vector<Value*>& inputs,
             size_t noutputs) {
            TORCH_CHECK_VALUE(
                std::all_of(
                    inputs.begin(),
                    inputs.end(),
                    [](Value* v) { return (v != nullptr); }),
                "cannot pass None in inputs");
            return g.create(Symbol::fromQualString(str), inputs, noutputs);
          })
      .def("param_node", [](Graph& g) { return g.block()->param_node(); })
      .def("return_node", [](Graph& g) { return g.block()->return_node(); })
      .def(
          "createFusionGroup",
          [](Graph& g) { return g.createWithSubgraph(prim::FusionGroup); })
      .def(
          "createCudaFusionGroup",
          [](Graph& g) { return g.createWithSubgraph(prim::CudaFusionGroup); })
      .def(
          "createClone",
          [](Graph& g, Node* n, py::object fn) {
            return g.createClone(
                n, [&](Value* e) { return fn(e).cast<Value*>(); });
          })
      .GS(appendNode)
      .GS(prependNode)
      .def(
          "makeMultiOutputIntoTuple",
          [](Graph& g) {
            auto tup = g.createTuple(g.outputs());
            tup->insertBefore(g.return_node());
            for (int64_t i = g.outputs().size() - 1; i >= 0; i--) {
              g.eraseOutput(0);
            }
            g.registerOutput(tup->output());
          })
      .def(
          "insertConstant",
          [](Graph& g, const IValue& ival) { return g.insertConstant(ival); })
      .GS(lint)
      .GS(insertNode);
#undef GS

#define VS(name) def(#name, &Value ::name)
  py::class_<Value, unwrapping_shared_ptr<Value>>(m, "Value")
      .def(
          "__repr__",
          [](Value& n) {
            std::stringstream ss;
            ss << n.debugName() << " defined in (" << *n.node() << ")";
            return ss.str();
          })
      .VS(type)
      .VS(setType)
      .def(
          "inferTypeFrom",
          py::overload_cast<const at::Tensor&>(&Value::inferTypeFrom))
      .def(
          "inferTypeFrom",
          py::overload_cast<const c10::intrusive_ptr<c10::ivalue::Object>&>(
              &Value::inferTypeFrom))
      // skip owningGraph because it returns a raw pointer to a otherwise
      // std::shared_ptr stored graph object, and would cause a double free
      .VS(unique)
      .VS(debugName)
      .VS(setDebugName)
      .VS(offset)
      .VS(uses)
      .VS(replaceAllUsesWith)
      .VS(replaceAllUsesAfterNodeWith)
      .def("node", [](Value& v) { return v.node(); })
      .def(
          "setTypeAs",
          [](Value* node, Value* other) {
            node->setType(other->type());
            return node;
          })
      .VS(copyMetadata)
      .VS(isCompleteTensor)
      .VS(requires_grad)
      .def(
          "requiresGrad",
          [](Value& n) {
            return n.type()->expectRef<TensorType>().requiresGrad();
          })
      .def("toIValue", [](Value& n) { return toIValue(&n); })
      .def("type", [](Value& v) { return v.type(); });
#undef VS

  py::class_<Block, unwrapping_shared_ptr<Block>>(m, "Block")
      .def(
          "nodes",
          [](Block& b) {
            return py::make_iterator(b.nodes().begin(), b.nodes().end());
          })
      .def(
          "findNode",
          [](Block& b, const std::string& kind, bool recurse) {
            return findNode(&b, Symbol::fromQualString(kind), recurse);
          },
          "Find Node",
          py::arg("kind"),
          py::arg("recurse") = true)
      .def(
          "findAllNodes",
          [](Block& b, const std::string& kind, bool recurse) {
            return findAllNodes(b, Symbol::fromQualString(kind), recurse);
          },
          "Find all nodes",
          py::arg("kind"),
          py::arg("recurse") = true)
      .def(
          "inputs",
          [](Block& b) {
            return py::make_iterator(b.inputs().begin(), b.inputs().end());
          })
      .def(
          "outputs",
          [](Block& b) {
            return py::make_iterator(b.outputs().begin(), b.outputs().end());
          })
      .def("returnNode", [](Block& b) { return b.return_node(); })
      .def("paramNode", [](Block& b) { return b.param_node(); })
      .def("owningNode", [](Block& b) { return b.owningNode(); })
      .def(
          "addNode",
          [](Block& b, const char* str, const std::vector<Value*>& inputs) {
            return addNodeToBlock(&b, Symbol::fromQualString(str), inputs);
          })
      .def("addInputToBlock", [](Block& b) { return addInputToBlock(&b); })
      .def("registerOutput", [](Block& b, Value* value) {
        return b.registerOutput(value);
      });

#define NS(name) def(#name, &Node ::name)
  py::class_<Node, unwrapping_shared_ptr<Node>>(m, "Node")
      .def(
          "__repr__",
          [](Node& n) {
            std::stringstream ss;
            ss << n;
            return ss.str();
          })
      .def("sourceRange", [](Node& n) { return n.sourceRange().str(); })
      .def("hasMultipleOutputs", [](Node& n) { return n.outputs().size() > 1; })
      .def("inputsSize", [](Node& n) { return n.inputs().size(); })
      .def("outputsSize", [](Node& n) { return n.outputs().size(); })
      .NS(kind)
      .def("owningBlock", [](Node& n) { return n.owningBlock(); })
      .def("inputsAt", [](Node& n, size_t i) { return n.inputs().at(i); })
      .def(
          "inputs",
          [](Node& n) {
            return py::make_iterator(n.inputs().begin(), n.inputs().end());
          })
      .def(
          "schema",
          [](Node& n) {
            std::stringstream ss;
            if (auto sch = n.maybeSchema()) {
              ss << n.schema();
            } else {
              ss << "(no schema)";
            }
            return ss.str();
          })
      .def(
          "outputs",
          [](Node& n) {
            return py::make_iterator(n.outputs().begin(), n.outputs().end());
          })
      .def("outputsAt", [](Node& n, size_t i) { return n.outputs().at(i); })
      .def(
          "findNode",
          [](Node& n, const std::string& kind, bool recurse) {
            return findNode(n.blocks(), Symbol::fromQualString(kind), recurse);
          },
          "Find Node",
          py::arg("kind"),
          py::arg("recurse") = true)
      .def(
          "findAllNodes",
          [](Node& n, const std::string& kind, bool recurse) {
            return findAllNodes(
                n.blocks(), Symbol::fromQualString(kind), recurse);
          },
          "Find all nodes",
          py::arg("kind"),
          py::arg("recurse") = true)
      .def("input", [](Node& n) { return n.input(); })
      .def("output", [](Node& n) { return n.output(); })
      .def(
          "getModuleHierarchy",
          [](Node& n) { return torch::jit::utils::getNodesModuleHierarchy(n); })
      .NS(addInput)
      .NS(copyMetadata)
      .NS(replaceInput)
      .NS(replaceInputWith)
      .NS(replaceAllUsesWith)
      .NS(insertBefore)
      .NS(insertAfter)
      .NS(isBefore)
      .NS(isAfter)
      .NS(moveAfter)
      .NS(moveBefore)
      .NS(removeInput)
      .NS(removeAllInputs)
      .NS(destroy)
      .NS(hasUses)
      .NS(eraseOutput)
      .NS(addOutput)
      .NS(scopeName)
      .NS(isNondeterministic)
      .def(
          "blocks",
          [](Node& n) {
            return py::make_iterator(n.blocks().begin(), n.blocks().end());
          })
      .NS(addBlock)
      .NS(mustBeNone)

#define AS(name) def(#name, &Node::name)
      // methods from Attributes
      .AS(copyAttributes)
      .AS(hasAttributes)
#undef AS
#define AS(name) def(#name, &Node::name##S)
      // The default method names take Symbol, but the string conversion for
      // Symbol you to qualify with attr::. This is not very user friendly
      // for attributes, so expose the string variants instead.
      .AS(hasAttribute)
      .AS(kindOf)
      .AS(removeAttribute)
      .AS(attributeNames)
#undef AS
#define CREATE_ACCESSOR(Kind, method)                                       \
  def(#method "_", [](Node& n, const char* name, Kind##Attr::ValueType v) { \
    return n.method##_(Symbol::attr(name), std::move(v));                   \
  }).def(#method, [](Node& n, const char* name) {                           \
    return n.method(Symbol::attr(name));                                    \
  })
      .CREATE_ACCESSOR(Float, f)
      .CREATE_ACCESSOR(Floats, fs)
      .CREATE_ACCESSOR(Complex, c)
      .CREATE_ACCESSOR(String, s)
      .CREATE_ACCESSOR(Strings, ss)
      .CREATE_ACCESSOR(Int, i)
      .CREATE_ACCESSOR(Ints, is)
      .CREATE_ACCESSOR(Graph, g)
      .CREATE_ACCESSOR(Graphs, gs)
#undef CREATE_ACCESSOR
      // Tensor (t_) -- manually written to unwrap the variable into a tensor.
      .def(
          "t_",
          [](Node& n, const char* name, const torch::autograd::Variable& v) {
            AT_ASSERT(!v.requires_grad());
            return n.t_(Symbol::attr(name), v);
          })
      .def(
          "t",
          [](Node& n, const char* name) { return n.t(Symbol::attr(name)); })
      // Tensors (ts_) -- manually written to unwrap variables into tensors.
      .def(
          "ts_",
          [](Node& n,
             const char* name,
             const std::vector<torch::autograd::Variable>& vs) {
            std::vector<at::Tensor> tensors;
            tensors.reserve(vs.size());
            for (auto& variable : vs) {
              AT_ASSERT(!variable.requires_grad());
              tensors.push_back(variable);
            }
            return n.ts_(Symbol::attr(name), std::move(tensors));
          })
      .def(
          "ts",
          [](Node& n, const char* name) {
            auto tensors = n.ts(Symbol::attr(name));
            std::vector<torch::autograd::Variable> variables;
            variables.reserve(tensors.size());
            for (auto& tensor : tensors) {
              variables.emplace_back(std::move(tensor));
            }
            return variables;
          })
      .def(
          "z_",
          [](Node& n, const char* name, const at::Tensor& v) {
            return n.t_(
                Symbol::attr(name),
                autograd::Variable(v.view(std::vector<int64_t>{}))
                    .set_requires_grad(false));
          })
      .def(
          "z",
          [](Node& n, const char* name) { return n.t(Symbol::attr(name)); })
      .def(
          "zs_",
          [](Node& n, const char* name, TensorsAttr::ValueType v) {
            for (auto& i : v) {
              i = autograd::Variable(i.view(std::vector<int64_t>{}))
                      .set_requires_grad(false);
            }
            return n.ts_(Symbol::attr(name), std::move(v));
          })
      .def(
          "zs",
          [](Node& n, const char* name) { return n.ts(Symbol::attr(name)); })
      .def(
          "pyobj",
          [](Node& n) {
            return py::handle(n.expect<ConcretePythonOp>()->pyobj.get())
                .cast<py::object>();
          })
      .def("cconv", [](Node& n) { return n.expect<ConcretePythonOp>()->cconv; })
      .def(
          "pyname",
          [](Node& n) { return n.expect<ConcretePythonOp>()->name(); })
      .def("scalar_args", [](Node& n) {
        auto op = n.expect<ConcretePythonOp>();
        auto scalars = py::list();
        auto append = scalars.attr("append");
        for (auto& arg : op->scalar_args) {
          append(py::handle(arg.get()));
        }
        return scalars;
      });

  using ::c10::Type;
  py::class_<Type, TypePtr>(m, "Type")
      .def("__repr__", [](Type& t) { return t.annotation_str(); })
      .def(
          "str",
          [](Type& t) {
            std::ostringstream s;
            s << t;
            return s.str();
          })
      .def("kind", [](const Type& t) { return typeKindToString(t.kind()); })
      .def(
          "dim",
          [](Type& t) {
            auto vshape = t.expectRef<TensorType>().sizes();
            return vshape.size() ? py::cast(*vshape.size())
                                 : py::cast<py::none>(Py_None);
          })
      .def(
          "undefined",
          [](Type& t) {
            auto undef = t.expectRef<TensorType>().undefined();
            return undef.has_value() ? py::cast(*undef)
                                     : py::cast<py::none>(Py_None);
          })
      .def(
          "sizes",
          [](Type& t) -> py::object {
            if (auto ptt = t.expect<TensorType>()) {
              if (auto cs = ptt->sizes().concrete_sizes()) {
                return py::cast(*cs);
              }
            }
            return py::none();
          })
      .def(
          "symbolic_sizes",
          [](Type& t) -> py::object {
            if (auto ptt = t.expect<TensorType>()) {
              auto ss = ptt->symbolic_sizes();
              if (!ss.rank().has_value()) {
                return py::none();
              }

              std::vector<int64_t> ss_vals;
              for (size_t i = 0; i < *ss.rank(); ++i) {
                ss_vals.push_back(ss.at(i).value());
              }
              return py::cast(ss_vals);
            }
            return py::none();
          })
      .def(
          "with_sizes",
          [](Type& t, c10::optional<std::vector<c10::optional<int64_t>>> sizes)
              -> py::object {
            auto ptt = t.expect<TensorType>();
            if (!ptt) {
              return py::none();
            }
            if (!sizes) {
              return py::cast(ptt->withSymbolicShapes(c10::SymbolicShape()));
            }
            return py::cast(ptt->withSymbolicShapes(*sizes));
          })
      .def(
          "varyingSizes",
          [](Type& t) -> py::object {
            if (auto ptt = t.expect<TensorType>()) {
              if (auto s = ptt->sizes().sizes()) {
                return py::cast(s.value());
              }
            }
            return py::none();
          })
      .def(
          "strides",
          [](Type& t) -> py::object {
            if (auto ptt = t.expect<TensorType>()) {
              if (auto cs = ptt->strides().concrete_sizes()) {
                return py::cast(*cs);
              }
            }
            return py::none();
          })
      .def(
          "contiguous",
          [](Type& t) {
            return std::static_pointer_cast<Type>(
                t.expectRef<TensorType>().contiguous());
          })
      .def(
          "scalarType",
          [](Type& t) {
            auto scalar_type = t.expectRef<TensorType>().scalarType();
            return (scalar_type) ? toString(*scalar_type) : nullptr;
          })
      .def(
          "device",
          [](Type& t) -> py::object {
            auto device = t.expectRef<TensorType>().device();
            if (!device) {
              return py::none();
            }
            PyObject* thp_device = THPDevice_New(device.value());
            return py::reinterpret_borrow<py::object>(thp_device);
            // return toPyObject(device.value());
          })
      .def(
          "with_device",
          [](Type& t, py::object device) -> py::object {
            at::Device c_device = python::detail::py_object_to_device(device);
            if (auto ptt = t.expect<TensorType>()) {
              return py::cast(ptt->withDevice(c_device));
            }
            return py::none();
          })
      .def(
          "dtype",
          [](Type& t) -> py::object {
            auto scalar_type = t.expectRef<TensorType>().scalarType();
            if (!scalar_type) {
              return py::none();
            }
            THPDtype* thp_dtype = torch::getTHPDtype(*scalar_type);
            py::object dtype =
                py::reinterpret_borrow<py::object>((PyObject*)thp_dtype);
            return dtype;
          })
      .def(
          "with_dtype",
          [](Type& t, py::object dtype) -> py::object {
            at::ScalarType scalar_type =
                python::detail::py_object_to_dtype(dtype);

            if (auto ptt = t.expect<TensorType>()) {
              // auto scalar_type = dtype->scalar_type;
              return py::cast(ptt->withScalarType(scalar_type));
            }
            return py::none();
          })
      .def(
          "__eq__",
          [](const TypePtr& self, const TypePtr& other) {
            if (!other) {
              return false;
            }
            return *self == *other;
          })
      .def(
          "isSubtypeOf",
          [](const TypePtr& self, const TypePtr& other) {
            if (!other) {
              return false;
            }
            return self->isSubtypeOf(other);
          })
      .def(
          "is_interface_type",
          [](const TypePtr& self) {
            return self->castRaw<InterfaceType>() != nullptr;
          })
      .def_property_readonly(
          "annotation_str", [](const std::shared_ptr<Type>& self) {
            return self->annotation_str();
          });

  py::class_<AnyType, Type, AnyTypePtr>(m, "AnyType")
      .def_static("get", &AnyType::get);
  py::class_<NumberType, Type, NumberTypePtr>(m, "NumberType")
      .def_static("get", &NumberType::get);
  py::class_<IntType, Type, IntTypePtr>(m, "IntType")
      .def_static("get", &IntType::get);
  py::class_<FloatType, Type, FloatTypePtr>(m, "FloatType")
      .def_static("get", &FloatType::get);
  py::class_<ComplexType, Type, ComplexTypePtr>(m, "ComplexType")
      .def_static("get", &ComplexType::get);
  py::class_<TensorType, Type, TensorTypePtr>(m, "TensorType")
      .def_static("get", &TensorType::get)
      .def_static("getInferred", &TensorType::getInferred)
      .def_static("create_from_tensor", [](const at::Tensor& t) {
        return TensorType::create(t);
      });
  py::class_<BoolType, Type, BoolTypePtr>(m, "BoolType")
      .def_static("get", &BoolType::get);
  py::class_<StringType, Type, StringTypePtr>(m, "StringType")
      .def_static("get", &StringType::get);
  py::class_<DeviceObjType, Type, DeviceObjTypePtr>(m, "DeviceObjType")
      .def_static("get", &DeviceObjType::get);
  py::class_<StreamObjType, Type, StreamObjTypePtr>(m, "StreamObjType")
      .def_static("get", &StreamObjType::get);
  py::class_<PyObjectType, Type, PyObjectTypePtr>(m, "PyObjectType")
      .def_static("get", &PyObjectType::get);
  py::class_<NoneType, Type, NoneTypePtr>(m, "NoneType")
      .def_static("get", &NoneType::get);

  py::class_<TupleType, Type, TupleTypePtr>(m, "TupleType")
      .def(py::init([](std::vector<TypePtr> a) {
        return TupleType::create(std::move(a));
      }))
      .def("elements", [](TupleType& self) {
        std::vector<TypePtr> types;
        for (const auto& type : self.elements()) {
          types.push_back(type);
        }
        return types;
      });
  py::class_<UnionType, Type, UnionTypePtr>(m, "UnionType")
      .def(py::init(
          [](const std::vector<TypePtr>& a) { return UnionType::create(a); }))
      .def("containedTypes", [](UnionType& self) {
        return self.containedTypes().vec();
      });
  py::class_<ListType, Type, ListTypePtr>(m, "ListType")
      .def(py::init([](TypePtr a) { return ListType::create(a); }))
      .def_static("ofInts", &ListType::ofInts)
      .def_static("ofTensors", &ListType::ofTensors)
      .def_static("ofFloats", &ListType::ofFloats)
      .def_static("ofComplexDoubles", &ListType::ofComplexDoubles)
      .def_static("ofBools", &ListType::ofBools)
      .def("getElementType", &ListType::getElementType);
  py::class_<DictType, Type, DictTypePtr>(m, "DictType")
      .def(py::init([](TypePtr key, TypePtr value) {
        return DictType::create(std::move(key), std::move(value));
      }))
      .def("getKeyType", &DictType::getKeyType)
      .def("getValueType", &DictType::getValueType);
  py::class_<OptionalType, Type, OptionalTypePtr>(m, "OptionalType")
      .def(py::init(
          [](TypePtr a) { return OptionalType::create(std::move(a)); }))
      .def_static("ofTensor", &OptionalType::ofTensor)
      .def("getElementType", &OptionalType::getElementType);
  py::class_<RRefType, Type, RRefTypePtr>(m, "RRefType")
      .def(py::init([](TypePtr a) { return RRefType::create(std::move(a)); }))
      .def("getElementType", &RRefType::getElementType);

  py::class_<FutureType, Type, FutureTypePtr>(m, "FutureType")
      .def(py::init([](TypePtr a) { return FutureType::create(std::move(a)); }))
      .def("getElementType", &FutureType::getElementType);

  py::class_<ClassType, Type, ClassTypePtr>(m, "ClassType")
      .def(py::init([](const std::string& qualified_name) {
        return get_python_cu()->get_class(c10::QualifiedName(qualified_name));
      }))
      .def("name", [](ClassType& self) { return self.name()->name(); })
      .def("qualified_name", [](ClassType& self) {
        return self.name()->qualifiedName();
      });
  py::class_<EnumType, Type, EnumTypePtr>(m, "EnumType")
      .def(py::init([](const std::string& qualified_name,
                       TypePtr value_type,
                       const std::vector<py::object>& enum_names_values) {
        std::vector<std::pair<std::string, IValue>> names_values;
        names_values.reserve(enum_names_values.size());
        for (const auto& enum_name_value : enum_names_values) {
          auto enum_name = py::cast<std::string>(enum_name_value.attr("name"));
          auto enum_value = toIValue(enum_name_value.attr("value"), value_type);
          names_values.emplace_back(std::make_pair(enum_name, enum_value));
        }
        return EnumType::create(
            c10::QualifiedName(qualified_name),
            std::move(value_type),
            std::move(names_values),
            get_python_cu());
      }));
  py::class_<InterfaceType, Type, InterfaceTypePtr>(m, "InterfaceType")
      .def(py::init([](const std::string& qualified_name) {
        return get_python_cu()->get_interface(
            c10::QualifiedName(qualified_name));
      }))
      .def(
          "getMethod",
          [](InterfaceType& self, const std::string& name) {
            return self.getMethod(name);
          },
          py::return_value_policy::reference)
      .def("getMethodNames", [](InterfaceType& self) {
        std::vector<std::string> names;
        for (const FunctionSchema& fn : self.methods()) {
          names.emplace_back(fn.name());
        }
        return names;
      });
  using ::c10::InferredType;
  py::class_<InferredType, std::shared_ptr<InferredType>>(m, "InferredType")
      .def(py::init([](std::shared_ptr<Type> type) {
        return std::make_shared<InferredType>(std::move(type));
      }))
      .def(py::init([](std::string reason) {
        return std::make_shared<InferredType>(std::move(reason));
      }))
      .def(
          "type",
          [](const std::shared_ptr<InferredType>& self) {
            return self->type();
          })
      .def(
          "success",
          [](const std::shared_ptr<InferredType>& self) {
            return self->success();
          })
      .def("reason", [](const std::shared_ptr<InferredType>& self) {
        return self->reason();
      });

  py::class_<Use>(m, "Use")
      .def_readonly("user", &Use::user)
      .def_readonly("offset", &Use::offset)
      .def("isAfter", [](Use& self, Use& other_use) {
        return isBeforeOrAfter(self, other_use, false);
      });

  py::class_<torch::jit::ShapeComputeGraphMapping>(
      m, "_ShapeComputeGraphMapping")
      .def(
          "partial_eval_shape_graph",
          [](ShapeComputeGraphMapping& g) {
            return g.partial_eval_shape_graph;
          })
      .def(
          "graph_output_to_symbolic_shape_dim",
          [](ShapeComputeGraphMapping& g) {
            return g.graph_output_to_symbolic_shape_dim_;
          });
}