void init_actorpool()

in src/cc/actorpool.cc [566:632]


void init_actorpool(py::module& m) {
  py::register_exception<std::future_error>(m, "AsyncError");
  py::register_exception<ClosedBatchingQueue>(m, "ClosedBatchingQueue");
  py::register_exception<std::bad_variant_access>(m, "NestError");

  py::class_<ActorPool>(m, "ActorPool")
      .def(py::init<int, std::shared_ptr<BatchingQueue<>>,
                    std::shared_ptr<DynamicBatcher>, std::vector<std::string>,
                    TensorNest>(),
           py::arg("unroll_length"), py::arg("learner_queue").none(false),
           py::arg("inference_batcher").none(false),
           py::arg("env_server_addresses"), py::arg("initial_agent_state"))
      .def("run", &ActorPool::run, py::call_guard<py::gil_scoped_release>())
      .def("count", &ActorPool::count);

  py::class_<DynamicBatcher::Batch, std::shared_ptr<DynamicBatcher::Batch>>(
      m, "Batch")
      .def("get_inputs", &DynamicBatcher::Batch::get_inputs)
      .def("set_outputs", &DynamicBatcher::Batch::set_outputs,
           py::arg("outputs"), py::call_guard<py::gil_scoped_release>());

  py::class_<DynamicBatcher, std::shared_ptr<DynamicBatcher>>(m,
                                                              "DynamicBatcher")
      .def(py::init<int64_t, int64_t, int64_t, std::optional<int>, bool>(),
           py::arg("batch_dim") = 1, py::arg("minimum_batch_size") = 1,
           py::arg("maximum_batch_size") = 1024, py::arg("timeout_ms") = 100,
           py::arg("check_outputs") = true, R"docstring(
             DynamicBatcher class.
             If timeout_ms is set to None, the batcher will not allow data to be
             retrieved until at least minimum_batch_size inputs are provided.
             If timeout_ms is not None (default behaviour), the batcher will
             allow data to be retrieved when the timeout expires, even if the
             number of inputs received is smaller than minimum_batch_size.
           )docstring")
      .def("close", &DynamicBatcher::close)
      .def("is_closed", &DynamicBatcher::is_closed)
      .def("size", &DynamicBatcher::size)
      .def("compute", &DynamicBatcher::compute,
           py::call_guard<py::gil_scoped_release>())
      .def("__iter__",
           [](std::shared_ptr<DynamicBatcher> batcher) { return batcher; })
      .def("__next__", &DynamicBatcher::get_batch,
           py::call_guard<py::gil_scoped_release>());

  py::class_<BatchingQueue<>, std::shared_ptr<BatchingQueue<>>>(m,
                                                                "BatchingQueue")
      .def(py::init<int64_t, int64_t, int64_t, std::optional<int>, bool,
                    std::optional<uint64_t>>(),
           py::arg("batch_dim") = 1, py::arg("minimum_batch_size") = 1,
           py::arg("maximum_batch_size") = 1024,
           py::arg("timeout_ms") = std::nullopt, py::arg("check_inputs") = true,
           py::arg("maximum_queue_size") = std::nullopt)
      .def("enqueue",
           [](std::shared_ptr<BatchingQueue<>> queue, TensorNest tensors) {
             queue->enqueue({std::move(tensors), Empty()});
           })
      .def("close", &BatchingQueue<>::close)
      .def("is_closed", &BatchingQueue<>::is_closed)
      .def("size", &BatchingQueue<>::size)
      .def("__iter__",
           [](std::shared_ptr<BatchingQueue<>> queue) { return queue; })
      .def("__next__", [](BatchingQueue<>& queue) {
        py::gil_scoped_release release;
        std::pair<TensorNest, std::vector<Empty>> pair = queue.dequeue_many();
        return pair.first;
      });
}