pkg/eas/predict_client.go (243 lines of code) (raw):
package eas
import (
"bytes"
"fmt"
"io/ioutil"
"net/http"
"sync/atomic"
"time"
)
const (
// Default endpoint is the gateway mode
EndpointTypeDefault = "DEFAULT"
// Vipserver endpoint is only used for services which registered
// vipsever domains in alibaba internal clusters
EndpointTypeVipserver = "VIPSERVER"
// Direct endpoint is used for direct accessing to the services' instances
// both inside a eas service and in user's client ecs
EndpointTypeDirect = "DIRECT"
// use eas docker to deploy service, but not use pai eas
EndpointTypeDocker = "DOCKER"
)
const (
ErrorCodeServiceDiscovery = 510
ErrorCodeCreateRequest = 511
ErrorCodePerformRequest = 512
ErrorCodeReadResponse = 513
)
// PredictError is a custom err type
type PredictError struct {
Code int
Message string
RequestURL string
}
// Error for error interface
func (err *PredictError) Error() string {
return fmt.Sprintf("Url: [%v] Code: [%d], Message: [%s]", err.RequestURL, err.Code, err.Message)
}
// NewPredictError constructs an error
func NewPredictError(code int, url string, msg string) *PredictError {
return &PredictError{
Code: code,
Message: msg,
RequestURL: url,
}
}
// PredictClient for accessing prediction service by creating a fixed size connection pool
// to perform the request through established persistent connections.
type PredictClient struct {
retryCount int
maxConnectionCount int
token string
headers map[string]string
host string
endpoint Endpoint
endpointType string
endpointName string
serviceName string
stop int32
client http.Client
}
// NewPredictClient returns an instance of PredictClient
func NewPredictClient(endpointName string, serviceName string) *PredictClient {
return &PredictClient{
endpointName: endpointName,
serviceName: serviceName,
retryCount: 5,
stop: 0,
headers: map[string]string{},
client: http.Client{
Timeout: 5000 * time.Millisecond,
Transport: &http.Transport{
MaxConnsPerHost: 100,
},
},
}
}
// Init initializes the predict client to create and enable endpoint discovery
func (p *PredictClient) Init() error {
switch p.endpointType {
case "":
p.endpoint = newGatewayEndpoint(p.endpointName)
case EndpointTypeDefault:
p.endpoint = newGatewayEndpoint(p.endpointName)
case EndpointTypeVipserver:
p.endpoint = newVipServerEndpoint(p.endpointName)
go p.syncHandler()
case EndpointTypeDirect:
p.endpoint = newCacheServerEndpoint(p.endpointName, p.serviceName)
go p.syncHandler()
default:
return NewPredictError(http.StatusBadRequest, "", "Unsupported endpoint type: "+p.endpointType)
}
return nil
}
// Shutdown after called this client instance should not be used again
func (p *PredictClient) Shutdown() {
atomic.StoreInt32(&(p.stop), 1)
}
// syncHandler synchronizes the services's endpoints from the upstream discovery server periodically
func (p *PredictClient) syncHandler() {
p.endpoint.Sync()
for {
select {
// Sync endpoints from upstream every 3 seconds
case <-time.NewTimer(time.Second * 3).C:
if 1 == atomic.LoadInt32(&(p.stop)) {
break
}
p.endpoint.Sync()
}
}
}
// SetEndpoint sets service's endpoint for client
func (p *PredictClient) SetEndpoint(endpointName string) {
p.endpointName = endpointName
}
// SetEndpointType sets endpoint type for client
func (p *PredictClient) SetEndpointType(endpointType string) {
p.endpointType = endpointType
}
// SetToken function sets service's access token for client
func (p *PredictClient) SetToken(token string) {
p.token = token
}
func (p *PredictClient) AddHeader(headerName, headerValue string) {
p.headers[headerName] = headerValue
}
func (p *PredictClient) SetHost(host string) {
p.host = host
}
// SetRetryCount sets max retry count for client
func (p *PredictClient) SetRetryCount(cnt int) {
p.retryCount = cnt
}
// SetHttpTransport sets http transport argument for go http client
func (p *PredictClient) SetHttpTransport(transport *http.Transport) {
p.client.Transport = transport
}
// SetTimeout set the request timeout for client, 5000ms by default
func (p *PredictClient) SetTimeout(timeout int) {
p.client.Timeout = time.Duration(timeout) * time.Millisecond
}
// SetServiceName sets target service name for client
func (p *PredictClient) SetServiceName(serviceName string) {
p.serviceName = serviceName
}
func (p *PredictClient) tryNext(host string) string {
return p.endpoint.TryNext(host)
}
func (p *PredictClient) createUrl(host string) string {
if len(p.serviceName) != 0 {
if p.serviceName[len(p.serviceName)-1] == '/' {
p.serviceName = p.serviceName[:len(p.serviceName)-1]
}
}
return fmt.Sprintf("http://%s/api/predict/%s", host, p.serviceName)
}
// generateSignature computes the signature header using the access token with hmac sha1 algorithm.
// returns the headers including signature header for authentication.
func (p *PredictClient) generateSignature(requestData []byte) map[string]string {
//canonicalizedResource := fmt.Sprintf("/api/predict/%s", p.serviceName)
//contentMd5 := md5sum(requestData)
contentType := "application/octet-stream"
//currentTime := time.Now().Format("Mon, 02 Jan 2006 15:04:05 GMT")
//verb := "POST"
//auth := fmt.Sprintf("%s\n%s\n%s\n%s\n%s", verb, contentMd5, contentType, currentTime, canonicalizedResource)
//authorization := fmt.Sprintf("EAS %s", hmacSha256(auth, p.token))
return map[string]string{
//"Content-MD5": contentMd5,
//"Date": currentTime,
"Content-Type": contentType,
"Content-Length": fmt.Sprintf("%d", len(requestData)),
//"Authorization": authorization,
"Authorization": p.token,
}
}
// BytesPredict send the raw request data in byte array through http connections,
// retry the request automatically when an error occurs
func (p *PredictClient) BytesPredict(requestData []byte) ([]byte, error) {
host := p.tryNext("")
headers := p.generateSignature(requestData)
for i := 0; i <= p.retryCount; i++ {
if i != 0 {
host = p.tryNext(host)
}
if len(host) == 0 {
return nil, NewPredictError(ErrorCodeServiceDiscovery, host,
fmt.Sprintf("No available endpoint found for service: %v", p.serviceName))
}
url := p.createUrl(host)
req, err := http.NewRequest("POST", url, bytes.NewReader(requestData))
if err != nil {
// retry
if i != p.retryCount {
continue
}
return nil, NewPredictError(ErrorCodeCreateRequest, url, err.Error())
}
if p.token != "" {
for headerName, headerValue := range headers {
req.Header.Set(headerName, headerValue)
}
}
for headerName, headerValue := range p.headers {
req.Header.Set(headerName, headerValue)
}
if p.host != "" {
req.Host = p.host
}
resp, err := p.client.Do(req)
if err != nil {
// retry
if i != p.retryCount {
continue
}
return nil, NewPredictError(ErrorCodePerformRequest, url, err.Error())
}
body, err := ioutil.ReadAll(resp.Body)
if err != nil {
// retry
if i != p.retryCount {
continue
}
return nil, NewPredictError(ErrorCodeReadResponse, url, err.Error())
}
resp.Body.Close()
if resp.StatusCode != 200 {
// retry
if i != p.retryCount {
continue
}
return body, NewPredictError(resp.StatusCode, url, string(body))
}
return body, nil
}
return []byte{}, nil
}
type Request interface {
ToString() (string, error)
}
type Response interface {
unmarshal(body []byte) error
}
// Predict for request
func (p *PredictClient) Predict(request Request) (Response, error) {
req, err2 := request.ToString()
if err2 != nil {
return nil, err2
}
body, err := p.BytesPredict([]byte(req))
if err != nil {
return nil, err
}
switch request.(type) {
case TFRequest:
resp := TFResponse{}
unmarshalErr := resp.unmarshal(body)
return &resp, unmarshalErr
case TorchRequest:
resp := TorchResponse{}
unmarshalErr := resp.unmarshal(body)
return &resp, unmarshalErr
default:
return nil, NewPredictError(-1, "", "Unknown request type, currently support StringRequest, TFRequest and TorchRequest.")
}
}
// StringPredict function send input data and return predicted result
func (p *PredictClient) StringPredict(str string) (string, error) {
body, err := p.BytesPredict([]byte(str))
return string(body), err
}
// TorchPredict function send input data and return PyTorch predicted result
func (p *PredictClient) TorchPredict(request TorchRequest) (*TorchResponse, error) {
resp, err := p.Predict(request)
if err != nil {
return nil, err
}
return resp.(*TorchResponse), err
}
// TFPredict function send input data and return TensorFlow predicted result
func (p *PredictClient) TFPredict(request TFRequest) (*TFResponse, error) {
resp, err := p.Predict(request)
if err != nil {
return nil, err
}
return resp.(*TFResponse), err
}