in rlmeta/cc/segment_tree.h [399:465]
void DefineMinSegmentTree(const std::string& type, py::module& m) {
const std::string pyclass = type + "MinSegmentTree";
py::class_<MinSegmentTree<T>, std::shared_ptr<MinSegmentTree<T>>>(
m, pyclass.c_str())
.def(py::init<int64_t>())
.def("__len__", &MinSegmentTree<T>::size)
.def("size", &MinSegmentTree<T>::size)
.def("capacity", &MinSegmentTree<T>::capacity)
.def("identity_element", &MinSegmentTree<T>::identity_element)
.def("__getitem__",
py::overload_cast<int64_t>(&MinSegmentTree<T>::At, py::const_))
.def("__getitem__", py::overload_cast<const py::array_t<int64_t>&>(
&MinSegmentTree<T>::At, py::const_))
.def("__getitem__", py::overload_cast<const torch::Tensor&>(
&MinSegmentTree<T>::At, py::const_))
.def("at", py::overload_cast<int64_t>(&MinSegmentTree<T>::At, py::const_))
.def("at", py::overload_cast<const py::array_t<int64_t>&>(
&MinSegmentTree<T>::At, py::const_))
.def("at", py::overload_cast<const torch::Tensor&>(&MinSegmentTree<T>::At,
py::const_))
.def("__setitem__",
py::overload_cast<int64_t, const T&>(&MinSegmentTree<T>::Update))
.def("__setitem__",
py::overload_cast<const py::array_t<int64_t>&, const T&>(
&MinSegmentTree<T>::Update))
.def(
"__setitem__",
py::overload_cast<const py::array_t<int64_t>&, const py::array_t<T>&>(
&MinSegmentTree<T>::Update))
.def("__setitem__", py::overload_cast<const torch::Tensor&, const T&>(
&MinSegmentTree<T>::Update))
.def("__setitem__",
py::overload_cast<const torch::Tensor&, const torch::Tensor&>(
&MinSegmentTree<T>::Update))
.def("update",
py::overload_cast<int64_t, const T&>(&MinSegmentTree<T>::Update))
.def("update", py::overload_cast<const py::array_t<int64_t>&, const T&>(
&MinSegmentTree<T>::Update))
.def(
"update",
py::overload_cast<const py::array_t<int64_t>&, const py::array_t<T>&>(
&MinSegmentTree<T>::Update))
.def("update", py::overload_cast<const torch::Tensor&, const T&>(
&MinSegmentTree<T>::Update))
.def("update",
py::overload_cast<const torch::Tensor&, const torch::Tensor&>(
&MinSegmentTree<T>::Update))
.def("query", py::overload_cast<int64_t, int64_t>(
&MinSegmentTree<T>::Query, py::const_))
.def("query", py::overload_cast<const py::array_t<int64_t>&,
const py::array_t<int64_t>&>(
&MinSegmentTree<T>::Query, py::const_))
.def("query",
py::overload_cast<const torch::Tensor&, const torch::Tensor&>(
&MinSegmentTree<T>::Query, py::const_))
.def(py::pickle(
[](const MinSegmentTree<T>& s) {
return py::make_tuple(s.DumpValues());
},
[](const py::tuple& t) {
assert(t.size() == 1);
const py::array_t<T>& arr = t[0].cast<py::array_t<T>>();
MinSegmentTree<T> s(arr.size());
s.LoadValues(arr);
return s;
}));
}