void DefineMinSegmentTree()

in rlmeta/cc/segment_tree.h [399:465]


void DefineMinSegmentTree(const std::string& type, py::module& m) {
  const std::string pyclass = type + "MinSegmentTree";
  py::class_<MinSegmentTree<T>, std::shared_ptr<MinSegmentTree<T>>>(
      m, pyclass.c_str())
      .def(py::init<int64_t>())
      .def("__len__", &MinSegmentTree<T>::size)
      .def("size", &MinSegmentTree<T>::size)
      .def("capacity", &MinSegmentTree<T>::capacity)
      .def("identity_element", &MinSegmentTree<T>::identity_element)
      .def("__getitem__",
           py::overload_cast<int64_t>(&MinSegmentTree<T>::At, py::const_))
      .def("__getitem__", py::overload_cast<const py::array_t<int64_t>&>(
                              &MinSegmentTree<T>::At, py::const_))
      .def("__getitem__", py::overload_cast<const torch::Tensor&>(
                              &MinSegmentTree<T>::At, py::const_))
      .def("at", py::overload_cast<int64_t>(&MinSegmentTree<T>::At, py::const_))
      .def("at", py::overload_cast<const py::array_t<int64_t>&>(
                     &MinSegmentTree<T>::At, py::const_))
      .def("at", py::overload_cast<const torch::Tensor&>(&MinSegmentTree<T>::At,
                                                         py::const_))
      .def("__setitem__",
           py::overload_cast<int64_t, const T&>(&MinSegmentTree<T>::Update))
      .def("__setitem__",
           py::overload_cast<const py::array_t<int64_t>&, const T&>(
               &MinSegmentTree<T>::Update))
      .def(
          "__setitem__",
          py::overload_cast<const py::array_t<int64_t>&, const py::array_t<T>&>(
              &MinSegmentTree<T>::Update))
      .def("__setitem__", py::overload_cast<const torch::Tensor&, const T&>(
                              &MinSegmentTree<T>::Update))
      .def("__setitem__",
           py::overload_cast<const torch::Tensor&, const torch::Tensor&>(
               &MinSegmentTree<T>::Update))
      .def("update",
           py::overload_cast<int64_t, const T&>(&MinSegmentTree<T>::Update))
      .def("update", py::overload_cast<const py::array_t<int64_t>&, const T&>(
                         &MinSegmentTree<T>::Update))
      .def(
          "update",
          py::overload_cast<const py::array_t<int64_t>&, const py::array_t<T>&>(
              &MinSegmentTree<T>::Update))
      .def("update", py::overload_cast<const torch::Tensor&, const T&>(
                         &MinSegmentTree<T>::Update))
      .def("update",
           py::overload_cast<const torch::Tensor&, const torch::Tensor&>(
               &MinSegmentTree<T>::Update))
      .def("query", py::overload_cast<int64_t, int64_t>(
                        &MinSegmentTree<T>::Query, py::const_))
      .def("query", py::overload_cast<const py::array_t<int64_t>&,
                                      const py::array_t<int64_t>&>(
                        &MinSegmentTree<T>::Query, py::const_))
      .def("query",
           py::overload_cast<const torch::Tensor&, const torch::Tensor&>(
               &MinSegmentTree<T>::Query, py::const_))
      .def(py::pickle(
          [](const MinSegmentTree<T>& s) {
            return py::make_tuple(s.DumpValues());
          },
          [](const py::tuple& t) {
            assert(t.size() == 1);
            const py::array_t<T>& arr = t[0].cast<py::array_t<T>>();
            MinSegmentTree<T> s(arr.size());
            s.LoadValues(arr);
            return s;
          }));
}