in router/core/websocket.go [234:407]
func (h *WebsocketHandler) handleUpgradeRequest(w http.ResponseWriter, r *http.Request) {
var (
subProtocol string
)
requestID := middleware.GetReqID(r.Context())
requestContext := getRequestContext(r.Context())
requestLogger := h.logger.With(logging.WithRequestID(requestID), logging.WithTraceID(rtrace.GetTraceID(r.Context())))
clientInfo := NewClientInfoFromRequest(r, h.clientHeader)
if h.accessController != nil && !h.config.Authentication.FromInitialPayload.Enabled {
// Check access control before upgrading the connection
validatedReq, err := h.accessController.Access(w, r)
if err != nil {
statusCode := http.StatusForbidden
if errors.Is(err, ErrUnauthorized) {
statusCode = http.StatusUnauthorized
}
http.Error(w, http.StatusText(statusCode), statusCode)
return
}
r = validatedReq
requestContext.expressionContext.Request.Auth = expr.LoadAuth(r.Context())
}
upgrader := ws.HTTPUpgrader{
Timeout: time.Second * 5,
Protocol: func(s string) bool {
if wsproto.IsSupportedSubprotocol(s) {
subProtocol = s
return true
}
return false
},
}
c, _, _, err := upgrader.Upgrade(r, w)
if err != nil {
requestLogger.Warn("Websocket upgrade", zap.Error(err))
_ = c.Close()
return
}
// legacy absinthe clients don't set the Sec-WebSocket-Protocol header (Subprotocol)
// so we need to check the path to determine if it's an absinthe client and set the subprotocol manually
if subProtocol == "" && h.absintheHandlerEnabled && r.URL.Path == h.absintheHandlerPath {
subProtocol = wsproto.AbsintheWSSubProtocol
}
// After successful upgrade, we can't write to the response writer anymore
// because it's hijacked by the websocket connection
conn := newWSConnectionWrapper(c)
protocol, err := wsproto.NewProtocol(subProtocol, conn)
if err != nil {
requestLogger.Error("Create websocket protocol", zap.Error(err))
_ = c.Close()
return
}
// We can parse the request options before creating the handler
// this avoids touching the client request across goroutines
executionOptions, traceOptions, err := h.preHandler.parseRequestOptions(r, clientInfo, requestLogger)
if err != nil {
requestLogger.Error("Parse request options", zap.Error(err))
_ = c.Close()
return
}
planOptions := PlanOptions{
ClientInfo: clientInfo,
TraceOptions: traceOptions,
ExecutionOptions: executionOptions,
TrackSchemaUsageInfo: h.preHandler.trackSchemaUsageInfo,
}
handler := NewWebsocketConnectionHandler(h.ctx, WebSocketConnectionHandlerOptions{
OperationProcessor: h.operationProcessor,
OperationBlocker: h.operationBlocker,
Planner: h.planner,
GraphQLHandler: h.graphqlHandler,
PreHandler: h.preHandler,
Metrics: h.metrics,
PlanOptions: planOptions,
ResponseWriter: w,
Request: r,
Connection: conn,
Protocol: protocol,
Logger: requestLogger,
Stats: h.stats,
ConnectionID: h.connectionIDs.Inc(),
ClientInfo: clientInfo,
InitRequestID: requestID,
Config: h.config,
ForwardUpgradeHeaders: h.forwardUpgradeHeadersConfig,
ForwardQueryParams: h.forwardQueryParamsConfig,
Attributes: h.attributes,
DisableVariablesRemapping: h.disableVariablesRemapping,
ApolloCompatibilityFlags: h.apolloCompatibilityFlags,
})
err = handler.Initialize()
if err != nil {
// Don't produce errors logs here because it can only be client side errors
// e.g. slow client, aborted connection, invalid JSON, etc.
// We log it as debug because it's not a server side error
requestLogger.Debug("Initializing websocket connection", zap.Error(err))
handler.Close()
return
}
// Authenticate the connection using the initial payload
fromInitialPayloadConfig := h.config.Authentication.FromInitialPayload
if fromInitialPayloadConfig.Enabled {
// Setting the initialPayload in the context to be used by the websocketInitialPayloadAuthenticator
r = r.WithContext(authentication.WithWebsocketInitialPayloadContextKey(r.Context(), handler.initialPayload))
// Later check access control after initial payload is read and set into the context
if h.accessController != nil {
handler.request, err = h.accessController.Access(w, r)
if err != nil {
statusCode := http.StatusForbidden
if errors.Is(err, ErrUnauthorized) {
statusCode = http.StatusUnauthorized
}
http.Error(handler.w, http.StatusText(statusCode), statusCode)
_ = handler.writeErrorMessage(requestID, err)
handler.Close()
return
}
}
// Export the token from the initial payload to the request header
if fromInitialPayloadConfig.ExportToken.Enabled {
var initialPayloadMap map[string]interface{}
err := json.Unmarshal(handler.initialPayload, &initialPayloadMap)
if err != nil {
requestLogger.Error("Error parsing initial payload: %v", zap.Error(err))
_ = handler.writeErrorMessage(requestID, err)
handler.Close()
return
}
jwtToken, ok := initialPayloadMap[fromInitialPayloadConfig.Key].(string)
if !ok {
err := fmt.Errorf("invalid JWT token in initial payload: JWT token is not a string")
requestLogger.Error(err.Error())
_ = handler.writeErrorMessage(requestID, err)
handler.Close()
return
}
handler.request.Header.Set(fromInitialPayloadConfig.ExportToken.HeaderKey, jwtToken)
}
requestContext.expressionContext.Request.Auth = expr.LoadAuth(handler.request.Context())
}
// Only when epoll/kqueue is available. On Windows, epoll is not available
if h.netPoll != nil {
err = h.addConnection(c, handler)
if err != nil {
requestLogger.Error("Adding connection to net poller", zap.Error(err))
handler.Close()
}
return
}
// Handle messages sync when net poller implementation is not available
go h.handleConnectionSync(handler)
}