binding-go/client.go (278 lines of code) (raw):

package gym import ( "bytes" "encoding/json" "errors" "fmt" "io" "io/ioutil" "net/http" "net/url" "os" "reflect" "strconv" ) // InstanceID uniquely identifies a running instance. type InstanceID string func (i InstanceID) path() string { return "/v1/envs/" + string(i) } // Space stores information about an action space or an // observation space. type Space struct { // Name is the name of the space, such as "Box", "HighLow", // or "Discrete". Name string `json:"name"` // Properties for Box spaces. Shape []int `json:"shape"` Low []float64 `json:"low"` High []float64 `json:"high"` // Properties for Discrete spaces. N int `json:"n"` // Properties for HighLow spaces. NumRows int `json:"num_rows"` Matrix []float64 `json:"matrix"` } // A Client interfaces with a Gym HTTP server. type Client struct { remoteURL url.URL } // NewClient creates a client with the given base URL. // // For example, the base URL might be // // http://localhost:8080 // // This fails if the baseURL is invalid. func NewClient(baseURL string) (*Client, error) { u, err := url.Parse(baseURL) if err != nil { return nil, fmt.Errorf("create client: %s", err) } return &Client{remoteURL: *u}, nil } // ListAll lists all instantiated environments. // The result maps between instance IDs and environment // IDs. func (c *Client) ListAll() (map[InstanceID]string, error) { var resp struct { Result map[InstanceID]string `json:"all_envs"` } if err := c.get("/v1/envs/", &resp); err != nil { return nil, fmt.Errorf("list environments: %s", err) } return resp.Result, nil } // Create creates a new instance of an environment. func (c *Client) Create(envID string) (InstanceID, error) { var resp struct { Result InstanceID `json:"instance_id"` } req := map[string]string{"env_id": envID} if err := c.post("/v1/envs/", req, &resp); err != nil { return "", fmt.Errorf("create environment: %s", err) } return resp.Result, nil } // Reset resets the environment instance. // // The resulting observation type may vary. // For discrete spaces, it is an int. // For vector spaces, it is a []float64. func (c *Client) Reset(id InstanceID) (observation interface{}, err error) { defer func() { if err != nil { err = fmt.Errorf("reset environment: %s", err) } }() var resp struct { Observation interface{} `json:"observation"` } if err := c.post(id.path()+"/reset/", struct{}{}, &resp); err != nil { return nil, err } return normalizeSpaceElem(resp.Observation) } // Step takes a step in the environment. // // The action type may vary. // For discrete spaces, it should be an int. // For vector spaces, it should be a []float64 or a // []float32. // // See Reset() for information on the observation type. func (c *Client) Step(id InstanceID, action interface{}, render bool) (observation interface{}, reward float64, done bool, info interface{}, err error) { defer func() { if err != nil { err = fmt.Errorf("step environment: %s", err) } }() req := map[string]interface{}{"action": action, "render": render} var resp struct { Observation interface{} `json:"observation"` Reward float64 `json:"reward"` Done bool `json:"done"` Info interface{} `json:"info"` } err = c.post(id.path()+"/step/", req, &resp) if err != nil { return } resp.Observation, err = normalizeSpaceElem(resp.Observation) if err != nil { return } return resp.Observation, resp.Reward, resp.Done, resp.Info, nil } // ActionSpace fetches the action space. func (c *Client) ActionSpace(id InstanceID) (*Space, error) { return c.getSpace(id, "action_space") } // ObservationSpace fetches the observation space. func (c *Client) ObservationSpace(id InstanceID) (*Space, error) { return c.getSpace(id, "observation_space") } // SampleAction samples an action uniformly. // // The action is turned into a Go type just like Reset() // turns observations into Go types. func (c *Client) SampleAction(id InstanceID) (interface{}, error) { var resp struct { Action interface{} `json:"action"` } if err := c.get(id.path()+"/action_space/sample", &resp); err != nil { return nil, fmt.Errorf("sample action: %s", err) } if obs, err := normalizeSpaceElem(resp.Action); err != nil { return nil, fmt.Errorf("sample action: %s", err) } else { return obs, nil } } // ContainsAction checks if an action is contained in the // action space. // // Currently, only int action types are supported. func (c *Client) ContainsAction(id InstanceID, act interface{}) (bool, error) { num, ok := act.(int) if !ok { return false, fmt.Errorf("contains action: unexpected action type %T", act) } var resp struct { Member bool `json:"member"` } path := id.path() + "/action_space/contains/" + strconv.Itoa(num) if err := c.get(path, &resp); err != nil { return false, fmt.Errorf("contains action: %s", err) } return resp.Member, nil } // Close closes the environment instance. func (c *Client) Close(id InstanceID) error { if err := c.post(id.path()+"/close/", struct{}{}, nil); err != nil { return fmt.Errorf("close environment: %s", err) } return nil } // StartMonitor starts monitoring the environment. func (c *Client) StartMonitor(id InstanceID, dir string, force, resume, videoCallable bool) error { req := map[string]interface{}{ "directory": dir, "force": force, "resume": resume, "video_callable": videoCallable, } if err := c.post(id.path()+"/monitor/start/", req, nil); err != nil { return fmt.Errorf("start monitor: %s", err) } return nil } // CloseMonitor stops monitoring the environment. func (c *Client) CloseMonitor(id InstanceID) error { if err := c.post(id.path()+"/monitor/close/", struct{}{}, nil); err != nil { return fmt.Errorf("close monitor: %s", err) } return nil } // Upload uploads the monitor results from the directory // to the Gym website. // // If apiKey is "", then the "OPENAI_GYM_API_KEY" // environment variable is used. func (c *Client) Upload(dir, apiKey, algorithmID string) error { if apiKey == "" { apiKey = os.Getenv("OPENAI_GYM_API_KEY") } data := map[string]string{"training_dir": dir, "api_key": apiKey} if algorithmID != "" { data["algorithm_id"] = algorithmID } if err := c.post("/v1/upload/", data, nil); err != nil { return fmt.Errorf("upload: %s", err) } return nil } // Shutdown stops the server. func (c *Client) Shutdown() error { if err := c.post("/v1/shutdown/", struct{}{}, nil); err != nil { return fmt.Errorf("shutdown: %s", err) } return nil } func (c *Client) getSpace(id InstanceID, name string) (*Space, error) { var resp struct { Space *Space `json:"info"` } if err := c.get(id.path()+"/"+name+"/", &resp); err != nil { return nil, fmt.Errorf("get space: %s", err) } return resp.Space, nil } // post encodes data as JSON and POSTs it to the path. // If result is non-nil, the response is parsed as JSON // into result. func (c *Client) post(path string, data, result interface{}) error { u := c.remoteURL u.Path = path body, err := json.Marshal(data) if err != nil { return err } resp, err := http.Post(u.String(), "application/json", bytes.NewReader(body)) if err != nil { return err } return processResponse(resp.Body, result) } // get requests the URL and decodes the response as JSON // into result. func (c *Client) get(path string, result interface{}) error { u := c.remoteURL u.Path = path resp, err := http.Get(u.String()) if err != nil { return err } return processResponse(resp.Body, result) } func processResponse(body io.ReadCloser, result interface{}) error { defer body.Close() bodyData, err := ioutil.ReadAll(body) if err != nil { return err } if err := responseErrorMessage(bodyData); err != nil { return err } if result != nil { if err := json.Unmarshal(bodyData, &result); err != nil { return err } } return nil } func responseErrorMessage(resp []byte) error { var obj struct { Message string `json:"message"` } json.Unmarshal(resp, &obj) if obj.Message != "" { return errors.New(obj.Message) } return nil } func normalizeSpaceElem(obs interface{}) (interface{}, error) { if obs == nil { return nil, errors.New("unsupported observation: nil") } switch obs := obs.(type) { case float64: return int(obs), nil case []interface{}: if len(obs) == 0 { return nil, errors.New("unsupported observation: empty array") } else if _, isFloat := obs[0].(float64); isFloat { return normalizeOneDimSpace(obs) } else { return normalizeMultiDimSpace(obs) } default: return nil, fmt.Errorf("unsupported observation: %v", obs) } } func normalizeOneDimSpace(obs []interface{}) ([]float64, error) { res := make([]float64, len(obs)) for i, x := range obs { var isFloat bool res[i], isFloat = x.(float64) if !isFloat { return nil, errors.New("unsupported observation: heterogeneous array") } } return res, nil } func normalizeMultiDimSpace(obs []interface{}) (interface{}, error) { firstElem, err := normalizeSpaceElem(obs[0]) if err != nil { return nil, err } elemType := reflect.TypeOf(firstElem) sliceType := reflect.SliceOf(elemType) slice := reflect.MakeSlice(sliceType, len(obs), len(obs)) for i, x := range obs { obj, err := normalizeSpaceElem(x) if err != nil { return nil, err } val := reflect.ValueOf(obj) if val.Type() != elemType { return nil, errors.New("unsupported observation: heterogeneous array") } slice.Index(i).Set(val) } return slice.Interface(), nil }