void register_python_nested_node()

in nestedtensor/csrc/utils/python_nested_node.cpp [45:119]


void register_python_nested_node(py::module m) {
  py::class_<THPPythonNode>(m, "PythonNode")
      .def("__str__", &THPPythonNode::str)
      .def("unbind", &THPPythonNode::unbind)
      .def("__getitem__", &THPPythonNode::operator[])
      .def("__repr__", &THPPythonNode::str)
      .def("__len__", &THPPythonNode::len)
      .def("__eq__", [](THPPythonNode& a_, THPPythonNode& b_) {
        NestedNode<py::object> a = a_.get_node();
        NestedNode<py::object> b = b_.get_node();
        if (!shape_matches(a, b)) {
          return false;
        }
        auto fn = [](py::object a, py::object b) -> bool {
          // return a.equal(b);
          int rv = PyObject_RichCompareBool(a.ptr(), b.ptr(), Py_EQ);
          if (rv == -1) {
              throw py::error_already_set();
          }
          return rv == 1;
        };
        return all<decltype(fn)>(std::move(fn), a, b);
      });

  add_thp_node<THPSizeNode>(
      m, "SizeNode", [](THPSizeNode& a_, THPSizeNode& b_) {
        SizeNode a = a_.get_node();
        SizeNode b = b_.get_node();
        if (!shape_matches(a, b)) {
          return false;
        }
        auto fn = [](std::vector<int64_t> a, std::vector<int64_t> b) {
          for (size_t i = 0; i < a.size(); i++) {
            if (a[i] != b[i]) {
              return false;
            }
          }
          return true;
        };
        return all<decltype(fn)>(std::move(fn), a, b);
      });

  add_thp_node<THPIValueNode>(
      m, "IValueNode", [](THPIValueNode& a_, THPIValueNode& b_) {
        auto a = a_.get_node();
        auto b = b_.get_node();
        if (!shape_matches(a, b)) {
          return false;
        }
        auto fn1 = [](auto i, auto j) { return (*i.type()) == (*j.type()); };
        if (!all<decltype(fn1)>(std::move(fn1), a, b)) {
          return false;
        }
        auto fn2 = [](auto a, auto b) {
          if (a.isInt()) {
            return a.toInt() == b.toInt();
          }
          if (a.isIntList()) {
            auto a_ = a.toIntList();
            auto b_ = b.toIntList();
            for (size_t i = 0; i < a_.size(); i++) {
              if (a_[i] != b_[i]) {
                return false;
              }
            }
            return true;
          }
          TORCH_CHECK(false, "Type not supported for comparison.");
        };
        return all<decltype(fn2)>(std::move(fn2), a, b);
      });

  m.def("as_nested_node", &as_nested_node);
  m.def("map", &py_map);
}