in src/cc/actorpool.cc [566:632]
void init_actorpool(py::module& m) {
py::register_exception<std::future_error>(m, "AsyncError");
py::register_exception<ClosedBatchingQueue>(m, "ClosedBatchingQueue");
py::register_exception<std::bad_variant_access>(m, "NestError");
py::class_<ActorPool>(m, "ActorPool")
.def(py::init<int, std::shared_ptr<BatchingQueue<>>,
std::shared_ptr<DynamicBatcher>, std::vector<std::string>,
TensorNest>(),
py::arg("unroll_length"), py::arg("learner_queue").none(false),
py::arg("inference_batcher").none(false),
py::arg("env_server_addresses"), py::arg("initial_agent_state"))
.def("run", &ActorPool::run, py::call_guard<py::gil_scoped_release>())
.def("count", &ActorPool::count);
py::class_<DynamicBatcher::Batch, std::shared_ptr<DynamicBatcher::Batch>>(
m, "Batch")
.def("get_inputs", &DynamicBatcher::Batch::get_inputs)
.def("set_outputs", &DynamicBatcher::Batch::set_outputs,
py::arg("outputs"), py::call_guard<py::gil_scoped_release>());
py::class_<DynamicBatcher, std::shared_ptr<DynamicBatcher>>(m,
"DynamicBatcher")
.def(py::init<int64_t, int64_t, int64_t, std::optional<int>, bool>(),
py::arg("batch_dim") = 1, py::arg("minimum_batch_size") = 1,
py::arg("maximum_batch_size") = 1024, py::arg("timeout_ms") = 100,
py::arg("check_outputs") = true, R"docstring(
DynamicBatcher class.
If timeout_ms is set to None, the batcher will not allow data to be
retrieved until at least minimum_batch_size inputs are provided.
If timeout_ms is not None (default behaviour), the batcher will
allow data to be retrieved when the timeout expires, even if the
number of inputs received is smaller than minimum_batch_size.
)docstring")
.def("close", &DynamicBatcher::close)
.def("is_closed", &DynamicBatcher::is_closed)
.def("size", &DynamicBatcher::size)
.def("compute", &DynamicBatcher::compute,
py::call_guard<py::gil_scoped_release>())
.def("__iter__",
[](std::shared_ptr<DynamicBatcher> batcher) { return batcher; })
.def("__next__", &DynamicBatcher::get_batch,
py::call_guard<py::gil_scoped_release>());
py::class_<BatchingQueue<>, std::shared_ptr<BatchingQueue<>>>(m,
"BatchingQueue")
.def(py::init<int64_t, int64_t, int64_t, std::optional<int>, bool,
std::optional<uint64_t>>(),
py::arg("batch_dim") = 1, py::arg("minimum_batch_size") = 1,
py::arg("maximum_batch_size") = 1024,
py::arg("timeout_ms") = std::nullopt, py::arg("check_inputs") = true,
py::arg("maximum_queue_size") = std::nullopt)
.def("enqueue",
[](std::shared_ptr<BatchingQueue<>> queue, TensorNest tensors) {
queue->enqueue({std::move(tensors), Empty()});
})
.def("close", &BatchingQueue<>::close)
.def("is_closed", &BatchingQueue<>::is_closed)
.def("size", &BatchingQueue<>::size)
.def("__iter__",
[](std::shared_ptr<BatchingQueue<>> queue) { return queue; })
.def("__next__", [](BatchingQueue<>& queue) {
py::gil_scoped_release release;
std::pair<TensorNest, std::vector<Empty>> pair = queue.dequeue_many();
return pair.first;
});
}