gym_http_client.py (127 lines of code) (raw):

import requests import six.moves.urllib.parse as urlparse import json import os import logging logger = logging.getLogger(__name__) logger.setLevel(logging.INFO) class Client(object): """ Gym client to interface with gym_http_server """ def __init__(self, remote_base): self.remote_base = remote_base self.session = requests.Session() self.session.headers.update({'Content-type': 'application/json'}) def _parse_server_error_or_raise_for_status(self, resp): j = {} try: j = resp.json() except: # Most likely json parse failed because of network error, not server error (server # sends its errors in json). Don't let parse exception go up, but rather raise default # error. resp.raise_for_status() if resp.status_code != 200 and "message" in j: # descriptive message from server side raise ServerError(message=j["message"], status_code=resp.status_code) resp.raise_for_status() return j def _post_request(self, route, data): url = urlparse.urljoin(self.remote_base, route) logger.info("POST {}\n{}".format(url, json.dumps(data))) resp = self.session.post(urlparse.urljoin(self.remote_base, route), data=json.dumps(data)) return self._parse_server_error_or_raise_for_status(resp) def _get_request(self, route): url = urlparse.urljoin(self.remote_base, route) logger.info("GET {}".format(url)) resp = self.session.get(url) return self._parse_server_error_or_raise_for_status(resp) def env_create(self, env_id): route = '/v1/envs/' data = {'env_id': env_id} resp = self._post_request(route, data) instance_id = resp['instance_id'] return instance_id def env_list_all(self): route = '/v1/envs/' resp = self._get_request(route) all_envs = resp['all_envs'] return all_envs def env_reset(self, instance_id): route = '/v1/envs/{}/reset/'.format(instance_id) resp = self._post_request(route, None) observation = resp['observation'] return observation def env_step(self, instance_id, action, render=False): route = '/v1/envs/{}/step/'.format(instance_id) data = {'action': action, 'render': render} resp = self._post_request(route, data) observation = resp['observation'] reward = resp['reward'] done = resp['done'] info = resp['info'] return [observation, reward, done, info] def env_action_space_info(self, instance_id): route = '/v1/envs/{}/action_space/'.format(instance_id) resp = self._get_request(route) info = resp['info'] return info def env_action_space_sample(self, instance_id): route = '/v1/envs/{}/action_space/sample'.format(instance_id) resp = self._get_request(route) action = resp['action'] return action def env_action_space_contains(self, instance_id, x): route = '/v1/envs/{}/action_space/contains/{}'.format(instance_id, x) resp = self._get_request(route) member = resp['member'] return member def env_observation_space_info(self, instance_id): route = '/v1/envs/{}/observation_space/'.format(instance_id) resp = self._get_request(route) info = resp['info'] return info def env_observation_space_contains(self, instance_id, params): route = '/v1/envs/{}/observation_space/contains'.format(instance_id) resp = self._post_request(route, params) member = resp['member'] return member def env_monitor_start(self, instance_id, directory, force=False, resume=False, video_callable=False): route = '/v1/envs/{}/monitor/start/'.format(instance_id) data = {'directory': directory, 'force': force, 'resume': resume, 'video_callable': video_callable} self._post_request(route, data) def env_monitor_close(self, instance_id): route = '/v1/envs/{}/monitor/close/'.format(instance_id) self._post_request(route, None) def env_close(self, instance_id): route = '/v1/envs/{}/close/'.format(instance_id) self._post_request(route, None) def upload(self, training_dir, algorithm_id=None, api_key=None): if not api_key: api_key = os.environ.get('OPENAI_GYM_API_KEY') route = '/v1/upload/' data = {'training_dir': training_dir, 'algorithm_id': algorithm_id, 'api_key': api_key} self._post_request(route, data) def shutdown_server(self): route = '/v1/shutdown/' self._post_request(route, None) class ServerError(Exception): def __init__(self, message, status_code=None): Exception.__init__(self) self.message = message if status_code is not None: self.status_code = status_code if __name__ == '__main__': remote_base = 'http://127.0.0.1:5000' client = Client(remote_base) # Create environment env_id = 'CartPole-v0' instance_id = client.env_create(env_id) # Check properties all_envs = client.env_list_all() action_info = client.env_action_space_info(instance_id) obs_info = client.env_observation_space_info(instance_id) # Run a single step client.env_monitor_start(instance_id, directory='tmp', force=True) init_obs = client.env_reset(instance_id) [observation, reward, done, info] = client.env_step(instance_id, 1, True) client.env_monitor_close(instance_id) client.upload(training_dir='tmp')