void DefineSumSegmentTree()

in rlmeta/cc/segment_tree.h [321:396]


void DefineSumSegmentTree(const std::string& type, py::module& m) {
  const std::string pyclass = type + "SumSegmentTree";
  py::class_<SumSegmentTree<T>, std::shared_ptr<SumSegmentTree<T>>>(
      m, pyclass.c_str())
      .def(py::init<int64_t>())
      .def("__len__", &SumSegmentTree<T>::size)
      .def("size", &SumSegmentTree<T>::size)
      .def("capacity", &SumSegmentTree<T>::capacity)
      .def("identity_element", &SumSegmentTree<T>::identity_element)
      .def("__getitem__",
           py::overload_cast<int64_t>(&SumSegmentTree<T>::At, py::const_))
      .def("__getitem__", py::overload_cast<const py::array_t<int64_t>&>(
                              &SumSegmentTree<T>::At, py::const_))
      .def("__getitem__", py::overload_cast<const torch::Tensor&>(
                              &SumSegmentTree<T>::At, py::const_))
      .def("at", py::overload_cast<int64_t>(&SumSegmentTree<T>::At, py::const_))
      .def("at", py::overload_cast<const py::array_t<int64_t>&>(
                     &SumSegmentTree<T>::At, py::const_))
      .def("at", py::overload_cast<const torch::Tensor&>(&SumSegmentTree<T>::At,
                                                         py::const_))
      .def("__setitem__",
           py::overload_cast<int64_t, const T&>(&SumSegmentTree<T>::Update))
      .def("__setitem__",
           py::overload_cast<const py::array_t<int64_t>&, const T&>(
               &SumSegmentTree<T>::Update))
      .def(
          "__setitem__",
          py::overload_cast<const py::array_t<int64_t>&, const py::array_t<T>&>(
              &SumSegmentTree<T>::Update))
      .def("__setitem__", py::overload_cast<const torch::Tensor&, const T&>(
                              &SumSegmentTree<T>::Update))
      .def("__setitem__",
           py::overload_cast<const torch::Tensor&, const torch::Tensor&>(
               &SumSegmentTree<T>::Update))
      .def("update",
           py::overload_cast<int64_t, const T&>(&SumSegmentTree<T>::Update))
      .def("update", py::overload_cast<const py::array_t<int64_t>&, const T&>(
                         &SumSegmentTree<T>::Update))
      .def(
          "update",
          py::overload_cast<const py::array_t<int64_t>&, const py::array_t<T>&>(
              &SumSegmentTree<T>::Update))
      .def("update", py::overload_cast<const torch::Tensor&, const T&>(
                         &SumSegmentTree<T>::Update))
      .def("update",
           py::overload_cast<const torch::Tensor&, const torch::Tensor&>(
               &SumSegmentTree<T>::Update))
      .def("query", py::overload_cast<int64_t, int64_t>(
                        &SumSegmentTree<T>::Query, py::const_))
      .def("query", py::overload_cast<const py::array_t<int64_t>&,
                                      const py::array_t<int64_t>&>(
                        &SumSegmentTree<T>::Query, py::const_))
      .def("query",
           py::overload_cast<const torch::Tensor&, const torch::Tensor&>(
               &SumSegmentTree<T>::Query, py::const_))
      .def("scan_lower_bound",
           py::overload_cast<const T&>(&SumSegmentTree<T>::ScanLowerBound,
                                       py::const_))
      .def("scan_lower_bound",
           py::overload_cast<const py::array_t<T>&>(
               &SumSegmentTree<T>::ScanLowerBound, py::const_))
      .def("scan_lower_bound",
           py::overload_cast<const torch::Tensor&>(
               &SumSegmentTree<T>::ScanLowerBound, py::const_))
      .def(py::pickle(
          [](const SumSegmentTree<T>& s) {
            return py::make_tuple(s.DumpValues());
          },
          [](const py::tuple& t) {
            assert(t.size() == 1);
            const py::array_t<T>& arr = t[0].cast<py::array_t<T>>();
            SumSegmentTree<T> s(arr.size());
            s.LoadValues(arr);
            return s;
          }));
}