router/core/errors.go (276 lines of code) (raw):
package core
import (
"context"
"encoding/json"
"errors"
"net"
"net/http"
"github.com/hashicorp/go-multierror"
"github.com/wundergraph/astjson"
rErrors "github.com/wundergraph/cosmo/router/internal/errors"
"github.com/wundergraph/cosmo/router/internal/persistedoperation"
"github.com/wundergraph/cosmo/router/internal/unique"
"github.com/wundergraph/cosmo/router/pkg/pubsub"
rtrace "github.com/wundergraph/cosmo/router/pkg/trace"
"github.com/wundergraph/graphql-go-tools/v2/pkg/engine/datasource/graphql_datasource"
"github.com/wundergraph/graphql-go-tools/v2/pkg/engine/resolve"
"github.com/wundergraph/graphql-go-tools/v2/pkg/graphqlerrors"
"github.com/wundergraph/graphql-go-tools/v2/pkg/operationreport"
"go.opentelemetry.io/otel/trace"
"go.uber.org/zap"
)
type errorType int
const (
errorTypeUnknown errorType = iota
errorTypeRateLimit
errorTypeUnauthorized
errorTypeContextCanceled
errorTypeContextTimeout
errorTypeUpgradeFailed
errorTypeEDFS
errorTypeInvalidWsSubprotocol
errorTypeEDFSInvalidMessage
errorTypeMergeResult
)
type (
GraphQLErrorResponse struct {
Errors []graphqlError `json:"errors"`
Data any `json:"data"`
Extensions *Extensions `json:"extensions,omitempty"`
}
Extensions struct {
RateLimit json.RawMessage `json:"rateLimit,omitempty"`
Authorization json.RawMessage `json:"authorization,omitempty"`
Trace json.RawMessage `json:"trace,omitempty"`
StatusCode int `json:"statusCode,omitempty"`
Code string `json:"code,omitempty"`
}
)
func getErrorType(err error) errorType {
if errors.Is(err, ErrRateLimitExceeded) {
return errorTypeRateLimit
}
if errors.Is(err, ErrUnauthorized) {
return errorTypeUnauthorized
}
if errors.Is(err, context.Canceled) {
return errorTypeContextCanceled
}
var upgradeErr *graphql_datasource.UpgradeRequestError
if errors.As(err, &upgradeErr) {
return errorTypeUpgradeFailed
}
var nErr net.Error
if errors.As(err, &nErr) {
if nErr.Timeout() {
return errorTypeContextTimeout
}
}
var edfsErr *pubsub.Error
if errors.As(err, &edfsErr) {
return errorTypeEDFS
}
var invalidWsSubprotocolErr graphql_datasource.InvalidWsSubprotocolError
if errors.As(err, &invalidWsSubprotocolErr) {
return errorTypeInvalidWsSubprotocol
}
var jsonParseErr *astjson.ParseError
if errors.As(err, &jsonParseErr) {
return errorTypeEDFSInvalidMessage
}
var mergeResultErr resolve.ErrMergeResult
if errors.As(err, &mergeResultErr) {
return errorTypeMergeResult
}
return errorTypeUnknown
}
func logInternalErrorsFromReport(report *operationreport.Report, requestLogger *zap.Logger) {
var internalErr error
for _, err := range report.InternalErrors {
internalErr = multierror.Append(internalErr, err)
}
if internalErr != nil {
requestLogger.Error("internal error", zap.Error(internalErr))
}
}
// trackFinalResponseError sets the final response error on the request context and
// attaches it to the span. This is used to process the error in the outer middleware
// and therefore only intended to be used in the GraphQL handler.
func trackFinalResponseError(ctx context.Context, err error) {
if err == nil {
return
}
span := trace.SpanFromContext(ctx)
requestContext := getRequestContext(ctx)
if requestContext == nil {
return
}
requestContext.SetError(err)
requestContext.graphQLErrorServices = getAggregatedSubgraphServiceNames(requestContext.error)
requestContext.graphQLErrorCodes = getAggregatedSubgraphErrorCodes(requestContext.error)
rtrace.AttachErrToSpan(span, err)
}
func getAggregatedSubgraphErrorCodes(err error) []string {
if unwrapped, ok := err.(multiError); ok {
errs := unwrapped.Unwrap()
errorCodes := make([]string, 0, len(errs))
for _, e := range errs {
var subgraphError *resolve.SubgraphError
if errors.As(e, &subgraphError) {
errorCodes = append(errorCodes, subgraphError.Codes()...)
}
}
return errorCodes
}
return nil
}
func getSubgraphNames(ds []resolve.DataSourceInfo) []string {
operationServiceNames := make([]string, 0, len(ds))
for _, ds := range ds {
operationServiceNames = append(operationServiceNames, ds.Name)
}
return operationServiceNames
}
func getAggregatedSubgraphServiceNames(err error) []string {
if unwrapped, ok := err.(multiError); ok {
errs := unwrapped.Unwrap()
serviceNames := make([]string, 0, len(errs))
for _, e := range errs {
var subgraphError *resolve.SubgraphError
if errors.As(e, &subgraphError) {
serviceNames = append(serviceNames, subgraphError.DataSourceInfo.Name)
}
}
return unique.SliceElements(serviceNames)
}
return nil
}
// propagateSubgraphErrors propagates the subgraph errors to the request context
func propagateSubgraphErrors(ctx *resolve.Context) {
err := ctx.SubgraphErrors()
if err != nil {
trackFinalResponseError(ctx.Context(), err)
}
}
// writeRequestErrors writes the given request errors to the http.ResponseWriter.
// It accepts a graphqlerrors.RequestErrors object and writes it to the response based on the GraphQL spec.
func writeRequestErrors(r *http.Request, w http.ResponseWriter, statusCode int, requestErrors graphqlerrors.RequestErrors, requestLogger *zap.Logger) {
if requestErrors == nil {
return
}
// According to the tests requestContext can be nil (when called from module WriteResponseError)
// As such we have coded this condition defensively to be safe
requestContext := getRequestContext(r.Context())
isSubscription := requestContext != nil && requestContext.operation != nil && requestContext.operation.opType == "subscription"
wgRequestParams := NegotiateSubscriptionParams(r, !isSubscription)
// Is subscription
if wgRequestParams.UseSse || wgRequestParams.UseMultipart {
setSubscriptionHeaders(wgRequestParams, r, w)
if statusCode != 0 {
w.WriteHeader(statusCode)
}
if wgRequestParams.UseSse {
_, err := w.Write([]byte("event: next\ndata: "))
if err != nil {
if requestLogger != nil {
if rErrors.IsBrokenPipe(err) {
requestLogger.Warn("Broken pipe, error writing response", zap.Error(err))
return
}
requestLogger.Error("Error writing response", zap.Error(err))
}
return
}
} else if wgRequestParams.UseMultipart {
// Handle multipart error response
if err := writeMultipartError(w, requestErrors, isSubscription); err != nil {
if requestLogger != nil {
requestLogger.Error("error writing multipart response", zap.Error(err))
}
}
return
}
} else {
// Regular request
w.Header().Set("Content-Type", "application/json")
if statusCode != 0 {
w.WriteHeader(statusCode)
}
}
if _, err := requestErrors.WriteResponse(w); err != nil {
if requestLogger != nil {
if rErrors.IsBrokenPipe(err) {
requestLogger.Warn("Broken pipe, error writing response", zap.Error(err))
return
}
requestLogger.Error("Error writing response", zap.Error(err))
}
}
}
// writeMultipartError writes the error response in a multipart format with proper boundaries and headers.
func writeMultipartError(
w http.ResponseWriter,
requestErrors graphqlerrors.RequestErrors,
isSubscription bool,
) error {
// Start with the multipart boundary
prefix := GetWriterPrefix(false, true, true)
if _, err := w.Write([]byte(prefix)); err != nil {
return err
}
// Write the actual error payload
response := graphqlerrors.Response{
Errors: requestErrors,
}
responseBytes, err := response.Marshal()
if err != nil {
return err
}
resp, err := wrapMultipartMessage(responseBytes, isSubscription)
if err != nil {
return err
}
// The multipart spec requires us to use both CRLF (\r and \n) characters together. Since we didn't do this
// before, some clients that rely on both CR and LF strictly to parse blocks were broken and not parsing our
// multipart chunks correctly. With this fix here (and in a few other places) the clients are now working.
if isSubscription {
resp = append(resp, '\r', '\n')
} else {
resp = append(resp, []byte("\r\n--graphql--")...)
}
if _, err := w.Write([]byte(resp)); err != nil {
return err
}
// Flush the response
if flusher, ok := w.(http.Flusher); ok {
flusher.Flush()
}
return nil
}
func requestErrorsFromHttpError(httpErr HttpError) graphqlerrors.RequestErrors {
requestErr := graphqlerrors.RequestError{
Message: httpErr.Error(),
}
if httpErr.ExtensionCode() != "" {
requestErr.Extensions = &graphqlerrors.Extensions{
Code: httpErr.ExtensionCode(),
}
}
return graphqlerrors.RequestErrors{requestErr}
}
// writeOperationError writes the given error to the http.ResponseWriter but evaluates the error type first.
// It also logs additional information about the error.
func writeOperationError(r *http.Request, w http.ResponseWriter, requestLogger *zap.Logger, err error) {
requestLogger.Debug("operation error", zap.Error(err))
var reportErr ReportError
var httpErr HttpError
var poNotFoundErr *persistedoperation.PersistentOperationNotFoundError
switch {
case errors.As(err, &httpErr):
writeRequestErrors(r, w, httpErr.StatusCode(), requestErrorsFromHttpError(httpErr), requestLogger)
case errors.As(err, &poNotFoundErr):
newErr := NewHttpGraphqlError("PersistedQueryNotFound", "PERSISTED_QUERY_NOT_FOUND", http.StatusOK)
writeRequestErrors(r, w, http.StatusOK, requestErrorsFromHttpError(newErr), requestLogger)
case errors.As(err, &reportErr):
report := reportErr.Report()
logInternalErrorsFromReport(reportErr.Report(), requestLogger)
statusCode, requestErrors := graphqlerrors.RequestErrorsFromOperationReportWithStatusCode(*report)
if len(requestErrors) > 0 {
writeRequestErrors(r, w, statusCode, requestErrors, requestLogger)
return
} else {
// there were no external errors to return to user, so we return an internal server error
writeRequestErrors(r, w, http.StatusInternalServerError, graphqlerrors.RequestErrorsFromError(errInternalServer), requestLogger)
}
default:
writeRequestErrors(r, w, http.StatusInternalServerError, graphqlerrors.RequestErrorsFromError(errInternalServer), requestLogger)
}
}
type ExprWrapError struct {
Err error
}
func (e *ExprWrapError) Error() string {
if e.Err == nil {
return ""
}
return e.Err.Error()
}