in src/cc/rpcenv.cc [44:129]
virtual grpc::Status StreamingEnv(
grpc::ServerContext *context,
grpc::ServerReaderWriter<Step, Action> *stream) override {
py::gil_scoped_acquire acquire; // Destroy after pyenv.
py::object pyenv;
py::object stepfunc;
py::object resetfunc;
PyArrayNest observation;
float reward = 0.0;
bool done = true;
int episode_step = 0;
float episode_return = 0.0;
auto set_observation = py::cpp_function(
[&observation](PyArrayNest o) { observation = std::move(o); },
py::arg("observation"));
auto set_observation_reward_done = py::cpp_function(
[&observation, &reward, &done](PyArrayNest o, float r, bool d,
py::args) {
observation = std::move(o);
reward = r;
done = d;
},
py::arg("observation"), py::arg("reward"), py::arg("done"));
try {
pyenv = env_init_();
stepfunc = pyenv.attr("step");
resetfunc = pyenv.attr("reset");
set_observation(resetfunc());
} catch (const pybind11::error_already_set &e) {
// Needs to be caught and not re-raised, as this isn't in a Python
// thread.
std::cerr << e.what() << std::endl;
return grpc::Status(grpc::INTERNAL, e.what());
}
Step step_pb;
fill_nest_pb(step_pb.mutable_observation(), std::move(observation),
fill_ndarray_pb);
step_pb.set_reward(reward);
step_pb.set_done(done);
step_pb.set_episode_step(episode_step);
step_pb.set_episode_return(episode_return);
Action action_pb;
while (true) {
{
py::gil_scoped_release release; // Release while doing transfer.
stream->Write(step_pb);
if (!stream->Read(&action_pb)) {
break;
}
}
try {
// I'm not sure if this is fast, but it's convienient.
set_observation_reward_done(*stepfunc(nest_pb_to_nest(
action_pb.mutable_nest_action(), array_pb_to_nest)));
episode_step += 1;
episode_return += reward;
step_pb.Clear();
step_pb.set_reward(reward);
step_pb.set_done(done);
step_pb.set_episode_step(episode_step);
step_pb.set_episode_return(episode_return);
if (done) {
set_observation(resetfunc());
// Reset episode_* for the _next_ step.
episode_step = 0;
episode_return = 0.0;
}
} catch (const pybind11::error_already_set &e) {
std::cerr << e.what() << std::endl;
return grpc::Status(grpc::INTERNAL, e.what());
}
fill_nest_pb(step_pb.mutable_observation(), std::move(observation),
fill_ndarray_pb);
}
return grpc::Status::OK;
}