router/core/websocket.go (988 lines of code) (raw):
package core
import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"net"
"net/http"
"regexp"
"slices"
"sync"
"syscall"
"time"
"github.com/buger/jsonparser"
"github.com/go-chi/chi/v5/middleware"
"github.com/gobwas/ws"
"github.com/gobwas/ws/wsutil"
"github.com/gorilla/websocket"
"github.com/tidwall/gjson"
"go.opentelemetry.io/otel/attribute"
"go.uber.org/atomic"
"go.uber.org/zap"
"golang.org/x/sync/semaphore"
"github.com/wundergraph/astjson"
"github.com/wundergraph/graphql-go-tools/v2/pkg/engine/plan"
"github.com/wundergraph/graphql-go-tools/v2/pkg/engine/resolve"
"github.com/wundergraph/graphql-go-tools/v2/pkg/netpoll"
"github.com/wundergraph/cosmo/router/internal/expr"
"github.com/wundergraph/cosmo/router/internal/wsproto"
"github.com/wundergraph/cosmo/router/pkg/authentication"
"github.com/wundergraph/cosmo/router/pkg/config"
"github.com/wundergraph/cosmo/router/pkg/logging"
"github.com/wundergraph/cosmo/router/pkg/statistics"
rtrace "github.com/wundergraph/cosmo/router/pkg/trace"
)
var (
errClientTerminatedConnection = errors.New("client terminated connection")
)
type WebsocketMiddlewareOptions struct {
OperationProcessor *OperationProcessor
OperationBlocker *OperationBlocker
Planner *OperationPlanner
GraphQLHandler *GraphQLHandler
PreHandler *PreHandler
Metrics RouterMetrics
AccessController *AccessController
Logger *zap.Logger
Stats statistics.EngineStatistics
ReadTimeout time.Duration
EnableNetPoll bool
NetPollTimeout time.Duration
NetPollConnBufferSize int
WebSocketConfiguration *config.WebSocketConfiguration
ClientHeader config.ClientHeader
Attributes []attribute.KeyValue
DisableVariablesRemapping bool
ApolloCompatibilityFlags config.ApolloCompatibilityFlags
}
func NewWebsocketMiddleware(ctx context.Context, opts WebsocketMiddlewareOptions) func(http.Handler) http.Handler {
handler := &WebsocketHandler{
ctx: ctx,
operationProcessor: opts.OperationProcessor,
operationBlocker: opts.OperationBlocker,
planner: opts.Planner,
graphqlHandler: opts.GraphQLHandler,
preHandler: opts.PreHandler,
metrics: opts.Metrics,
accessController: opts.AccessController,
logger: opts.Logger,
stats: opts.Stats,
readTimeout: opts.ReadTimeout,
config: opts.WebSocketConfiguration,
clientHeader: opts.ClientHeader,
handlerSem: semaphore.NewWeighted(128),
attributes: opts.Attributes,
disableVariablesRemapping: opts.DisableVariablesRemapping,
apolloCompatibilityFlags: opts.ApolloCompatibilityFlags,
}
if opts.WebSocketConfiguration != nil && opts.WebSocketConfiguration.AbsintheProtocol.Enabled {
handler.absintheHandlerEnabled = true
handler.absintheHandlerPath = opts.WebSocketConfiguration.AbsintheProtocol.HandlerPath
}
if opts.WebSocketConfiguration.ForwardUpgradeHeaders.Enabled {
handler.forwardUpgradeHeadersConfig.enabled = true
for _, str := range opts.WebSocketConfiguration.ForwardUpgradeHeaders.AllowList {
if detectNonRegex.MatchString(str) {
canonicalHeaderKey := http.CanonicalHeaderKey(str)
handler.forwardUpgradeHeadersConfig.staticAllowList = append(handler.forwardUpgradeHeadersConfig.staticAllowList, canonicalHeaderKey)
} else {
re, err := regexp.Compile(str)
if err != nil {
opts.Logger.Warn("Invalid regex in forward upgrade headers allow list", zap.String("regex", str), zap.Error(err))
continue
}
handler.forwardUpgradeHeadersConfig.regexAllowList = append(handler.forwardUpgradeHeadersConfig.regexAllowList, re)
}
}
handler.forwardUpgradeHeadersConfig.withStaticAllowList = len(handler.forwardUpgradeHeadersConfig.staticAllowList) > 0
handler.forwardUpgradeHeadersConfig.withRegexAllowList = len(handler.forwardUpgradeHeadersConfig.regexAllowList) > 0
}
if opts.WebSocketConfiguration.ForwardUpgradeQueryParams.Enabled {
handler.forwardQueryParamsConfig.enabled = true
for _, str := range opts.WebSocketConfiguration.ForwardUpgradeQueryParams.AllowList {
if detectNonRegex.MatchString(str) {
handler.forwardQueryParamsConfig.staticAllowList = append(handler.forwardQueryParamsConfig.staticAllowList, str)
} else {
re, err := regexp.Compile(str)
if err != nil {
opts.Logger.Warn("Invalid regex in forward upgrade query params allow list", zap.String("regex", str), zap.Error(err))
continue
}
handler.forwardQueryParamsConfig.regexAllowList = append(handler.forwardQueryParamsConfig.regexAllowList, re)
}
}
handler.forwardQueryParamsConfig.withStaticAllowList = len(handler.forwardQueryParamsConfig.staticAllowList) > 0
handler.forwardQueryParamsConfig.withRegexAllowList = len(handler.forwardQueryParamsConfig.regexAllowList) > 0
}
if opts.EnableNetPoll {
poller, err := netpoll.NewPoller(opts.NetPollConnBufferSize, opts.NetPollTimeout)
if err == nil {
opts.Logger.Debug("Net poller is available")
handler.netPoll = poller
handler.connections = make(map[int]*WebSocketConnectionHandler)
go handler.runPoller()
}
}
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if !websocket.IsWebSocketUpgrade(r) {
next.ServeHTTP(w, r)
return
}
handler.handleUpgradeRequest(w, r)
})
}
}
// wsConnectionWrapper is a wrapper around websocket.Conn that allows
// writing from multiple goroutines
type wsConnectionWrapper struct {
conn net.Conn
mu sync.Mutex
}
func newWSConnectionWrapper(conn net.Conn) *wsConnectionWrapper {
return &wsConnectionWrapper{
conn: conn,
}
}
func (c *wsConnectionWrapper) ReadJSON(v interface{}) error {
text, err := wsutil.ReadClientText(c.conn)
if err != nil {
return err
}
return json.Unmarshal(text, v)
}
func (c *wsConnectionWrapper) WriteText(text string) error {
c.mu.Lock()
defer c.mu.Unlock()
return wsutil.WriteServerText(c.conn, []byte(text))
}
func (c *wsConnectionWrapper) WriteJSON(v interface{}) error {
c.mu.Lock()
defer c.mu.Unlock()
data, err := json.Marshal(v)
if err != nil {
return err
}
return wsutil.WriteServerText(c.conn, data)
}
func (c *wsConnectionWrapper) Close() error {
c.mu.Lock()
defer c.mu.Unlock()
return c.conn.Close()
}
type WebsocketHandler struct {
ctx context.Context
config *config.WebSocketConfiguration
operationProcessor *OperationProcessor
operationBlocker *OperationBlocker
planner *OperationPlanner
graphqlHandler *GraphQLHandler
preHandler *PreHandler
metrics RouterMetrics
accessController *AccessController
logger *zap.Logger
netPoll netpoll.Poller
connections map[int]*WebSocketConnectionHandler
connectionsMu sync.RWMutex
handlerSem *semaphore.Weighted
connectionIDs atomic.Int64
stats statistics.EngineStatistics
attributes []attribute.KeyValue
readTimeout time.Duration
absintheHandlerEnabled bool
absintheHandlerPath string
forwardUpgradeHeadersConfig forwardConfig
forwardQueryParamsConfig forwardConfig
clientHeader config.ClientHeader
disableVariablesRemapping bool
apolloCompatibilityFlags config.ApolloCompatibilityFlags
}
func (h *WebsocketHandler) handleUpgradeRequest(w http.ResponseWriter, r *http.Request) {
var (
subProtocol string
)
requestID := middleware.GetReqID(r.Context())
requestContext := getRequestContext(r.Context())
requestLogger := h.logger.With(logging.WithRequestID(requestID), logging.WithTraceID(rtrace.GetTraceID(r.Context())))
clientInfo := NewClientInfoFromRequest(r, h.clientHeader)
if h.accessController != nil && !h.config.Authentication.FromInitialPayload.Enabled {
// Check access control before upgrading the connection
validatedReq, err := h.accessController.Access(w, r)
if err != nil {
statusCode := http.StatusForbidden
if errors.Is(err, ErrUnauthorized) {
statusCode = http.StatusUnauthorized
}
http.Error(w, http.StatusText(statusCode), statusCode)
return
}
r = validatedReq
requestContext.expressionContext.Request.Auth = expr.LoadAuth(r.Context())
}
upgrader := ws.HTTPUpgrader{
Timeout: time.Second * 5,
Protocol: func(s string) bool {
if wsproto.IsSupportedSubprotocol(s) {
subProtocol = s
return true
}
return false
},
}
c, _, _, err := upgrader.Upgrade(r, w)
if err != nil {
requestLogger.Warn("Websocket upgrade", zap.Error(err))
_ = c.Close()
return
}
// legacy absinthe clients don't set the Sec-WebSocket-Protocol header (Subprotocol)
// so we need to check the path to determine if it's an absinthe client and set the subprotocol manually
if subProtocol == "" && h.absintheHandlerEnabled && r.URL.Path == h.absintheHandlerPath {
subProtocol = wsproto.AbsintheWSSubProtocol
}
// After successful upgrade, we can't write to the response writer anymore
// because it's hijacked by the websocket connection
conn := newWSConnectionWrapper(c)
protocol, err := wsproto.NewProtocol(subProtocol, conn)
if err != nil {
requestLogger.Error("Create websocket protocol", zap.Error(err))
_ = c.Close()
return
}
// We can parse the request options before creating the handler
// this avoids touching the client request across goroutines
executionOptions, traceOptions, err := h.preHandler.parseRequestOptions(r, clientInfo, requestLogger)
if err != nil {
requestLogger.Error("Parse request options", zap.Error(err))
_ = c.Close()
return
}
planOptions := PlanOptions{
ClientInfo: clientInfo,
TraceOptions: traceOptions,
ExecutionOptions: executionOptions,
TrackSchemaUsageInfo: h.preHandler.trackSchemaUsageInfo,
}
handler := NewWebsocketConnectionHandler(h.ctx, WebSocketConnectionHandlerOptions{
OperationProcessor: h.operationProcessor,
OperationBlocker: h.operationBlocker,
Planner: h.planner,
GraphQLHandler: h.graphqlHandler,
PreHandler: h.preHandler,
Metrics: h.metrics,
PlanOptions: planOptions,
ResponseWriter: w,
Request: r,
Connection: conn,
Protocol: protocol,
Logger: requestLogger,
Stats: h.stats,
ConnectionID: h.connectionIDs.Inc(),
ClientInfo: clientInfo,
InitRequestID: requestID,
Config: h.config,
ForwardUpgradeHeaders: h.forwardUpgradeHeadersConfig,
ForwardQueryParams: h.forwardQueryParamsConfig,
Attributes: h.attributes,
DisableVariablesRemapping: h.disableVariablesRemapping,
ApolloCompatibilityFlags: h.apolloCompatibilityFlags,
})
err = handler.Initialize()
if err != nil {
// Don't produce errors logs here because it can only be client side errors
// e.g. slow client, aborted connection, invalid JSON, etc.
// We log it as debug because it's not a server side error
requestLogger.Debug("Initializing websocket connection", zap.Error(err))
handler.Close()
return
}
// Authenticate the connection using the initial payload
fromInitialPayloadConfig := h.config.Authentication.FromInitialPayload
if fromInitialPayloadConfig.Enabled {
// Setting the initialPayload in the context to be used by the websocketInitialPayloadAuthenticator
r = r.WithContext(authentication.WithWebsocketInitialPayloadContextKey(r.Context(), handler.initialPayload))
// Later check access control after initial payload is read and set into the context
if h.accessController != nil {
handler.request, err = h.accessController.Access(w, r)
if err != nil {
statusCode := http.StatusForbidden
if errors.Is(err, ErrUnauthorized) {
statusCode = http.StatusUnauthorized
}
http.Error(handler.w, http.StatusText(statusCode), statusCode)
_ = handler.writeErrorMessage(requestID, err)
handler.Close()
return
}
}
// Export the token from the initial payload to the request header
if fromInitialPayloadConfig.ExportToken.Enabled {
var initialPayloadMap map[string]interface{}
err := json.Unmarshal(handler.initialPayload, &initialPayloadMap)
if err != nil {
requestLogger.Error("Error parsing initial payload: %v", zap.Error(err))
_ = handler.writeErrorMessage(requestID, err)
handler.Close()
return
}
jwtToken, ok := initialPayloadMap[fromInitialPayloadConfig.Key].(string)
if !ok {
err := fmt.Errorf("invalid JWT token in initial payload: JWT token is not a string")
requestLogger.Error(err.Error())
_ = handler.writeErrorMessage(requestID, err)
handler.Close()
return
}
handler.request.Header.Set(fromInitialPayloadConfig.ExportToken.HeaderKey, jwtToken)
}
requestContext.expressionContext.Request.Auth = expr.LoadAuth(handler.request.Context())
}
// Only when epoll/kqueue is available. On Windows, epoll is not available
if h.netPoll != nil {
err = h.addConnection(c, handler)
if err != nil {
requestLogger.Error("Adding connection to net poller", zap.Error(err))
handler.Close()
}
return
}
// Handle messages sync when net poller implementation is not available
go h.handleConnectionSync(handler)
}
func (h *WebsocketHandler) handleConnectionSync(handler *WebSocketConnectionHandler) {
h.stats.ConnectionsInc()
defer h.stats.ConnectionsDec()
serverDone := h.ctx.Done()
defer handler.Close()
for {
select {
case <-serverDone:
return
default:
// It's important to set the ReadDeadline
// Otherwise, the following "ReadMessage" call will block forever
err := handler.conn.conn.SetReadDeadline(time.Now().Add(h.readTimeout))
if err != nil {
h.logger.Debug("Setting read deadline", zap.Error(err))
return
}
msg, err := handler.protocol.ReadMessage()
if err != nil {
if isReadTimeout(err) {
continue
}
h.logger.Debug("Client closed connection")
return
}
err = h.HandleMessage(handler, msg)
if err != nil {
h.logger.Debug("Handling websocket message", zap.Error(err))
if errors.Is(err, errClientTerminatedConnection) {
return
}
}
}
}
}
func (h *WebsocketHandler) addConnection(conn net.Conn, handler *WebSocketConnectionHandler) error {
h.stats.ConnectionsInc()
h.connectionsMu.Lock()
defer h.connectionsMu.Unlock()
fd := socketFd(conn)
if fd == 0 {
return fmt.Errorf("unable to get socket fd for conn: %d", handler.connectionID)
}
h.connections[fd] = handler
return h.netPoll.Add(conn)
}
func (h *WebsocketHandler) removeConnection(conn net.Conn, handler *WebSocketConnectionHandler, fd int) {
h.stats.ConnectionsDec()
h.connectionsMu.Lock()
delete(h.connections, fd)
h.connectionsMu.Unlock()
err := h.netPoll.Remove(conn)
if err != nil {
h.logger.Warn("Removing connection from net poller", zap.Error(err))
}
handler.Close()
}
func socketFd(conn net.Conn) int {
if con, ok := conn.(syscall.Conn); ok {
raw, err := con.SyscallConn()
if err != nil {
return 0
}
sfd := 0
_ = raw.Control(func(fd uintptr) {
sfd = int(fd)
})
return sfd
}
if con, ok := conn.(netpoll.ConnImpl); ok {
return con.GetFD()
}
return 0
}
func isReadTimeout(err error) bool {
if err == nil {
return false
}
var netErr net.Error
if errors.As(err, &netErr) {
return netErr.Timeout()
}
return false
}
func (h *WebsocketHandler) runPoller() {
done := h.ctx.Done()
defer func() {
h.connectionsMu.Lock()
_ = h.netPoll.Close(true)
h.connectionsMu.Unlock()
}()
for {
select {
case <-done:
return
default:
connections, err := h.netPoll.Wait(128)
if err != nil {
h.logger.Warn("Net Poller wait", zap.Error(err))
continue
}
for i := 0; i < len(connections); i++ {
if connections[i] == nil {
continue
}
conn := connections[i].(netpoll.ConnImpl)
// check if the connection is still valid
fd := socketFd(conn)
h.connectionsMu.RLock()
handler, exists := h.connections[fd]
h.connectionsMu.RUnlock()
if !exists {
continue
}
if fd == 0 {
h.logger.Debug("Invalid socket fd", zap.Int("fd", fd))
h.removeConnection(conn, handler, fd)
continue
}
err = handler.conn.conn.SetReadDeadline(time.Now().Add(h.readTimeout))
if err != nil {
h.logger.Debug("Setting read deadline", zap.Error(err))
h.removeConnection(conn, handler, fd)
continue
}
msg, err := handler.protocol.ReadMessage()
if err != nil {
h.logger.Debug("Client closed connection")
h.removeConnection(conn, handler, fd)
continue
}
err = h.HandleMessage(handler, msg)
if err != nil {
h.logger.Debug("Handling websocket message", zap.Error(err))
if errors.Is(err, errClientTerminatedConnection) {
h.removeConnection(conn, handler, fd)
continue
}
}
}
}
}
}
type websocketResponseWriter struct {
id string
protocol wsproto.Proto
header http.Header
buf bytes.Buffer
writtenBytes int
logger *zap.Logger
stats statistics.EngineStatistics
propagateErrors bool
}
var _ http.ResponseWriter = (*websocketResponseWriter)(nil)
var _ resolve.SubscriptionResponseWriter = (*websocketResponseWriter)(nil)
func newWebsocketResponseWriter(id string, protocol wsproto.Proto, propagateErrors bool, logger *zap.Logger, stats statistics.EngineStatistics) *websocketResponseWriter {
return &websocketResponseWriter{
id: id,
protocol: protocol,
header: make(http.Header),
logger: logger.With(zap.String("subscription_id", id)),
stats: stats,
propagateErrors: propagateErrors,
}
}
func (rw *websocketResponseWriter) Header() http.Header {
return rw.header
}
func (rw *websocketResponseWriter) WriteHeader(statusCode int) {
rw.logger.Debug("Response status code", zap.Int("status_code", statusCode))
}
func (rw *websocketResponseWriter) Complete() {
err := rw.protocol.Done(rw.id)
if err != nil {
rw.logger.Debug("Sending complete message", zap.Error(err))
}
}
func (rw *websocketResponseWriter) Write(data []byte) (int, error) {
rw.writtenBytes += len(data)
return rw.buf.Write(data)
}
func (rw *websocketResponseWriter) Flush() error {
if rw.buf.Len() > 0 {
rw.logger.Debug("flushing", zap.Int("bytes", rw.buf.Len()))
payload := rw.buf.Bytes()
var extensions []byte
var err error
if len(rw.header) > 0 {
extensions, err = json.Marshal(map[string]any{
"response_headers": rw.header,
})
if err != nil {
rw.logger.Warn("Serializing response headers", zap.Error(err))
return err
}
}
// Check if the result is an error
errorsResult := gjson.GetBytes(payload, "errors")
if errorsResult.Type == gjson.JSON {
if rw.propagateErrors {
err = rw.protocol.WriteGraphQLErrors(rw.id, json.RawMessage(errorsResult.Raw), extensions)
} else {
err = rw.protocol.WriteGraphQLErrors(rw.id, json.RawMessage(`[{"message":"Unable to subscribe"}]`), extensions)
}
} else {
err = rw.protocol.WriteGraphQLData(rw.id, payload, extensions)
}
rw.buf.Reset()
if err != nil {
return err
}
}
return nil
}
func (rw *websocketResponseWriter) SubscriptionResponseWriter() resolve.SubscriptionResponseWriter {
return rw
}
type graphqlError struct {
Message string `json:"message"`
Extensions *Extensions `json:"extensions,omitempty"`
}
type WebSocketConnectionHandlerOptions struct {
Config *config.WebSocketConfiguration
OperationProcessor *OperationProcessor
OperationBlocker *OperationBlocker
Planner *OperationPlanner
GraphQLHandler *GraphQLHandler
PreHandler *PreHandler
Metrics RouterMetrics
ResponseWriter http.ResponseWriter
Request *http.Request
Connection *wsConnectionWrapper
Protocol wsproto.Proto
Logger *zap.Logger
Stats statistics.EngineStatistics
PlanOptions PlanOptions
ConnectionID int64
ClientInfo *ClientInfo
InitRequestID string
ForwardUpgradeHeaders forwardConfig
ForwardQueryParams forwardConfig
Attributes []attribute.KeyValue
DisableVariablesRemapping bool
ApolloCompatibilityFlags config.ApolloCompatibilityFlags
}
type WebSocketConnectionHandler struct {
ctx context.Context
operationProcessor *OperationProcessor
operationBlocker *OperationBlocker
planner *OperationPlanner
graphqlHandler *GraphQLHandler
plannerOptions PlanOptions
preHandler *PreHandler
metrics RouterMetrics
w http.ResponseWriter
// request is the original client request. It is not safe for concurrent use.
// You have to clone it before using it in a goroutine.
request *http.Request
conn *wsConnectionWrapper
protocol wsproto.Proto
clientInfo *ClientInfo
logger *zap.Logger
initialPayload json.RawMessage
upgradeRequestHeaders json.RawMessage
upgradeRequestQueryParams json.RawMessage
initRequestID string
connectionID int64
subscriptionIDs atomic.Int64
subscriptions sync.Map
stats statistics.EngineStatistics
attributes []attribute.KeyValue
forwardInitialPayload bool
forwardUpgradeHeaders *forwardConfig
forwardQueryParams *forwardConfig
disableVariablesRemapping bool
apolloCompatibilityFlags config.ApolloCompatibilityFlags
}
type forwardConfig struct {
enabled bool
withStaticAllowList bool
staticAllowList []string
withRegexAllowList bool
regexAllowList []*regexp.Regexp
}
var (
detectNonRegex = regexp.MustCompile(`^[A-Za-z0-9_-]+$`)
)
func NewWebsocketConnectionHandler(ctx context.Context, opts WebSocketConnectionHandlerOptions) *WebSocketConnectionHandler {
return &WebSocketConnectionHandler{
ctx: ctx,
operationProcessor: opts.OperationProcessor,
operationBlocker: opts.OperationBlocker,
planner: opts.Planner,
graphqlHandler: opts.GraphQLHandler,
preHandler: opts.PreHandler,
metrics: opts.Metrics,
w: opts.ResponseWriter,
request: opts.Request,
conn: opts.Connection,
protocol: opts.Protocol,
logger: opts.Logger,
connectionID: opts.ConnectionID,
stats: opts.Stats,
clientInfo: opts.ClientInfo,
initRequestID: opts.InitRequestID,
forwardUpgradeHeaders: &opts.ForwardUpgradeHeaders,
forwardQueryParams: &opts.ForwardQueryParams,
forwardInitialPayload: opts.Config != nil && opts.Config.ForwardInitialPayload,
plannerOptions: opts.PlanOptions,
attributes: opts.Attributes,
disableVariablesRemapping: opts.DisableVariablesRemapping,
apolloCompatibilityFlags: opts.ApolloCompatibilityFlags,
}
}
func (h *WebSocketConnectionHandler) requestError(err error) error {
if errors.As(err, &wsutil.ClosedError{}) {
h.logger.Debug("Client closed connection")
return err
}
h.logger.Warn("Handling websocket connection", zap.Error(err))
return h.conn.WriteText(err.Error())
}
func (h *WebSocketConnectionHandler) writeErrorMessage(operationID string, err error) error {
gqlErrors := []graphqlError{
{Message: err.Error()},
}
payload, err := json.Marshal(gqlErrors)
if err != nil {
return fmt.Errorf("encoding GraphQL errors: %w", err)
}
return h.protocol.WriteGraphQLErrors(operationID, payload, nil)
}
func (h *WebSocketConnectionHandler) parseAndPlan(registration *SubscriptionRegistration) (*ParsedOperation, *operationContext, error) {
operationKit, err := h.operationProcessor.NewKit()
if err != nil {
return nil, nil, err
}
defer operationKit.Free()
opContext := &operationContext{
clientInfo: h.plannerOptions.ClientInfo,
}
if err := operationKit.UnmarshalOperationFromBody(registration.msg.Payload); err != nil {
return nil, nil, err
}
opContext.extensions = operationKit.parsedOperation.Request.Extensions
var (
skipParse bool
isApq bool
)
if operationKit.parsedOperation.IsPersistedOperation {
skipParse, isApq, err = operationKit.FetchPersistedOperation(h.ctx, h.clientInfo)
if err != nil {
return nil, nil, err
}
}
// If the persistent operation is already in the cache, we skip the parse step
// because the operation was already parsed. This is a performance optimization, and we
// can do it because we know that the persisted operation is immutable (identified by the hash)
if !skipParse {
startParsing := time.Now()
if err := operationKit.Parse(); err != nil {
opContext.parsingTime = time.Since(startParsing)
return nil, nil, err
}
opContext.parsingTime = time.Since(startParsing)
}
opContext.name = operationKit.parsedOperation.Request.OperationName
opContext.opType = operationKit.parsedOperation.Type
reqCtx := getRequestContext(registration.clientRequest.Context())
if reqCtx == nil {
return nil, nil, fmt.Errorf("request context not found")
}
if blocked := h.operationBlocker.OperationIsBlocked(h.logger, reqCtx.expressionContext, operationKit.parsedOperation); blocked != nil {
return nil, nil, blocked
}
startNormalization := time.Now()
if _, err := operationKit.NormalizeOperation(h.clientInfo.Name, isApq); err != nil {
opContext.normalizationTime = time.Since(startNormalization)
return nil, nil, err
}
opContext.normalizationCacheHit = operationKit.parsedOperation.NormalizationCacheHit
if _, err := operationKit.NormalizeVariables(); err != nil {
opContext.normalizationTime = time.Since(startNormalization)
return nil, nil, err
}
if err := operationKit.RemapVariables(h.disableVariablesRemapping); err != nil {
opContext.normalizationTime = time.Since(startNormalization)
return nil, nil, err
}
opContext.hash = operationKit.parsedOperation.ID
opContext.internalHash = operationKit.parsedOperation.InternalID
opContext.remapVariables = operationKit.parsedOperation.RemapVariables
opContext.normalizationTime = time.Since(startNormalization)
opContext.content = operationKit.parsedOperation.NormalizedRepresentation
opContext.variables, err = astjson.ParseBytes(operationKit.parsedOperation.Request.Variables)
if err != nil {
return nil, nil, err
}
startValidation := time.Now()
if _, err := operationKit.Validate(h.plannerOptions.ExecutionOptions.SkipLoader, opContext.remapVariables, &h.apolloCompatibilityFlags); err != nil {
opContext.validationTime = time.Since(startValidation)
return nil, nil, err
}
opContext.validationTime = time.Since(startValidation)
startPlanning := time.Now()
err = h.planner.plan(opContext, h.plannerOptions)
if err != nil {
opContext.planningTime = time.Since(startPlanning)
return operationKit.parsedOperation, nil, err
}
opContext.planningTime = time.Since(startPlanning)
opContext.initialPayload = h.initialPayload
return operationKit.parsedOperation, opContext, nil
}
func (h *WebSocketConnectionHandler) executeSubscription(registration *SubscriptionRegistration) {
rw := newWebsocketResponseWriter(registration.msg.ID, h.protocol, h.graphqlHandler.subgraphErrorPropagation.Enabled, h.logger, h.stats)
_, operationCtx, err := h.parseAndPlan(registration)
if err != nil {
wErr := h.writeErrorMessage(registration.msg.ID, err)
if wErr != nil {
h.logger.Warn("writing error message", zap.Error(wErr))
}
return
}
if h.forwardUpgradeHeaders.enabled && h.upgradeRequestHeaders != nil {
if operationCtx.extensions == nil {
operationCtx.extensions = json.RawMessage("{}")
}
operationCtx.extensions, err = jsonparser.Set(operationCtx.extensions, h.upgradeRequestHeaders, "upgradeHeaders")
if err != nil {
h.logger.Warn("Setting upgrade request data", zap.Error(err))
_ = h.writeErrorMessage(registration.msg.ID, err)
return
}
}
if h.forwardQueryParams.enabled && h.upgradeRequestQueryParams != nil {
if operationCtx.extensions == nil {
operationCtx.extensions = json.RawMessage("{}")
}
operationCtx.extensions, err = jsonparser.Set(operationCtx.extensions, h.upgradeRequestQueryParams, "upgradeQueryParams")
if err != nil {
h.logger.Warn("Setting upgrade request data", zap.Error(err))
_ = h.writeErrorMessage(registration.msg.ID, err)
return
}
}
if h.forwardInitialPayload && operationCtx.initialPayload != nil {
if operationCtx.extensions == nil {
operationCtx.extensions = json.RawMessage("{}")
}
operationCtx.extensions, err = jsonparser.Set(operationCtx.extensions, operationCtx.initialPayload, "initialPayload")
if err != nil {
h.logger.Warn("Setting initial payload", zap.Error(err))
_ = h.writeErrorMessage(registration.msg.ID, err)
return
}
}
resolveCtx := &resolve.Context{
Variables: operationCtx.Variables(),
Request: resolve.Request{
Header: registration.clientRequest.Header,
ID: h.initRequestID,
},
RenameTypeNames: h.graphqlHandler.executor.RenameTypeNames,
RemapVariables: operationCtx.remapVariables,
TracingOptions: operationCtx.traceOptions,
Extensions: operationCtx.extensions,
}
if h.forwardInitialPayload && operationCtx.initialPayload != nil {
resolveCtx.InitialPayload = operationCtx.initialPayload
}
reqContext := buildRequestContext(requestContextOptions{
operationContext: operationCtx,
requestLogger: h.logger,
metricSetAttributes: nil,
w: nil,
r: registration.clientRequest,
})
resolveCtx = resolveCtx.WithContext(withRequestContext(h.ctx, reqContext))
if h.graphqlHandler.authorizer != nil {
resolveCtx = WithAuthorizationExtension(resolveCtx)
resolveCtx.SetAuthorizer(h.graphqlHandler.authorizer)
}
resolveCtx = h.graphqlHandler.configureRateLimiting(resolveCtx)
// Put in a closure to evaluate err after defer
defer func() {
// StatusCode has no meaning here. We set it to 0 but set the error.
h.metrics.ExportSchemaUsageInfo(operationCtx, 0, err != nil, false)
}()
switch p := operationCtx.preparedPlan.preparedPlan.(type) {
case *plan.SynchronousResponsePlan:
_, err = h.graphqlHandler.executor.Resolver.ResolveGraphQLResponse(resolveCtx, p.Response, nil, rw)
if err != nil {
h.logger.Warn("Resolving GraphQL response", zap.Error(err))
h.graphqlHandler.WriteError(resolveCtx, err, p.Response, rw)
}
_ = rw.Flush()
rw.Complete()
case *plan.SubscriptionResponsePlan:
err = h.graphqlHandler.executor.Resolver.AsyncResolveGraphQLSubscription(resolveCtx, p.Response, rw.SubscriptionResponseWriter(), registration.id)
if err != nil {
h.logger.Warn("Resolving GraphQL subscription", zap.Error(err))
h.graphqlHandler.WriteError(resolveCtx, err, p.Response.Response, rw)
return
}
}
}
type SubscriptionRegistration struct {
id resolve.SubscriptionIdentifier
msg *wsproto.Message
clientRequest *http.Request
}
// registerSubscription registers a new subscription with the given message. This method is not safe for concurrent use.
func (h *WebSocketConnectionHandler) registerSubscription(msg *wsproto.Message) (*SubscriptionRegistration, error) {
if msg.ID == "" {
return nil, fmt.Errorf("missing id in subscribe")
}
_, exists := h.subscriptions.Load(msg.ID)
if exists {
return nil, fmt.Errorf("subscription with id %q already exists", msg.ID)
}
subscriptionID := h.subscriptionIDs.Inc()
h.subscriptions.Store(msg.ID, subscriptionID)
registration := &SubscriptionRegistration{
id: resolve.SubscriptionIdentifier{
ConnectionID: h.connectionID,
SubscriptionID: subscriptionID,
},
msg: msg,
// executeSubscription is running on a worker pool, so we have to clone the request
// before passing it to the worker pool. The original request is not safe for concurrent use and
// is needed later to construct the operation context and to clone the resolver context.
clientRequest: h.request.Clone(h.request.Context()),
}
return registration, nil
}
func (h *WebSocketConnectionHandler) handleComplete(msg *wsproto.Message) error {
value, exists := h.subscriptions.Load(msg.ID)
if !exists {
return h.requestError(fmt.Errorf("no subscription was registered for ID %q", msg.ID))
}
h.subscriptions.Delete(msg.ID)
subscriptionID, ok := value.(int64)
if !ok {
return h.requestError(fmt.Errorf("invalid subscription state for ID %q", msg.ID))
}
id := resolve.SubscriptionIdentifier{
ConnectionID: h.connectionID,
SubscriptionID: subscriptionID,
}
return h.graphqlHandler.executor.Resolver.AsyncUnsubscribeSubscription(id)
}
func (h *WebsocketHandler) HandleMessage(handler *WebSocketConnectionHandler, msg *wsproto.Message) (err error) {
switch msg.Type {
case wsproto.MessageTypeTerminate:
return errClientTerminatedConnection
case wsproto.MessageTypePing:
_ = handler.protocol.Pong(msg)
case wsproto.MessageTypePong:
// "Furthermore, the Pong message may even be sent unsolicited as a unidirectional heartbeat"
return nil
case wsproto.MessageTypeSubscribe:
registration, err := handler.registerSubscription(msg)
if err != nil {
h.logger.Warn("Handling subscription registration", zap.Error(err))
return handler.requestError(fmt.Errorf("error registering subscription id: %s", msg.ID))
}
if err := h.handlerSem.Acquire(handler.ctx, 1); err != nil {
return err
}
defer h.handlerSem.Release(1)
handler.executeSubscription(registration)
case wsproto.MessageTypeComplete:
err = handler.handleComplete(msg)
if err != nil {
h.logger.Warn("Handling complete", zap.Error(err))
}
default:
return handler.requestError(fmt.Errorf("unsupported message type %d", msg.Type))
}
return nil
}
func (h *WebSocketConnectionHandler) Initialize() (err error) {
h.logger.Debug("Websocket connection", zap.String("protocol", h.protocol.Subprotocol()))
h.initialPayload, err = h.protocol.Initialize()
if err != nil {
_ = h.requestError(fmt.Errorf("error initializing session"))
return err
}
if h.forwardQueryParams.enabled {
query := h.request.URL.Query()
params := make(map[string]string, len(query))
for k := range query {
if !h.ignoreQueryParameter(k) {
params[k] = query.Get(k)
}
}
if len(params) != 0 {
h.upgradeRequestQueryParams, err = json.Marshal(params)
if err != nil {
return err
}
}
}
if h.forwardUpgradeHeaders.enabled {
header := make(map[string]string, len(h.request.Header))
for k := range h.request.Header {
if h.ignoreHeader(k) {
continue
}
header[k] = h.request.Header.Get(k)
}
if len(header) > 0 {
h.upgradeRequestHeaders, err = json.Marshal(header)
}
if err != nil {
return err
}
}
return nil
}
func (h *WebSocketConnectionHandler) ignoreQueryParameter(k string) bool {
if h.forwardQueryParams.withStaticAllowList {
if slices.Contains(h.forwardQueryParams.staticAllowList, k) {
return false
}
}
if h.forwardQueryParams.withRegexAllowList {
for _, re := range h.forwardQueryParams.regexAllowList {
if re.MatchString(k) {
return false
}
}
}
return h.forwardQueryParams.withStaticAllowList || h.forwardQueryParams.withRegexAllowList
}
func (h *WebSocketConnectionHandler) ignoreHeader(k string) bool {
if h.forwardUpgradeHeaders.withStaticAllowList {
if slices.Contains(h.forwardUpgradeHeaders.staticAllowList, k) {
return false
}
}
if h.forwardUpgradeHeaders.withRegexAllowList {
for _, re := range h.forwardUpgradeHeaders.regexAllowList {
if re.MatchString(k) {
return false
}
}
}
return h.forwardUpgradeHeaders.withStaticAllowList || h.forwardUpgradeHeaders.withRegexAllowList
}
func (h *WebSocketConnectionHandler) Complete(rw *websocketResponseWriter) {
h.subscriptions.Delete(rw.id)
err := rw.protocol.Done(rw.id)
if err != nil {
return
}
_ = rw.Flush()
}
func (h *WebSocketConnectionHandler) Close() {
// Remove any pending IDs associated with this connection
err := h.graphqlHandler.executor.Resolver.AsyncUnsubscribeClient(h.connectionID)
if err != nil {
h.logger.Debug("Unsubscribing client", zap.Error(err))
}
err = h.conn.Close()
if err != nil {
h.logger.Debug("Closing websocket connection", zap.Error(err))
}
}