binding-cpp/gym_binding.cpp (252 lines of code) (raw):
#include "include/gym/gym.h"
#include <boost/enable_shared_from_this.hpp>
#include <curl/curl.h>
#include <json/value.h>
#include <json/reader.h>
#include <stdio.h>
#include <random>
namespace Gym {
static bool verbose = false;
static std::random_device rd;
static std::mt19937 rand_generator(rd());
std::vector<float> Space::sample()
{
if (type==DISCRETE) {
std::uniform_int_distribution<int> randint(0, discreet_n-1);
std::vector<float> r(1, 0.0f);
r[0] = randint(rand_generator);
return r;
}
assert(type==BOX);
std::uniform_real_distribution<float> rand(0.0f, 1.0f);
int sz = 1;
for (int dim: box_shape)
sz *= dim;
assert((int)box_high.size()==sz);
assert((int)box_low.size()==sz);
std::vector<float> r(sz, 0.0f);
for (int c=0; c<sz; ++c)
r[c] = (box_high[c]-box_low[c])*rand(rand_generator) + box_low[c];
return r;
}
static
std::string require(const Json::Value& v, const std::string& k)
{
if (!v.isObject() || !v.isMember(k))
throw std::runtime_error("cannot find required parameter '" + k + "'");
return v[k].asString();
}
static
boost::shared_ptr<Space> space_from_json(const Json::Value& j)
{
boost::shared_ptr<Space> r(new Space);
Json::Value v = j["info"];
std::string type = require(v, "name");
if (type=="Discrete") {
r->type = Space::DISCRETE;
r->discreet_n = v["n"].asInt(); // will throw runtime_error if cannot be converted to int
} else if (type=="Box") {
r->type = Space::BOX;
Json::Value shape = v["shape"];
Json::Value low = v["low"];
Json::Value high = v["high"];
if (!shape.isArray() || !low.isArray() || !high.isArray())
throw std::runtime_error("cannot parse box space (1)");
int l1 = low.size();
int l2 = high.size();
int ls = shape.size();
int sz = 1;
for (int s=0; s<ls; ++s) {
int e = shape[s].asInt();
r->box_shape.push_back(e);
sz *= e;
}
if (sz != l1 || l1 != l2)
throw std::runtime_error("cannot parse box space (2)");
r->box_low.resize(sz);
r->box_high.resize(sz);
for (int i=0; i<sz; ++i) {
r->box_low[i] = low[i].asFloat();
r->box_high[i] = high[i].asFloat();
}
} else {
throw std::runtime_error("unknown space type '" + type + "'");
}
return r;
}
// curl
static
std::size_t curl_save_to_string(void* buffer, std::size_t size, std::size_t nmemb, void* userp)
{
std::string* str = static_cast<std::string*>(userp);
const std::size_t bytes = nmemb*size;
str->append(static_cast<char*>(buffer), bytes);
return bytes;
}
class ClientReal: public Client, public boost::enable_shared_from_this<ClientReal> {
public:
std::string addr;
int port;
boost::shared_ptr<CURL> h;
boost::shared_ptr<curl_slist> headers;
std::vector<char> curl_error_buf;
ClientReal()
{
CURL* c = curl_easy_init();
curl_easy_setopt(c, CURLOPT_NOSIGNAL, 1);
curl_easy_setopt(c, CURLOPT_CONNECTTIMEOUT_MS, 3000);
curl_easy_setopt(c, CURLOPT_IPRESOLVE, CURL_IPRESOLVE_V4);
curl_easy_setopt(c, CURLOPT_FOLLOWLOCATION, true);
curl_easy_setopt(c, CURLOPT_SSL_VERIFYPEER, 0);
curl_easy_setopt(c, CURLOPT_SSL_VERIFYHOST, 0);
curl_easy_setopt(c, CURLOPT_WRITEFUNCTION, &curl_save_to_string);
curl_error_buf.assign(CURL_ERROR_SIZE, 0);
curl_easy_setopt(c, CURLOPT_ERRORBUFFER, curl_error_buf.data());
h.reset(c, std::ptr_fun(curl_easy_cleanup));
headers.reset(curl_slist_append(0, "Content-Type: application/json"), std::ptr_fun(curl_slist_free_all));
}
Json::Value GET(const std::string& route)
{
std::string url = "http://" + addr + route;
if (verbose) printf("GET %s\n", url.c_str());
curl_easy_setopt(h.get(), CURLOPT_URL, url.c_str());
curl_easy_setopt(h.get(), CURLOPT_PORT, port);
std::string answer;
curl_easy_setopt(h.get(), CURLOPT_WRITEDATA, &answer);
curl_easy_setopt(h.get(), CURLOPT_POST, 0);
curl_easy_setopt(h.get(), CURLOPT_HTTPHEADER, 0);
CURLcode r;
r = curl_easy_perform(h.get());
if (r) throw std::runtime_error(curl_error_buf.data());
Json::Value j;
throw_server_error_or_response_code(answer, j);
return j;
}
Json::Value POST(const std::string& route, const std::string& post_data)
{
std::string url = "http://" + addr + route;
if (verbose) printf("POST %s\n%s\n", url.c_str(), post_data.c_str());
curl_easy_setopt(h.get(), CURLOPT_URL, url.c_str());
curl_easy_setopt(h.get(), CURLOPT_PORT, port);
std::string answer;
curl_easy_setopt(h.get(), CURLOPT_WRITEDATA, &answer);
curl_easy_setopt(h.get(), CURLOPT_POST, 1);
curl_easy_setopt(h.get(), CURLOPT_POSTFIELDS, post_data.c_str());
curl_easy_setopt(h.get(), CURLOPT_POSTFIELDSIZE_LARGE, (curl_off_t)post_data.size());
curl_easy_setopt(h.get(), CURLOPT_HTTPHEADER, headers.get());
CURLcode r = curl_easy_perform(h.get());
if (r) throw std::runtime_error(curl_error_buf.data());
Json::Value j;
throw_server_error_or_response_code(answer, j);
return j;
}
void throw_server_error_or_response_code(const std::string& answer, Json::Value& j)
{
long response_code;
CURLcode r = curl_easy_getinfo(h.get(), CURLINFO_RESPONSE_CODE, &response_code);
if (r) throw std::runtime_error(curl_error_buf.data());
if (verbose) printf("%i\n%s\n", (int)response_code, answer.c_str());
std::string parse_error;
Json::Reader jr;
if (!jr.parse(answer, j, false)) {
parse_error = jr.getFormattedErrorMessages();
parse_error += "original json that caused error: " + answer;
} else if (!j.isObject()) {
parse_error = "top level json is not an object";
parse_error += "original json that caused error: " + answer;
}
if (response_code != 200 && j.isObject() && j.isMember("message")) {
throw std::runtime_error(j["message"].asString());
} else if (response_code != 200) {
throw std::runtime_error("bad HTTP response code, and also cannot parse server message: " + answer);
} else {
// 200, but maybe invalid json
if (!parse_error.empty())
throw std::runtime_error(parse_error);
}
}
boost::shared_ptr<Environment> make(const std::string& env_id) override;
};
boost::shared_ptr<Client> client_create(const std::string& addr, int port)
{
boost::shared_ptr<ClientReal> client(new ClientReal);
client->addr = addr;
client->port = port;
return client;
}
// environment
class EnvironmentReal: public Environment {
public:
std::string instance_id;
boost::shared_ptr<ClientReal> client;
boost::shared_ptr<Space> space_act;
boost::shared_ptr<Space> space_obs;
boost::shared_ptr<Space> action_space() override
{
if (!space_act)
space_act = space_from_json(client->GET("/v1/envs/" + instance_id + "/action_space"));
return space_act;
}
boost::shared_ptr<Space> observation_space() override
{
if (!space_obs)
space_obs = space_from_json(client->GET("/v1/envs/" + instance_id + "/observation_space"));
return space_obs;
}
void observation_parse(const Json::Value& v, std::vector<float>& save_here)
{
if (!v.isArray())
throw std::runtime_error("cannot parse observation, not an array");
int s = v.size();
save_here.resize(s);
for (int i=0; i<s; ++i)
save_here[i] = v[i].asFloat();
}
void reset(State* save_initial_state_here) override
{
Json::Value ans = client->POST("/v1/envs/" + instance_id + "/reset/", "");
observation_parse(ans["observation"], save_initial_state_here->observation);
}
void step(const std::vector<float>& action, bool render, State* save_state_here) override
{
Json::Value act_json;
boost::shared_ptr<Space> aspace = action_space();
if (aspace->type==Space::DISCRETE) {
act_json["action"] = (int) action[0];
} else if (aspace->type==Space::BOX) {
Json::Value& array = act_json["action"];
assert(action.size()==aspace->box_low.size()); // really assert, it's a programming error on C++ part
for (int c=0; c<(int)action.size(); ++c)
array[c] = action[c];
} else {
assert(0);
}
act_json["render"] = render;
Json::Value ans = client->POST("/v1/envs/" + instance_id + "/step/", act_json.toStyledString());
observation_parse(ans["observation"], save_state_here->observation);
save_state_here->done = ans["done"].asBool();
save_state_here->reward = ans["reward"].asFloat();
}
void monitor_start(const std::string& directory, bool force, bool resume) override
{
Json::Value data;
data["directory"] = directory;
data["force"] = force;
data["resume"] = resume;
client->POST("/v1/envs/" + instance_id + "/monitor/start/", data.toStyledString());
}
void monitor_stop() override
{
client->POST("/v1/envs/" + instance_id + "/monitor/close/", "");
}
};
boost::shared_ptr<Environment> ClientReal::make(const std::string& env_id)
{
Json::Value req;
req["env_id"] = env_id;
Json::Value ans = POST("/v1/envs/", req.toStyledString());
std::string instance_id = require(ans, "instance_id");
if (verbose) printf(" * created %s instance_id=%s\n", env_id.c_str(), instance_id.c_str());
boost::shared_ptr<EnvironmentReal> env(new EnvironmentReal);
env->client = shared_from_this();
env->instance_id = instance_id;
return env;
}
} // namespace