pkg/eas/queue_client.go (574 lines of code) (raw):
package eas
import (
"bytes"
"context"
"fmt"
"github.com/google/uuid"
"github.com/alibaba/pairec/v2/pkg/eas/types"
"golang.org/x/net/websocket"
"io"
"io/ioutil"
"net/http"
"net/url"
"strconv"
"strings"
"sync"
"time"
)
const (
HeaderRequestId = "X-Eas-Queueservice-Request-Id"
HeaderAuthorization = "Authorization"
DefaultGroupName = "eas"
)
type QueueUser struct {
uid string
gid string
token string
}
func NewQueueUser(uid, gid, token string) QueueUser {
return QueueUser{
uid, gid, token,
}
}
func (c QueueUser) Uid() string {
return c.uid
}
func (c QueueUser) Gid() string {
return c.gid
}
func (c QueueUser) Token() string {
return c.token
}
// QueueClient is client of queue server, which also implements queue service interface
type QueueClient struct {
// HTTP client.
httpClient *http.Client
// base url of queue server.
baseUrl *url.URL
user types.User
WebsocketWatch bool
once sync.Once
attr types.Attributes
// codecs for data frame and attributes.
DCodec types.DataFrameCodec
ACodec types.AttributesCodec
}
func NewQueueClient(endpoint, queueName, token string) (*QueueClient, error) {
baseUrl := strings.Join([]string{endpoint, "api/predict", queueName}, "/")
u, err := url.Parse(baseUrl)
if err != nil {
return nil, err
}
if len(u.Scheme) == 0 {
u.Scheme = "http"
}
uid := uuid.New().String()
gid := DefaultGroupName
cli := &QueueClient{
baseUrl: u,
httpClient: &http.Client{},
user: NewQueueUser(uid, gid, token),
WebsocketWatch: true, // Watch through websocket by default
DCodec: types.DataFrameCodecFor(types.ContentTypeProtobuf),
ACodec: types.AttributesCodecFor(types.ContentTypeProtobuf),
}
return cli, nil
}
func readMessage(reader io.Reader) string {
b, err := ioutil.ReadAll(reader)
if err != nil {
return err.Error()
}
return string(b)
}
func (q *QueueClient) getAttr(force bool) (types.Attributes, error) {
var err error
if q.attr == nil {
q.once.Do(func() { err = q.obtainAttr() })
} else if force {
if err = q.obtainAttr(); err != nil {
return q.attr, fmt.Errorf("failed to obtain attributes, error: %v", err)
}
}
return q.attr, err
}
func (q *QueueClient) obtainAttr() error {
// make a copy of base url.
u := *q.baseUrl
qe := u.Query()
qe.Set("_attrs_", "true")
u.RawQuery = qe.Encode()
req, err := http.NewRequest(http.MethodGet, u.String(), nil)
if err != nil {
return err
}
q.withAuthorization(req)
req.Header.Set("accept", q.ACodec.MediaType())
resp, err := q.httpClient.Do(req)
if err != nil {
return err
}
body, err := ioutil.ReadAll(resp.Body)
if err != nil {
return err
}
defer resp.Body.Close()
if resp.StatusCode != 200 {
return fmt.Errorf("visiting: %s, unexpected status code: %d, body: %s", u.String(), resp.StatusCode, string(body))
}
q.attr = types.Attributes{}
if err = q.ACodec.Decode(body, &q.attr); err != nil {
return err
}
return nil
}
// withIdentity populates user and group id into request.
func (q *QueueClient) withIdentity(req *http.Request) error {
attr, err := q.getAttr(false)
if err != nil {
return err
}
uidHeader := attr[types.UserIdentifyHeader]
gidHeader := attr[types.GroupIdentifyHeader]
if len(uidHeader) == 0 {
return fmt.Errorf("malformed attributes: %v", attr)
} else {
req.Header.Add(uidHeader, q.user.Uid())
}
if len(gidHeader) > 0 {
req.Header.Add(gidHeader, q.user.Gid())
}
return nil
}
// withIdentity populates user and group id into request.
func (q *QueueClient) withAuthorization(req *http.Request) {
if t, ok := q.user.(types.UserWithToken); ok {
req.Header.Add(HeaderAuthorization, t.Token())
}
}
func (q *QueueClient) Truncate(ctx context.Context, index uint64) error {
// make a copy of base url.
u := *q.baseUrl
eq := u.Query()
eq.Set("_index_", strconv.FormatUint(index, 10))
eq.Set("_trunc_", boolString(true))
u.RawQuery = eq.Encode()
req, err := http.NewRequestWithContext(ctx, http.MethodDelete, u.String(), nil)
if err != nil {
return err
}
if err := q.withIdentity(req); err != nil {
return err
}
q.withAuthorization(req)
resp, err := q.httpClient.Do(req)
if err != nil {
return err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return fmt.Errorf("visiting: %s, unexpected status code: %d, message: %s", u.String(), resp.StatusCode, readMessage(resp.Body))
}
return nil
}
func (q *QueueClient) End(ctx context.Context, force bool) error {
// make a copy of base url.
u := *q.baseUrl
eq := u.Query()
eq.Set("_eos_", boolString(true))
if force {
eq.Set("_force_", boolString(true))
}
u.RawQuery = eq.Encode()
req, err := http.NewRequestWithContext(ctx, http.MethodPost, u.String(), nil)
if err != nil {
return err
}
if err := q.withIdentity(req); err != nil {
return err
}
q.withAuthorization(req)
resp, err := q.httpClient.Do(req)
if err != nil {
return err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return fmt.Errorf("visiting: %s, unexpected status code: %d, message: %s", u.String(), resp.StatusCode, readMessage(resp.Body))
}
return nil
}
func (q *QueueClient) Put(ctx context.Context, data []byte, tags types.Tags) (index uint64, requestId string, err error) {
// make a copy of base url.
u := *q.baseUrl
qe := u.Query()
for key, val := range tags {
qe.Set(key, val)
}
u.RawQuery = qe.Encode()
req, err := http.NewRequestWithContext(ctx, http.MethodPost, u.String(), bytes.NewReader(data))
if err != nil {
return 0, requestId, err
}
if err := q.withIdentity(req); err != nil {
return 0, requestId, err
}
q.withAuthorization(req)
resp, err := q.httpClient.Do(req)
if err != nil {
return 0, requestId, err
}
body, err := ioutil.ReadAll(resp.Body)
if err != nil {
return 0, requestId, err
}
requestId = resp.Header.Get(HeaderRequestId)
if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusNoContent {
return 0, requestId, fmt.Errorf("visiting: %s, unexpected status code: %d, message: %s", u.String(), resp.StatusCode, string(body))
}
index, err = strconv.ParseUint(string(body), 0, 64)
if err != nil {
return 0, requestId, err
}
return index, requestId, nil
}
func (q *QueueClient) GetByIndex(ctx context.Context, index uint64) (dfs []types.DataFrame, err error) {
return q.Get(ctx, index, 1, time.Duration(0), true, types.Tags{})
}
func (q *QueueClient) GetByRequestId(ctx context.Context, requestId string) (dfs []types.DataFrame, err error) {
return q.Get(ctx, 0, 1, time.Duration(0), true, types.Tags{"requestId": requestId})
}
func (q *QueueClient) Get(ctx context.Context, index uint64, length int, timeout time.Duration, autoDelete bool, tags types.Tags) (dfs []types.DataFrame, err error) {
var ret []types.DataFrame
u := *q.baseUrl
eq := u.Query()
eq.Set("_index_", strconv.FormatUint(index, 10))
eq.Set("_length_", strconv.FormatInt(int64(length), 10))
eq.Set("_timeout_", timeout.String())
eq.Set("_raw_", boolString(false))
eq.Set("_auto_delete_", boolString(autoDelete))
if err = tags.Validate(); err != nil {
return nil, err
}
for key, val := range tags {
eq.Set(key, val)
}
u.RawQuery = eq.Encode()
req, err := http.NewRequestWithContext(ctx, http.MethodGet, u.String(), nil)
if err != nil {
return ret, err
}
req.Header.Set("Accept", q.DCodec.MediaType())
if err := q.withIdentity(req); err != nil {
return ret, err
}
q.withAuthorization(req)
resp, err := q.httpClient.Do(req)
if err != nil {
return ret, err
}
data, err := ioutil.ReadAll(resp.Body)
if err != nil {
return ret, err
}
defer resp.Body.Close()
if resp.StatusCode >= 300 || resp.StatusCode < 200 {
return ret, fmt.Errorf("visiting: %s, unexpected status code: %d, message: %s", u.String(), resp.StatusCode, string(data))
}
return q.DCodec.DecodeList(data)
}
func boolString(b bool) string {
if b {
return "true"
} else {
return "false"
}
}
type websocketWatcher struct {
ctx context.Context
cancel context.CancelFunc
conn *websocket.Conn
decoder types.DataFrameDecoder
pingFrameWriter io.WriteCloser
ch chan types.DataFrame
}
func newWebsocketWatcher(ctx context.Context, cancel context.CancelFunc, config *websocket.Config, decoder types.DataFrameDecoder) (types.Watcher, error) {
conn, err := websocket.DialConfig(config)
if err != nil {
return nil, err
}
ping, err := conn.NewFrameWriter(websocket.PingFrame)
if err != nil {
return nil, err
}
w := &websocketWatcher{
ctx: ctx,
cancel: cancel,
conn: conn,
decoder: decoder,
pingFrameWriter: ping,
ch: make(chan types.DataFrame, 100),
}
go w.run()
return w, nil
}
func (w *websocketWatcher) FrameChan() <-chan types.DataFrame {
return w.ch
}
func (w *websocketWatcher) Close() {
w.cancel()
}
func (w *websocketWatcher) pingServer() error {
_, err := w.pingFrameWriter.Write([]byte{})
return err
}
func (w *websocketWatcher) run() {
go func() {
ticker := time.NewTicker(time.Second)
defer ticker.Stop()
defer w.pingFrameWriter.Close()
for {
select {
case <-w.ctx.Done():
w.conn.Close()
return
case <-ticker.C:
w.pingServer()
}
}
}()
defer w.cancel()
defer close(w.ch)
var data []byte
for {
df := types.DataFrame{}
err := websocket.Message.Receive(w.conn, &data)
if err != nil {
df.Message = fmt.Sprintf("error reading message: %v", err)
w.ch <- df
return
}
if err = w.decoder.Decode(data, &df); err != nil {
df.Message = fmt.Sprintf("failed to decode message: %v", err)
}
w.ch <- df
}
}
type reconnectWatcher struct {
watcher types.Watcher
userChan chan types.DataFrame
ctx context.Context
cancel context.CancelFunc
}
func newReconnectWatcher(ctx context.Context, cancel context.CancelFunc, config *websocket.Config, decoder types.DataFrameDecoder) (types.Watcher, error) {
// TODO it can be more generic to cover different kind of watcher
wCtx, wCancel := context.WithCancel(context.Background())
websocketWatcher, err := newWebsocketWatcher(wCtx, wCancel, config, decoder)
if err != nil {
return nil, err
}
w := &reconnectWatcher{
watcher: websocketWatcher,
userChan: make(chan types.DataFrame, 100),
ctx: ctx,
cancel: cancel,
}
go w.run(config, decoder)
return w, nil
}
func (w *reconnectWatcher) FrameChan() <-chan types.DataFrame {
return w.userChan
}
func (w *reconnectWatcher) Close() {
w.cancel()
w.watcher.Close()
}
func (w *reconnectWatcher) run(config *websocket.Config, decoder types.DataFrameDecoder) {
defer close(w.userChan)
for {
df, ok := <-w.watcher.FrameChan()
// connection closed
if !ok {
// connection was closed by upstream unexpectedly, try to reconnect
ctx, cancel := context.WithCancel(context.Background())
ticker := time.NewTicker(time.Second)
loop:
for {
select {
case <-ticker.C:
// try to reconnect every 100ms
watcher, err := newWebsocketWatcher(ctx, cancel, config, decoder)
if err != nil {
fmt.Printf("Connect to upstream error: %v, retry...\n", err)
continue
}
w.watcher = watcher
break loop
case <-w.ctx.Done():
// watcher was closed by user
return
}
}
} else {
w.userChan <- df
}
}
}
type httpWatcher struct {
ctx context.Context
cancel context.CancelFunc
reader io.ReadCloser
decoder types.DataFrameDecoder
ch chan types.DataFrame
}
func newHTTPWatcher(ctx context.Context, cancel context.CancelFunc, reader io.ReadCloser, decoder types.DataFrameDecoder) *httpWatcher {
w := &httpWatcher{
ctx: ctx,
cancel: cancel,
reader: reader,
decoder: decoder,
ch: make(chan types.DataFrame, 100),
}
go w.run()
return w
}
func (h *httpWatcher) FrameChan() <-chan types.DataFrame {
return h.ch
}
func (h *httpWatcher) Close() {
h.cancel()
h.reader.Close()
}
func (h *httpWatcher) run() {
go func() {
<-h.ctx.Done()
h.reader.Close()
}()
defer h.cancel()
defer close(h.ch)
rbuf := [4096]byte{}
buf := bytes.NewBuffer(nil)
for {
n, err := h.reader.Read(rbuf[:])
if n > 0 {
io.Copy(buf, bytes.NewBuffer(rbuf[:n]))
if err == io.ErrShortBuffer {
continue
} else if err != nil {
// fmt.Printf("failed to read: %v\n", err)
return
}
df := types.DataFrame{}
if err = h.decoder.Decode(buf.Bytes(), &df); err != nil {
// klog.Errorf("failed to decode, err: %v", err)
return
}
buf.Reset()
h.ch <- df
} else {
break
}
}
}
func (q *QueueClient) Watch(ctx context.Context, index, window uint64, indexOnly bool, autocommit bool) (types.Watcher, error) {
ctx, cancel := context.WithCancel(ctx)
u := *q.baseUrl
eq := u.Query()
eq.Set("_index_", strconv.FormatUint(index, 10))
eq.Set("_window_", strconv.FormatUint(window, 10))
eq.Set("_index_only_", boolString(indexOnly))
eq.Set("_auto_commit_", boolString(autocommit))
eq.Set("_watch_", boolString(true))
u.RawQuery = eq.Encode()
if q.WebsocketWatch {
// use websocket watch.
u.Scheme = "ws"
config, err := websocket.NewConfig(u.String(), q.baseUrl.String())
if err != nil {
cancel()
return nil, err
}
header := http.Header{}
attr, err := q.getAttr(true)
if err != nil {
cancel()
return nil, err
}
uidHeader := attr[types.UserIdentifyHeader]
gidHeader := attr[types.GroupIdentifyHeader]
// set websocket request headers.
header.Set(uidHeader, q.user.Uid())
header.Set("Accept", q.DCodec.MediaType())
header.Set(HeaderAuthorization, q.user.Token())
if len(gidHeader) > 0 {
header.Set(gidHeader, q.user.Gid())
}
config.Header = header
watcher, err := newReconnectWatcher(ctx, cancel, config, q.DCodec)
if err != nil {
cancel()
}
return watcher, err
} else {
// default http watch.
req, err := http.NewRequestWithContext(ctx, http.MethodGet, u.String(), nil)
if err != nil {
cancel()
return nil, err
}
req.Header.Set("Accept", q.DCodec.MediaType())
if err := q.withIdentity(req); err != nil {
return nil, err
}
q.withAuthorization(req)
resp, err := q.httpClient.Do(req)
if err != nil {
cancel()
return nil, err
}
if resp.StatusCode != 200 {
cancel()
content, _ := ioutil.ReadAll(resp.Body)
return nil, fmt.Errorf("unexpected status code: %d, message: %s", resp.StatusCode, string(content))
}
reader := types.NewLengthDelimitedFrameReader(resp.Body)
return newHTTPWatcher(ctx, cancel, reader, q.DCodec), nil
}
}
func (q *QueueClient) Commit(ctx context.Context, indexes ...uint64) error {
// make a copy of base url.
u := *q.baseUrl
var indexStr []string
for _, idx := range indexes {
indexStr = append(indexStr, strconv.FormatUint(idx, 10))
}
eq := u.Query()
eq.Set("_indexes_", strings.Join(indexStr, ","))
// eq.Set("_delete_", boolString(del))
u.RawQuery = eq.Encode()
req, err := http.NewRequestWithContext(ctx, http.MethodPut, u.String(), nil)
if err != nil {
return err
}
if err := q.withIdentity(req); err != nil {
return err
}
q.withAuthorization(req)
resp, err := q.httpClient.Do(req)
if err != nil {
return err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return fmt.Errorf("visiting: %s, unexpected status code %d, message: %s", u.String(), resp.StatusCode, readMessage(resp.Body))
}
return nil
}
func (q *QueueClient) Del(ctx context.Context, indexes ...uint64) error {
// make a copy of base url.
u := *q.baseUrl
var indexStr []string
for _, idx := range indexes {
indexStr = append(indexStr, strconv.FormatUint(idx, 10))
}
eq := u.Query()
eq.Set("_indexes_", strings.Join(indexStr, ","))
u.RawQuery = eq.Encode()
req, err := http.NewRequestWithContext(ctx, http.MethodDelete, u.String(), nil)
if err != nil {
return err
}
if err := q.withIdentity(req); err != nil {
return err
}
q.withAuthorization(req)
resp, err := q.httpClient.Do(req)
if err != nil {
return err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return fmt.Errorf("visiting: %s, unexpected status code %d, message: %s", u.String(), resp.StatusCode, readMessage(resp.Body))
}
return nil
}
func (q *QueueClient) Attributes() (types.Attributes, error) {
return q.getAttr(true)
}