func()

in router/core/websocket.go [234:407]


func (h *WebsocketHandler) handleUpgradeRequest(w http.ResponseWriter, r *http.Request) {
	var (
		subProtocol string
	)

	requestID := middleware.GetReqID(r.Context())
	requestContext := getRequestContext(r.Context())

	requestLogger := h.logger.With(logging.WithRequestID(requestID), logging.WithTraceID(rtrace.GetTraceID(r.Context())))
	clientInfo := NewClientInfoFromRequest(r, h.clientHeader)

	if h.accessController != nil && !h.config.Authentication.FromInitialPayload.Enabled {
		// Check access control before upgrading the connection
		validatedReq, err := h.accessController.Access(w, r)
		if err != nil {
			statusCode := http.StatusForbidden
			if errors.Is(err, ErrUnauthorized) {
				statusCode = http.StatusUnauthorized
			}
			http.Error(w, http.StatusText(statusCode), statusCode)
			return
		}
		r = validatedReq

		requestContext.expressionContext.Request.Auth = expr.LoadAuth(r.Context())
	}

	upgrader := ws.HTTPUpgrader{
		Timeout: time.Second * 5,
		Protocol: func(s string) bool {
			if wsproto.IsSupportedSubprotocol(s) {
				subProtocol = s
				return true
			}
			return false
		},
	}
	c, _, _, err := upgrader.Upgrade(r, w)
	if err != nil {
		requestLogger.Warn("Websocket upgrade", zap.Error(err))
		_ = c.Close()
		return
	}

	// legacy absinthe clients don't set the Sec-WebSocket-Protocol header (Subprotocol)
	// so we need to check the path to determine if it's an absinthe client and set the subprotocol manually
	if subProtocol == "" && h.absintheHandlerEnabled && r.URL.Path == h.absintheHandlerPath {
		subProtocol = wsproto.AbsintheWSSubProtocol
	}

	// After successful upgrade, we can't write to the response writer anymore
	// because it's hijacked by the websocket connection

	conn := newWSConnectionWrapper(c)
	protocol, err := wsproto.NewProtocol(subProtocol, conn)
	if err != nil {
		requestLogger.Error("Create websocket protocol", zap.Error(err))
		_ = c.Close()
		return
	}

	// We can parse the request options before creating the handler
	// this avoids touching the client request across goroutines

	executionOptions, traceOptions, err := h.preHandler.parseRequestOptions(r, clientInfo, requestLogger)
	if err != nil {
		requestLogger.Error("Parse request options", zap.Error(err))
		_ = c.Close()
		return
	}

	planOptions := PlanOptions{
		ClientInfo:           clientInfo,
		TraceOptions:         traceOptions,
		ExecutionOptions:     executionOptions,
		TrackSchemaUsageInfo: h.preHandler.trackSchemaUsageInfo,
	}

	handler := NewWebsocketConnectionHandler(h.ctx, WebSocketConnectionHandlerOptions{
		OperationProcessor:        h.operationProcessor,
		OperationBlocker:          h.operationBlocker,
		Planner:                   h.planner,
		GraphQLHandler:            h.graphqlHandler,
		PreHandler:                h.preHandler,
		Metrics:                   h.metrics,
		PlanOptions:               planOptions,
		ResponseWriter:            w,
		Request:                   r,
		Connection:                conn,
		Protocol:                  protocol,
		Logger:                    requestLogger,
		Stats:                     h.stats,
		ConnectionID:              h.connectionIDs.Inc(),
		ClientInfo:                clientInfo,
		InitRequestID:             requestID,
		Config:                    h.config,
		ForwardUpgradeHeaders:     h.forwardUpgradeHeadersConfig,
		ForwardQueryParams:        h.forwardQueryParamsConfig,
		Attributes:                h.attributes,
		DisableVariablesRemapping: h.disableVariablesRemapping,
		ApolloCompatibilityFlags:  h.apolloCompatibilityFlags,
	})
	err = handler.Initialize()
	if err != nil {

		// Don't produce errors logs here because it can only be client side errors
		// e.g. slow client, aborted connection, invalid JSON, etc.
		// We log it as debug because it's not a server side error

		requestLogger.Debug("Initializing websocket connection", zap.Error(err))

		handler.Close()
		return
	}

	// Authenticate the connection using the initial payload
	fromInitialPayloadConfig := h.config.Authentication.FromInitialPayload
	if fromInitialPayloadConfig.Enabled {
		// Setting the initialPayload in the context to be used by the websocketInitialPayloadAuthenticator
		r = r.WithContext(authentication.WithWebsocketInitialPayloadContextKey(r.Context(), handler.initialPayload))

		// Later check access control after initial payload is read and set into the context
		if h.accessController != nil {
			handler.request, err = h.accessController.Access(w, r)
			if err != nil {
				statusCode := http.StatusForbidden
				if errors.Is(err, ErrUnauthorized) {
					statusCode = http.StatusUnauthorized
				}
				http.Error(handler.w, http.StatusText(statusCode), statusCode)
				_ = handler.writeErrorMessage(requestID, err)
				handler.Close()
				return
			}
		}

		// Export the token from the initial payload to the request header
		if fromInitialPayloadConfig.ExportToken.Enabled {
			var initialPayloadMap map[string]interface{}
			err := json.Unmarshal(handler.initialPayload, &initialPayloadMap)
			if err != nil {
				requestLogger.Error("Error parsing initial payload: %v", zap.Error(err))
				_ = handler.writeErrorMessage(requestID, err)
				handler.Close()
				return
			}
			jwtToken, ok := initialPayloadMap[fromInitialPayloadConfig.Key].(string)
			if !ok {
				err := fmt.Errorf("invalid JWT token in initial payload: JWT token is not a string")
				requestLogger.Error(err.Error())
				_ = handler.writeErrorMessage(requestID, err)
				handler.Close()
				return
			}
			handler.request.Header.Set(fromInitialPayloadConfig.ExportToken.HeaderKey, jwtToken)
		}

		requestContext.expressionContext.Request.Auth = expr.LoadAuth(handler.request.Context())
	}

	// Only when epoll/kqueue is available. On Windows, epoll is not available
	if h.netPoll != nil {
		err = h.addConnection(c, handler)
		if err != nil {
			requestLogger.Error("Adding connection to net poller", zap.Error(err))
			handler.Close()
		}
		return
	}

	// Handle messages sync when net poller implementation is not available

	go h.handleConnectionSync(handler)
}