router/core/flushwriter.go (223 lines of code) (raw):

package core import ( "bytes" "context" "io" "mime" "net/http" "strconv" "strings" "github.com/wundergraph/astjson" "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/resolve" ) const ( WgPrefix = "wg_" WgSseParam = WgPrefix + "sse" WgSubscribeOnceParam = WgPrefix + "subscribe_once" multipartBoundary = "graphql" multipartMime = "multipart/mixed" jsonContent = "application/json" sseMimeType = "text/event-stream" heartbeat = "{}" multipartContent = multipartMime + "; boundary=" + multipartBoundary multipartStart = "\r\n--" + multipartBoundary ) type HttpFlushWriter struct { ctx context.Context cancel context.CancelFunc writer io.Writer flusher http.Flusher subscribeOnce bool sse bool multipart bool buf *bytes.Buffer firstMessage bool // apolloSubscriptionMultipartPrintBoundary if set to true will send the multipart boundary at the end of the message to allow // misbehaving client (like apollo client) to read the message just sent before the next one or the heartbeat apolloSubscriptionMultipartPrintBoundary bool } func (f *HttpFlushWriter) Complete() { if f.ctx.Err() != nil { return } if f.sse { _, _ = f.writer.Write([]byte("event: complete")) } else if f.multipart { // Write the final boundary in the multipart response if f.apolloSubscriptionMultipartPrintBoundary { _, _ = f.writer.Write([]byte("--\r\n")) } else { _, _ = f.writer.Write([]byte("--" + multipartBoundary + "--\r\n")) } } f.Close() } func (f *HttpFlushWriter) Write(p []byte) (n int, err error) { if err = f.ctx.Err(); err != nil { return } return f.buf.Write(p) } func (f *HttpFlushWriter) Close() { if f.ctx.Err() != nil { return } f.cancel() } func (f *HttpFlushWriter) Flush() (err error) { if err = f.ctx.Err(); err != nil { return err } resp := f.buf.Bytes() f.buf.Reset() flushBreak := GetWriterPrefix(f.sse, f.multipart, !f.apolloSubscriptionMultipartPrintBoundary || f.firstMessage) if f.firstMessage { f.firstMessage = false } if f.multipart && len(resp) > 0 { var err error resp, err = wrapMultipartMessage(resp, true) if err != nil { return err } } separation := "\n\n" if f.multipart { if !f.apolloSubscriptionMultipartPrintBoundary { separation = "\r\n" } else { separation = "\r\n" + multipartStart } } else if f.subscribeOnce { separation = "" } full := flushBreak + string(resp) + separation _, err = f.writer.Write([]byte(full)) if err != nil { return err } f.flusher.Flush() if f.subscribeOnce { defer f.Close() } return nil } func GetSubscriptionResponseWriter(ctx *resolve.Context, r *http.Request, w http.ResponseWriter, apolloSubscriptionMultipartPrintBoundary bool) (*resolve.Context, resolve.SubscriptionResponseWriter, bool) { type withFlushWriter interface { SubscriptionResponseWriter() resolve.SubscriptionResponseWriter } if wfw, ok := w.(withFlushWriter); ok { return ctx, wfw.SubscriptionResponseWriter(), true } wgParams := NegotiateSubscriptionParams(r, false) flusher, ok := w.(http.Flusher) if !ok { return ctx, nil, false } setSubscriptionHeaders(wgParams, r, w) flusher.Flush() flushWriter := &HttpFlushWriter{ writer: w, flusher: flusher, sse: wgParams.UseSse, multipart: wgParams.UseMultipart, subscribeOnce: wgParams.SubscribeOnce, buf: &bytes.Buffer{}, firstMessage: true, apolloSubscriptionMultipartPrintBoundary: apolloSubscriptionMultipartPrintBoundary, } flushWriter.ctx, flushWriter.cancel = context.WithCancel(ctx.Context()) ctx = ctx.WithContext(flushWriter.ctx) if wgParams.UseMultipart { ctx.ExecutionOptions.SendHeartbeat = true } return ctx, flushWriter, true } func wrapMultipartMessage(resp []byte, wrapPayload bool) ([]byte, error) { if string(resp) == heartbeat { return resp, nil } respValuePreMerge, err := astjson.ParseBytes(resp) if err != nil { return nil, err } if !wrapPayload { return respValuePreMerge.MarshalTo(nil), nil } // Per the Apollo docs, multipart messages are supposed to be json, wrapped in `"payload"` // for subscriptions payloadWrapper, err := astjson.Parse(`{"payload": {}}`) if err != nil { return nil, err } respValue, _, err := astjson.MergeValuesWithPath(payloadWrapper, respValuePreMerge, "payload") if err != nil { return nil, err } return respValue.MarshalTo(nil), nil } // setSubscriptionHeaders sets the headers for the subscription response. Only used for non-websocket subscriptions. func setSubscriptionHeaders(wgParams SubscriptionParams, r *http.Request, w http.ResponseWriter) { if wgParams.SubscribeOnce { return } if wgParams.UseMultipart { w.Header().Set("Content-Type", multipartContent) if r.ProtoMajor == 1 { w.Header().Set("Transfer-Encoding", "chunked") } } else if wgParams.UseSse { w.Header().Set("Content-Type", sseMimeType) } w.Header().Set("Cache-Control", "no-cache") w.Header().Set("Connection", "keep-alive") // allow unbuffered responses, it's used when it's necessary just to pass response through // setting this to “yes” will allow the response to be cached w.Header().Set("X-Accel-Buffering", "no") } func NegotiateSubscriptionParams(r *http.Request, preferJson bool) SubscriptionParams { q := r.URL.Query() acceptHeaders := r.Header.Get("Accept") elements := strings.Split(acceptHeaders, ",") // Per RFC 9110, Accept header can be in the form`text/event-stream,application/json`, with an optional q-value to // specify preference. We want to parse this and find the best option to use, and default to the first option if no // q-value is provided. // Eventually a solution will be in the stdlib: see https://github.com/golang/go/issues/19307, at which point we should // remove this var ( useMultipart = false useSse = q.Has(WgSseParam) bestType = "" bestQ = -1.0 // Default to lowest possible q-value ) for _, acceptHeader := range elements { mediaType, params, _ := mime.ParseMediaType(acceptHeader) qValue := 1.0 // Default quality factor if qStr, exists := params["q"]; exists { // If a quality factor exists, parse it and prefer it if parsedQ, err := strconv.ParseFloat(qStr, 64); err == nil { qValue = parsedQ } } // We also have an exception where we prioritize json over higher priority media types if preferJson && strings.EqualFold(mediaType, jsonContent) { bestType = mediaType break } // Find the media type with the highest q-value. If none is provided, it will default to the first option // in the header, per https://www.rfc-editor.org/rfc/rfc9110.html#name-accept if qValue > bestQ { bestQ = qValue bestType = mediaType } } subscribeOnce := q.Has(WgSubscribeOnceParam) useSse = useSse || bestType == sseMimeType useMultipart = bestType == multipartMime return SubscriptionParams{ UseSse: useSse, SubscribeOnce: subscribeOnce, UseMultipart: useMultipart, } } type SubscriptionParams struct { UseSse bool SubscribeOnce bool UseMultipart bool } func GetWriterPrefix(sse bool, multipart bool, firstMessage bool) string { flushBreak := "" if sse { flushBreak = "event: next\ndata: " } else if multipart { messageStart := "" if firstMessage { messageStart = multipartStart } flushBreak = messageStart + "\r\nContent-Type: " + jsonContent + "\r\n\r\n" } return flushBreak }