graphql/handler/transport/websocket.go (416 lines of code) (raw):
package transport
import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"log"
"net"
"net/http"
"sync"
"time"
"github.com/gorilla/websocket"
"github.com/vektah/gqlparser/v2/gqlerror"
"github.com/99designs/gqlgen/graphql"
"github.com/99designs/gqlgen/graphql/errcode"
)
type (
Websocket struct {
Upgrader websocket.Upgrader
InitFunc WebsocketInitFunc
InitTimeout time.Duration
ErrorFunc WebsocketErrorFunc
CloseFunc WebsocketCloseFunc
KeepAlivePingInterval time.Duration
PongOnlyInterval time.Duration
PingPongInterval time.Duration
/* If PingPongInterval has a non-0 duration, then when the server sends a ping
* it sets a ReadDeadline of PingPongInterval*2 and if the client doesn't respond
* with pong before that deadline is reached then the connection will die with a
* 1006 error code.
*
* MissingPongOk if true, tells the server to not use a ReadDeadline such that a
* missing/slow pong response from the client doesn't kill the connection.
*/
MissingPongOk bool
didInjectSubprotocols bool
}
wsConnection struct {
Websocket
ctx context.Context
conn *websocket.Conn
me messageExchanger
active map[string]context.CancelFunc
mu sync.Mutex
keepAliveTicker *time.Ticker
pongOnlyTicker *time.Ticker
pingPongTicker *time.Ticker
receivedPong bool
exec graphql.GraphExecutor
closed bool
headers http.Header
initPayload InitPayload
}
WebsocketInitFunc func(ctx context.Context, initPayload InitPayload) (context.Context, *InitPayload, error)
WebsocketErrorFunc func(ctx context.Context, err error)
// Callback called when websocket is closed.
WebsocketCloseFunc func(ctx context.Context, closeCode int)
)
var errReadTimeout = errors.New("read timeout")
type WebsocketError struct {
Err error
// IsReadError flags whether the error occurred on read or write to the websocket
IsReadError bool
}
func (e WebsocketError) Error() string {
if e.IsReadError {
return fmt.Sprintf("websocket read: %v", e.Err)
}
return fmt.Sprintf("websocket write: %v", e.Err)
}
var (
_ graphql.Transport = Websocket{}
_ error = WebsocketError{}
)
func (t Websocket) Supports(r *http.Request) bool {
return r.Header.Get("Upgrade") != ""
}
func (t Websocket) Do(w http.ResponseWriter, r *http.Request, exec graphql.GraphExecutor) {
t.injectGraphQLWSSubprotocols()
ws, err := t.Upgrader.Upgrade(w, r, http.Header{})
if err != nil {
log.Printf("unable to upgrade %T to websocket %s: ", w, err.Error())
SendErrorf(w, http.StatusBadRequest, "unable to upgrade")
return
}
var me messageExchanger
switch ws.Subprotocol() {
default:
msg := websocket.FormatCloseMessage(websocket.CloseProtocolError, fmt.Sprintf("unsupported negotiated subprotocol %s", ws.Subprotocol()))
_ = ws.WriteMessage(websocket.CloseMessage, msg)
return
case graphqlwsSubprotocol, "":
// clients are required to send a subprotocol, to be backward compatible with the previous implementation we select
// "graphql-ws" by default
me = graphqlwsMessageExchanger{c: ws}
case graphqltransportwsSubprotocol:
me = graphqltransportwsMessageExchanger{c: ws}
}
conn := wsConnection{
active: map[string]context.CancelFunc{},
conn: ws,
ctx: r.Context(),
exec: exec,
me: me,
headers: r.Header,
Websocket: t,
}
if !conn.init() {
return
}
conn.run()
}
func (c *wsConnection) handlePossibleError(err error, isReadError bool) {
if c.ErrorFunc != nil && err != nil {
c.ErrorFunc(c.ctx, WebsocketError{
Err: err,
IsReadError: isReadError,
})
}
}
func (c *wsConnection) nextMessageWithTimeout(timeout time.Duration) (message, error) {
messages, errs := make(chan message, 1), make(chan error, 1)
go func() {
if m, err := c.me.NextMessage(); err != nil {
errs <- err
} else {
messages <- m
}
}()
select {
case m := <-messages:
return m, nil
case err := <-errs:
return message{}, err
case <-time.After(timeout):
return message{}, errReadTimeout
}
}
func (c *wsConnection) init() bool {
var m message
var err error
if c.InitTimeout != 0 {
m, err = c.nextMessageWithTimeout(c.InitTimeout)
} else {
m, err = c.me.NextMessage()
}
if err != nil {
if err == errReadTimeout {
c.close(websocket.CloseProtocolError, "connection initialisation timeout")
return false
}
if err == errInvalidMsg {
c.sendConnectionError("invalid json")
}
c.close(websocket.CloseProtocolError, "decoding error")
return false
}
switch m.t {
case initMessageType:
if len(m.payload) > 0 {
c.initPayload = make(InitPayload)
err := json.Unmarshal(m.payload, &c.initPayload)
if err != nil {
return false
}
}
var initAckPayload *InitPayload
if c.InitFunc != nil {
var ctx context.Context
ctx, initAckPayload, err = c.InitFunc(c.ctx, c.initPayload)
if err != nil {
c.sendConnectionError("%s", err.Error())
c.close(websocket.CloseNormalClosure, "terminated")
return false
}
c.ctx = ctx
}
if initAckPayload != nil {
initJsonAckPayload, err := json.Marshal(*initAckPayload)
if err != nil {
panic(err)
}
c.write(&message{t: connectionAckMessageType, payload: initJsonAckPayload})
} else {
c.write(&message{t: connectionAckMessageType})
}
c.write(&message{t: keepAliveMessageType})
case connectionCloseMessageType:
c.close(websocket.CloseNormalClosure, "terminated")
return false
default:
c.sendConnectionError("unexpected message %s", m.t)
c.close(websocket.CloseProtocolError, "unexpected message")
return false
}
return true
}
func (c *wsConnection) write(msg *message) {
c.mu.Lock()
c.handlePossibleError(c.me.Send(msg), false)
c.mu.Unlock()
}
func (c *wsConnection) run() {
// We create a cancellation that will shutdown the keep-alive when we leave
// this function.
ctx, cancel := context.WithCancel(c.ctx)
defer func() {
cancel()
}()
// If we're running in graphql-ws mode, create a timer that will trigger a
// keep alive message every interval
if (c.conn.Subprotocol() == "" || c.conn.Subprotocol() == graphqlwsSubprotocol) && c.KeepAlivePingInterval != 0 {
c.mu.Lock()
c.keepAliveTicker = time.NewTicker(c.KeepAlivePingInterval)
c.mu.Unlock()
go c.keepAlive(ctx)
}
// If we're running in graphql-transport-ws mode, create a timer that will trigger a
// just a pong message every interval
if c.conn.Subprotocol() == graphqltransportwsSubprotocol && c.PongOnlyInterval != 0 {
c.mu.Lock()
c.pongOnlyTicker = time.NewTicker(c.PongOnlyInterval)
c.mu.Unlock()
go c.keepAlivePongOnly(ctx)
}
// If we're running in graphql-transport-ws mode, create a timer that will
// trigger a ping message every interval and expect a pong!
if c.conn.Subprotocol() == graphqltransportwsSubprotocol && c.PingPongInterval != 0 {
c.mu.Lock()
c.pingPongTicker = time.NewTicker(c.PingPongInterval)
c.mu.Unlock()
if !c.MissingPongOk {
// Note: when the connection is closed by this deadline, the client
// will receive an "invalid close code"
_ = c.conn.SetReadDeadline(time.Now().UTC().Add(2 * c.PingPongInterval))
}
go c.ping(ctx)
}
// Close the connection when the context is cancelled.
// Will optionally send a "close reason" that is retrieved from the context.
go c.closeOnCancel(ctx)
for {
start := graphql.Now()
m, err := c.me.NextMessage()
if err != nil {
// If the connection got closed by us, don't report the error
if !errors.Is(err, net.ErrClosed) {
c.handlePossibleError(err, true)
}
return
}
switch m.t {
case startMessageType:
c.subscribe(start, &m)
case stopMessageType:
c.mu.Lock()
closer := c.active[m.id]
c.mu.Unlock()
if closer != nil {
closer()
}
case connectionCloseMessageType:
c.close(websocket.CloseNormalClosure, "terminated")
return
case pingMessageType:
c.write(&message{t: pongMessageType, payload: m.payload})
case pongMessageType:
c.mu.Lock()
c.receivedPong = true
c.mu.Unlock()
// Clear ReadTimeout -- 0 time val clears.
_ = c.conn.SetReadDeadline(time.Time{})
default:
c.sendConnectionError("unexpected message %s", m.t)
c.close(websocket.CloseProtocolError, "unexpected message")
return
}
}
}
func (c *wsConnection) keepAlivePongOnly(ctx context.Context) {
for {
select {
case <-ctx.Done():
c.pongOnlyTicker.Stop()
return
case <-c.pongOnlyTicker.C:
c.write(&message{t: pongMessageType, payload: json.RawMessage{}})
}
}
}
func (c *wsConnection) keepAlive(ctx context.Context) {
for {
select {
case <-ctx.Done():
c.keepAliveTicker.Stop()
return
case <-c.keepAliveTicker.C:
c.write(&message{t: keepAliveMessageType})
}
}
}
func (c *wsConnection) ping(ctx context.Context) {
for {
select {
case <-ctx.Done():
c.pingPongTicker.Stop()
return
case <-c.pingPongTicker.C:
c.write(&message{t: pingMessageType, payload: json.RawMessage{}})
// The initial deadline for this method is set in run()
// if we have not yet received a pong, don't reset the deadline.
c.mu.Lock()
if !c.MissingPongOk && c.receivedPong {
_ = c.conn.SetReadDeadline(time.Now().UTC().Add(2 * c.PingPongInterval))
}
c.receivedPong = false
c.mu.Unlock()
}
}
}
func (c *wsConnection) closeOnCancel(ctx context.Context) {
<-ctx.Done()
if r := closeReasonForContext(ctx); r != "" {
c.sendConnectionError("%s", r)
}
c.close(websocket.CloseNormalClosure, "terminated")
}
func (c *wsConnection) subscribe(start time.Time, msg *message) {
ctx := graphql.StartOperationTrace(c.ctx)
var params *graphql.RawParams
if err := jsonDecode(bytes.NewReader(msg.payload), ¶ms); err != nil {
c.sendError(msg.id, &gqlerror.Error{Message: "invalid json"})
c.complete(msg.id)
return
}
params.ReadTime = graphql.TraceTiming{
Start: start,
End: graphql.Now(),
}
params.Headers = c.headers
rc, err := c.exec.CreateOperationContext(ctx, params)
if err != nil {
resp := c.exec.DispatchError(graphql.WithOperationContext(ctx, rc), err)
switch errcode.GetErrorKind(err) {
case errcode.KindProtocol:
c.sendError(msg.id, resp.Errors...)
default:
c.sendResponse(msg.id, &graphql.Response{Errors: err})
}
c.complete(msg.id)
return
}
ctx = graphql.WithOperationContext(ctx, rc)
if c.initPayload != nil {
ctx = withInitPayload(ctx, c.initPayload)
}
ctx, cancel := context.WithCancel(ctx)
c.mu.Lock()
c.active[msg.id] = cancel
c.mu.Unlock()
go func() {
ctx = withSubscriptionErrorContext(ctx)
defer func() {
if r := recover(); r != nil {
err := rc.Recover(ctx, r)
var gqlerr *gqlerror.Error
if !errors.As(err, &gqlerr) {
gqlerr = &gqlerror.Error{}
if err != nil {
gqlerr.Message = err.Error()
}
}
c.sendError(msg.id, gqlerr)
}
if errs := getSubscriptionError(ctx); len(errs) != 0 {
c.sendError(msg.id, errs...)
} else {
c.complete(msg.id)
}
c.mu.Lock()
delete(c.active, msg.id)
c.mu.Unlock()
cancel()
}()
responses, ctx := c.exec.DispatchOperation(ctx, rc)
for {
response := responses(ctx)
if response == nil {
break
}
c.sendResponse(msg.id, response)
}
// complete and context cancel comes from the defer
}()
}
func (c *wsConnection) sendResponse(id string, response *graphql.Response) {
b, err := json.Marshal(response)
if err != nil {
panic(err)
}
c.write(&message{
payload: b,
id: id,
t: dataMessageType,
})
}
func (c *wsConnection) complete(id string) {
c.write(&message{id: id, t: completeMessageType})
}
func (c *wsConnection) sendError(id string, errors ...*gqlerror.Error) {
errs := make([]error, len(errors))
for i, err := range errors {
errs[i] = err
}
b, err := json.Marshal(errs)
if err != nil {
panic(err)
}
c.write(&message{t: errorMessageType, id: id, payload: b})
}
func (c *wsConnection) sendConnectionError(format string, args ...any) {
b, err := json.Marshal(&gqlerror.Error{Message: fmt.Sprintf(format, args...)})
if err != nil {
panic(err)
}
c.write(&message{t: connectionErrorMessageType, payload: b})
}
func (c *wsConnection) close(closeCode int, message string) {
c.mu.Lock()
if c.closed {
c.mu.Unlock()
return
}
_ = c.conn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(closeCode, message))
for _, closer := range c.active {
closer()
}
c.closed = true
c.mu.Unlock()
_ = c.conn.Close()
if c.CloseFunc != nil {
c.CloseFunc(c.ctx, closeCode)
}
}