in rlmeta/cc/segment_tree.h [321:396]
void DefineSumSegmentTree(const std::string& type, py::module& m) {
const std::string pyclass = type + "SumSegmentTree";
py::class_<SumSegmentTree<T>, std::shared_ptr<SumSegmentTree<T>>>(
m, pyclass.c_str())
.def(py::init<int64_t>())
.def("__len__", &SumSegmentTree<T>::size)
.def("size", &SumSegmentTree<T>::size)
.def("capacity", &SumSegmentTree<T>::capacity)
.def("identity_element", &SumSegmentTree<T>::identity_element)
.def("__getitem__",
py::overload_cast<int64_t>(&SumSegmentTree<T>::At, py::const_))
.def("__getitem__", py::overload_cast<const py::array_t<int64_t>&>(
&SumSegmentTree<T>::At, py::const_))
.def("__getitem__", py::overload_cast<const torch::Tensor&>(
&SumSegmentTree<T>::At, py::const_))
.def("at", py::overload_cast<int64_t>(&SumSegmentTree<T>::At, py::const_))
.def("at", py::overload_cast<const py::array_t<int64_t>&>(
&SumSegmentTree<T>::At, py::const_))
.def("at", py::overload_cast<const torch::Tensor&>(&SumSegmentTree<T>::At,
py::const_))
.def("__setitem__",
py::overload_cast<int64_t, const T&>(&SumSegmentTree<T>::Update))
.def("__setitem__",
py::overload_cast<const py::array_t<int64_t>&, const T&>(
&SumSegmentTree<T>::Update))
.def(
"__setitem__",
py::overload_cast<const py::array_t<int64_t>&, const py::array_t<T>&>(
&SumSegmentTree<T>::Update))
.def("__setitem__", py::overload_cast<const torch::Tensor&, const T&>(
&SumSegmentTree<T>::Update))
.def("__setitem__",
py::overload_cast<const torch::Tensor&, const torch::Tensor&>(
&SumSegmentTree<T>::Update))
.def("update",
py::overload_cast<int64_t, const T&>(&SumSegmentTree<T>::Update))
.def("update", py::overload_cast<const py::array_t<int64_t>&, const T&>(
&SumSegmentTree<T>::Update))
.def(
"update",
py::overload_cast<const py::array_t<int64_t>&, const py::array_t<T>&>(
&SumSegmentTree<T>::Update))
.def("update", py::overload_cast<const torch::Tensor&, const T&>(
&SumSegmentTree<T>::Update))
.def("update",
py::overload_cast<const torch::Tensor&, const torch::Tensor&>(
&SumSegmentTree<T>::Update))
.def("query", py::overload_cast<int64_t, int64_t>(
&SumSegmentTree<T>::Query, py::const_))
.def("query", py::overload_cast<const py::array_t<int64_t>&,
const py::array_t<int64_t>&>(
&SumSegmentTree<T>::Query, py::const_))
.def("query",
py::overload_cast<const torch::Tensor&, const torch::Tensor&>(
&SumSegmentTree<T>::Query, py::const_))
.def("scan_lower_bound",
py::overload_cast<const T&>(&SumSegmentTree<T>::ScanLowerBound,
py::const_))
.def("scan_lower_bound",
py::overload_cast<const py::array_t<T>&>(
&SumSegmentTree<T>::ScanLowerBound, py::const_))
.def("scan_lower_bound",
py::overload_cast<const torch::Tensor&>(
&SumSegmentTree<T>::ScanLowerBound, py::const_))
.def(py::pickle(
[](const SumSegmentTree<T>& s) {
return py::make_tuple(s.DumpValues());
},
[](const py::tuple& t) {
assert(t.size() == 1);
const py::array_t<T>& arr = t[0].cast<py::array_t<T>>();
SumSegmentTree<T> s(arr.size());
s.LoadValues(arr);
return s;
}));
}