in src/cc/actorpool.cc [354:460]
void loop(int64_t loop_index, const std::string& address) {
std::shared_ptr<grpc::Channel> channel =
grpc::CreateChannel(address, grpc::InsecureChannelCredentials());
std::unique_ptr<rpcenv::RPCEnvServer::Stub> stub =
rpcenv::RPCEnvServer::NewStub(channel);
auto deadline =
std::chrono::system_clock::now() + std::chrono::seconds(10 * 60);
if (loop_index == 0) {
std::cout << "First Environment waiting for connection to " << address
<< " ...";
}
if (!channel->WaitForConnected(deadline)) {
throw py::timeout_error("WaitForConnected timed out.");
}
if (loop_index == 0) {
std::cout << " connection established." << std::endl;
}
grpc::ClientContext context;
std::shared_ptr<grpc::ClientReaderWriter<rpcenv::Action, rpcenv::Step>>
stream(stub->StreamingEnv(&context));
rpcenv::Step step_pb;
if (!stream->Read(&step_pb)) {
throw py::connection_error("Initial read failed.");
}
TensorNest initial_agent_state = initial_agent_state_;
TensorNest env_outputs = ActorPool::step_pb_to_nest(&step_pb);
TensorNest compute_inputs(std::vector({env_outputs, initial_agent_state}));
TensorNest all_agent_outputs =
inference_batcher_->compute(compute_inputs); // Copy.
// Check this once per thread.
if (!all_agent_outputs.is_vector()) {
throw py::value_error("Expected agent output to be tuple");
}
if (all_agent_outputs.get_vector().size() != 2) {
throw py::value_error(
"Expected agent output to be ((action, ...), new_state) but got "
"sequence of "
"length " +
std::to_string(all_agent_outputs.get_vector().size()));
}
TensorNest agent_state = all_agent_outputs.get_vector()[1];
TensorNest agent_outputs = all_agent_outputs.get_vector()[0];
if (!agent_outputs.is_vector()) {
throw py::value_error(
"Expected first entry of agent output to be a (action, ...) tuple");
}
TensorNest last(std::vector({env_outputs, agent_outputs}));
rpcenv::Action action_pb;
std::vector<TensorNest> rollout;
try {
while (true) {
rollout.push_back(std::move(last));
for (int t = 1; t <= unroll_length_; ++t) {
all_agent_outputs = inference_batcher_->compute(compute_inputs);
agent_state = all_agent_outputs.get_vector()[1];
agent_outputs = all_agent_outputs.get_vector()[0];
// agent_outputs must be a tuple/list.
const TensorNest& action = agent_outputs.get_vector().front();
action_pb.Clear();
fill_nest_pb(
action_pb.mutable_nest_action(), action,
[&](rpcenv::NDArray* array, const torch::Tensor& tensor) {
return fill_ndarray_pb(array, tensor, /*start_dim=*/2);
});
stream->Write(action_pb);
if (!stream->Read(&step_pb)) {
throw py::connection_error("Read failed.");
}
env_outputs = ActorPool::step_pb_to_nest(&step_pb);
compute_inputs = TensorNest(std::vector({env_outputs, agent_state}));
last = TensorNest(std::vector({env_outputs, agent_outputs}));
rollout.push_back(std::move(last));
}
last = rollout.back();
learner_queue_->enqueue({
TensorNest(std::vector(
{batch(rollout, 0), std::move(initial_agent_state)})),
});
rollout.clear();
initial_agent_state = agent_state; // Copy
count_ += unroll_length_;
}
} catch (const ClosedBatchingQueue& e) {
// Thrown when inference_batcher_ and learner_queue_ are closed. Stop.
stream->WritesDone();
grpc::Status status = stream->Finish();
if (!status.ok()) {
std::cerr << "rpc failed on finish." << std::endl;
}
}
}