router/core/header_rule_engine.go (531 lines of code) (raw):
package core
import (
"context"
"fmt"
"io"
"net/http"
"regexp"
"slices"
"strings"
"sync"
"time"
cachedirective "github.com/pquerna/cachecontrol/cacheobject"
nodev1 "github.com/wundergraph/cosmo/router/gen/proto/wg/cosmo/node/v1"
"github.com/wundergraph/cosmo/router/pkg/config"
"github.com/wundergraph/cosmo/router/pkg/otel"
rtrace "github.com/wundergraph/cosmo/router/pkg/trace"
"github.com/wundergraph/graphql-go-tools/v2/pkg/engine/resolve"
"go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/trace"
)
var (
_ EnginePreOriginHandler = (*HeaderPropagation)(nil)
ignoredHeaders = []string{
"Alt-Svc",
"Connection",
"Proxy-Connection", // non-standard but still sent by libcurl and rejected by e.g. google
// Hop-by-hop headers
// https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Connection
"Keep-Alive",
"Proxy-Authenticate",
"Proxy-Authorization",
"Te", // canonicalized version of "TE"
"Trailer", // not Trailers per URL above; https://www.rfc-editor.org/errata_search.php?eid=4522
"Transfer-Encoding",
"Upgrade",
// Content Negotiation. We must never propagate the client headers to the upstream
// The router has to decide on its own what to send to the upstream
"Content-Type",
"Accept-Encoding",
"Accept-Charset",
"Accept",
// Web Socket negotiation headers. We must never propagate the client headers to the upstream.
"Sec-Websocket-Extensions",
"Sec-Websocket-Key",
"Sec-Websocket-Protocol",
"Sec-Websocket-Version",
}
cacheControlKey = "Cache-Control"
expiresKey = "Expires"
noCache = "no-cache"
caseInsensitiveRegexp = "(?i)"
)
type responseHeaderPropagationKey struct{}
type responseHeaderPropagation struct {
header http.Header
m *sync.Mutex
previousCacheControl *cachedirective.Object
setCacheControl bool
}
func WithResponseHeaderPropagation(ctx *resolve.Context) *resolve.Context {
return ctx.WithContext(context.WithValue(ctx.Context(), responseHeaderPropagationKey{}, &responseHeaderPropagation{
header: make(http.Header),
m: &sync.Mutex{},
}))
}
func getResponseHeaderPropagation(ctx context.Context) *responseHeaderPropagation {
v := ctx.Value(responseHeaderPropagationKey{})
if v == nil {
return nil
}
return v.(*responseHeaderPropagation)
}
func HeaderPropagationWriter(w http.ResponseWriter, ctx context.Context) io.Writer {
propagation := getResponseHeaderPropagation(ctx)
if propagation == nil {
return w
}
return &headerPropagationWriter{
writer: w,
headerPropagation: propagation,
propagateHeaders: true,
}
}
type headerPropagationWriter struct {
writer http.ResponseWriter
headerPropagation *responseHeaderPropagation
propagateHeaders bool
}
func (h *headerPropagationWriter) Write(p []byte) (n int, err error) {
if h.propagateHeaders {
for k, v := range h.headerPropagation.header {
for _, el := range v {
h.writer.Header().Add(k, el)
}
}
h.propagateHeaders = false
}
return h.writer.Write(p)
}
// HeaderPropagation is a pre-origin handler that can be used to propagate and
// manipulate headers from the client request to the upstream
type HeaderPropagation struct {
regex map[string]*regexp.Regexp
rules *config.HeaderRules
hasRequestRules bool
hasResponseRules bool
}
func initHeaderRules(rules *config.HeaderRules) {
if rules.All == nil {
rules.All = &config.GlobalHeaderRule{}
}
if rules.Subgraphs == nil {
rules.Subgraphs = make(map[string]*config.GlobalHeaderRule)
}
}
func NewHeaderPropagation(rules *config.HeaderRules) (*HeaderPropagation, error) {
if rules == nil {
return nil, nil
}
initHeaderRules(rules)
hf := HeaderPropagation{
rules: rules,
regex: map[string]*regexp.Regexp{},
}
rhrs, rhrrs := hf.getAllRules()
hf.hasRequestRules = len(rhrs) > 0
hf.hasResponseRules = len(rhrrs) > 0
if err := hf.collectRuleMatchers(rhrs, rhrrs); err != nil {
return nil, err
}
return &hf, nil
}
func AddCacheControlPolicyToRules(rules *config.HeaderRules, cacheControl config.CacheControlPolicy) *config.HeaderRules {
if rules == nil {
rules = &config.HeaderRules{}
if !cacheControl.Enabled && cacheControl.Subgraphs == nil {
return nil
}
}
initHeaderRules(rules)
if cacheControl.Enabled {
rules.All.Response = append(rules.All.Response, &config.ResponseHeaderRule{
Operation: config.HeaderRuleOperationPropagate,
Algorithm: config.ResponseHeaderRuleAlgorithmMostRestrictiveCacheControl,
Default: cacheControl.Value,
})
}
for _, graph := range cacheControl.Subgraphs {
subgraphRules, ok := rules.Subgraphs[graph.Name]
if !ok {
subgraphRules = &config.GlobalHeaderRule{Response: make([]*config.ResponseHeaderRule, 0)}
}
subgraphRules.Response = append(subgraphRules.Response, &config.ResponseHeaderRule{
Operation: config.HeaderRuleOperationPropagate,
Algorithm: config.ResponseHeaderRuleAlgorithmMostRestrictiveCacheControl,
Default: graph.Value,
})
rules.Subgraphs[graph.Name] = subgraphRules
}
return rules
}
func (hf *HeaderPropagation) getAllRules() ([]*config.RequestHeaderRule, []*config.ResponseHeaderRule) {
rhrs := hf.rules.All.Request
for _, subgraph := range hf.rules.Subgraphs {
rhrs = append(rhrs, subgraph.Request...)
}
rhrrs := hf.rules.All.Response
for _, subgraph := range hf.rules.Subgraphs {
rhrrs = append(rhrrs, subgraph.Response...)
}
return rhrs, rhrrs
}
func (hf *HeaderPropagation) processRule(rule config.HeaderRule, index int) error {
switch rule.GetOperation() {
case config.HeaderRuleOperationSet:
case config.HeaderRuleOperationPropagate:
if rule.GetMatching() != "" {
regex, err := regexp.Compile(caseInsensitiveRegexp + rule.GetMatching())
if err != nil {
return fmt.Errorf("invalid regex '%s' for header rule %d: %w", rule.GetMatching(), index, err)
}
hf.regex[rule.GetMatching()] = regex
}
default:
return fmt.Errorf("unhandled operation '%s' for header rule %+v", rule.GetOperation(), rule)
}
return nil
}
func (hf *HeaderPropagation) collectRuleMatchers(rhrs []*config.RequestHeaderRule, rhrrs []*config.ResponseHeaderRule) error {
for i, rule := range rhrs {
if err := hf.processRule(rule, i); err != nil {
return err
}
}
for i, rule := range rhrrs {
if err := hf.processRule(rule, i); err != nil {
return err
}
}
return nil
}
func (h *HeaderPropagation) HasRequestRules() bool {
if h == nil {
return false
}
return h.hasRequestRules
}
func (h *HeaderPropagation) HasResponseRules() bool {
if h == nil {
return false
}
return h.hasResponseRules
}
func (h *HeaderPropagation) OnOriginRequest(request *http.Request, ctx RequestContext) (*http.Request, *http.Response) {
for _, rule := range h.rules.All.Request {
h.applyRequestRule(ctx, request, rule)
}
subgraph := ctx.ActiveSubgraph(request)
if subgraph != nil {
if subgraphRules, ok := h.rules.Subgraphs[subgraph.Name]; ok {
for _, rule := range subgraphRules.Request {
h.applyRequestRule(ctx, request, rule)
}
}
}
return request, nil
}
func (h *HeaderPropagation) OnOriginResponse(resp *http.Response, ctx RequestContext) *http.Response {
// In the case of an error response, it is possible that the response is nil
if resp == nil {
return nil
}
propagation := getResponseHeaderPropagation(resp.Request.Context())
if propagation == nil {
return resp
}
for _, rule := range h.rules.All.Response {
h.applyResponseRule(propagation, resp, rule)
}
subgraph := ctx.ActiveSubgraph(resp.Request)
if subgraph != nil {
if subgraphRules, ok := h.rules.Subgraphs[subgraph.Name]; ok {
for _, rule := range subgraphRules.Response {
h.applyResponseRule(propagation, resp, rule)
}
}
}
return resp
}
func (h *HeaderPropagation) applyResponseRule(propagation *responseHeaderPropagation, res *http.Response, rule *config.ResponseHeaderRule) {
if rule.Operation == config.HeaderRuleOperationSet {
propagation.header.Set(rule.Name, rule.Value)
if rule.Name == cacheControlKey {
// Handle the case where the cache control header is set explicitly
propagation.setCacheControl = true
}
return
}
if rule.Operation != config.HeaderRuleOperationPropagate {
return
}
if rule.Named != "" {
if slices.Contains(ignoredHeaders, rule.Named) {
return
}
values := res.Header.Values(rule.Named)
if len(values) > 0 {
h.applyResponseRuleKeyValue(res, propagation, rule, rule.Named, values)
} else if rule.Default != "" {
h.applyResponseRuleKeyValue(res, propagation, rule, rule.Named, []string{rule.Default})
}
return
} else if rule.Matching != "" {
if regex, ok := h.regex[rule.Matching]; ok {
for name := range res.Header {
if regex.MatchString(name) {
if slices.Contains(ignoredHeaders, name) {
continue
}
values := res.Header.Values(name)
h.applyResponseRuleKeyValue(res, propagation, rule, name, values)
}
}
}
} else if rule.Algorithm == config.ResponseHeaderRuleAlgorithmMostRestrictiveCacheControl {
// Explicitly apply the CacheControl algorithm on the headers
h.applyResponseRuleKeyValue(res, propagation, rule, "", []string{""})
}
}
func (h *HeaderPropagation) applyResponseRuleKeyValue(res *http.Response, propagation *responseHeaderPropagation, rule *config.ResponseHeaderRule, key string, values []string) {
// Since we'll be setting the header map directly, we need to canonicalize the key
key = http.CanonicalHeaderKey(key)
switch rule.Algorithm {
case config.ResponseHeaderRuleAlgorithmFirstWrite:
propagation.m.Lock()
if val := propagation.header.Get(key); val == "" {
propagation.header[key] = values
}
propagation.m.Unlock()
case config.ResponseHeaderRuleAlgorithmLastWrite:
propagation.m.Lock()
propagation.header[key] = values
propagation.m.Unlock()
case config.ResponseHeaderRuleAlgorithmAppend:
propagation.m.Lock()
propagation.header[key] = append(propagation.header[key], values...)
propagation.m.Unlock()
case config.ResponseHeaderRuleAlgorithmMostRestrictiveCacheControl:
h.applyResponseRuleMostRestrictiveCacheControl(res, propagation, rule)
}
}
func (h *HeaderPropagation) applyRequestRule(ctx RequestContext, request *http.Request, rule *config.RequestHeaderRule) {
if rule.Operation == config.HeaderRuleOperationSet {
if rule.ValueFrom != nil && rule.ValueFrom.ContextField != "" {
val := getCustomDynamicAttributeValue(rule.ValueFrom, getRequestContext(request.Context()), nil)
value := fmt.Sprintf("%v", val)
if value != "" {
request.Header.Set(rule.Name, value)
}
return
}
request.Header.Set(rule.Name, rule.Value)
return
}
if rule.Operation != config.HeaderRuleOperationPropagate {
return
}
/**
* Rename the header before propagating and delete the original
*/
if rule.Rename != "" && rule.Named != "" {
// Ignore the rule when the target header is in the ignored list
if slices.Contains(ignoredHeaders, rule.Rename) {
return
}
value := ctx.Request().Header.Get(rule.Named)
if value != "" {
request.Header.Set(rule.Rename, ctx.Request().Header.Get(rule.Named))
request.Header.Del(rule.Named)
return
} else if rule.Default != "" {
request.Header.Set(rule.Rename, rule.Default)
request.Header.Del(rule.Named)
return
}
return
}
/**
* Propagate the header as is
*/
if rule.Named != "" {
if slices.Contains(ignoredHeaders, rule.Named) {
return
}
values := ctx.Request().Header.Values(rule.Named)
if len(values) > 0 {
request.Header[http.CanonicalHeaderKey(rule.Named)] = values
} else if rule.Default != "" {
request.Header.Set(rule.Named, rule.Default)
}
return
}
/**
* Matching based on regex
*/
if regex, ok := h.regex[rule.Matching]; ok {
for name := range ctx.Request().Header {
// Headers are case-insensitive, but Go canonicalize them
// Issue: https://github.com/golang/go/issues/37834
if regex.MatchString(name) {
/**
* Rename the header before propagating and delete the original
*/
if rule.Rename != "" && rule.Named == "" {
if slices.Contains(ignoredHeaders, rule.Rename) {
continue
}
value := ctx.Request().Header.Get(name)
if value != "" {
request.Header.Set(rule.Rename, ctx.Request().Header.Get(name))
request.Header.Del(name)
} else if rule.Default != "" {
request.Header.Set(rule.Rename, rule.Default)
request.Header.Del(name)
}
continue
}
/**
* Propagate the header as is
*/
if slices.Contains(ignoredHeaders, name) {
continue
}
request.Header.Set(name, ctx.Request().Header.Get(name))
}
}
}
}
func (h *HeaderPropagation) applyResponseRuleMostRestrictiveCacheControl(res *http.Response, propagation *responseHeaderPropagation, rule *config.ResponseHeaderRule) {
if propagation.setCacheControl {
// Handle the case where the cache control header is set explicitly using the set propagation rule
return
}
ctx := res.Request.Context()
tracer := rtrace.TracerFromContext(ctx)
commonAttributes := []attribute.KeyValue{
otel.WgOperationProtocol.String(OperationProtocolHTTP.String()),
}
_, span := tracer.Start(ctx, "HeaderPropagation - RestrictiveCacheControl",
trace.WithSpanKind(trace.SpanKindInternal),
trace.WithAttributes(commonAttributes...),
)
defer span.End()
// Set no-cache for all mutations, to ensure that requests to mutate data always work as expected (without returning cached data)
if resolve.SingleFlightDisallowed(ctx) {
propagation.header.Set(cacheControlKey, noCache)
return
}
reqCacheHeader := res.Request.Header.Get(cacheControlKey)
resCacheHeader := res.Header.Get(cacheControlKey)
expiresHeader, _ := http.ParseTime(res.Header.Get(expiresKey))
dateHeader, _ := http.ParseTime(res.Header.Get("Date"))
lastModifiedHeader, _ := http.ParseTime(res.Header.Get("Last-Modified"))
if propagation.previousCacheControl == nil && reqCacheHeader == "" && resCacheHeader == "" && expiresHeader.IsZero() && rule.Default == "" {
// There is no default/previous value to set, and since no cache control headers have been set, exit early
return
}
reqDir, _ := cachedirective.ParseRequestCacheControl(reqCacheHeader)
resDir, _ := cachedirective.ParseResponseCacheControl(resCacheHeader)
obj := &cachedirective.Object{
RespDirectives: resDir,
RespHeaders: res.Header,
RespStatusCode: res.StatusCode,
RespExpiresHeader: expiresHeader,
RespDateHeader: dateHeader,
RespLastModifiedHeader: lastModifiedHeader,
ReqDirectives: reqDir,
ReqHeaders: res.Request.Header,
NowUTC: time.Now().UTC(),
}
rv := cachedirective.ObjectResults{}
cachedirective.CachableObject(obj, &rv)
cachedirective.ExpirationObject(obj, &rv)
span.SetAttributes(
otel.WgResponseCacheControlReasons.String(fmt.Sprint(rv.OutReasons)),
otel.WgResponseCacheControlWarnings.String(fmt.Sprint(rv.OutWarnings)),
otel.WgResponseCacheControlExpiration.String(rv.OutExpirationTime.String()),
)
// Add each cache control object to the policies list
policies := []*cachedirective.Object{obj}
if rule.Default != "" {
defaultResponseCache, _ := cachedirective.ParseResponseCacheControl(rule.Default)
policies = append(policies, &cachedirective.Object{RespDirectives: defaultResponseCache})
}
if propagation.previousCacheControl != nil {
policies = append(policies, propagation.previousCacheControl)
}
// Determine the most restrictive cache policy and cache control header
restrictivePolicy, cacheControlHeader := createMostRestrictivePolicy(policies)
propagation.m.Lock()
defer propagation.m.Unlock()
propagation.previousCacheControl = restrictivePolicy
if cacheControlHeader != "" {
propagation.header.Set(cacheControlKey, cacheControlHeader)
}
// Update the Expires header if applicable
if !expiresHeader.IsZero() && !restrictivePolicy.RespExpiresHeader.IsZero() {
propagation.header.Set(expiresKey, restrictivePolicy.RespExpiresHeader.Format(http.TimeFormat))
}
}
func createMostRestrictivePolicy(policies []*cachedirective.Object) (*cachedirective.Object, string) {
result := cachedirective.Object{
RespDirectives: &cachedirective.ResponseCacheDirectives{},
}
var minMaxAge cachedirective.DeltaSeconds = -1
isPrivate := false
isPublic := false
for _, policy := range policies {
// Check no-store and no-cache first
if policy.RespDirectives.NoStore {
result.RespDirectives.NoStore = true
return &result, "no-store"
}
if policy.RespDirectives.NoCachePresent {
result.RespDirectives.NoCachePresent = true
}
// Determine the shortest max-age if available
if policy.RespDirectives.MaxAge > 0 && (minMaxAge == -1 || policy.RespDirectives.MaxAge < minMaxAge) {
minMaxAge = policy.RespDirectives.MaxAge
}
// Track if any policy specifies "private"
if policy.RespDirectives.PrivatePresent {
isPrivate = true
} else if policy.RespDirectives.Public {
isPublic = true
}
// Handle expires header comparisons
if policy.RespExpiresHeader.Before(result.RespExpiresHeader) || result.RespExpiresHeader.IsZero() {
result.RespExpiresHeader = policy.RespExpiresHeader
}
}
// Set the calculated max-age and privacy level on the result
if minMaxAge > 0 {
result.RespDirectives.MaxAge = minMaxAge
}
result.RespDirectives.PrivatePresent = isPrivate
// Format the final Cache-Control header
headerParts := []string{}
if result.RespDirectives.NoCachePresent {
headerParts = append(headerParts, noCache)
} else if minMaxAge > 0 {
headerParts = append(headerParts, fmt.Sprintf("max-age=%d", minMaxAge))
}
if isPrivate {
headerParts = append(headerParts, "private")
} else if isPublic {
headerParts = append(headerParts, "public")
}
cacheControlHeader := strings.Join(headerParts, ", ")
return &result, cacheControlHeader
}
// SubgraphRules returns the list of header rules for the subgraph with the given name
func SubgraphRules(rules *config.HeaderRules, subgraphName string) []*config.RequestHeaderRule {
if rules == nil {
return nil
}
var subgraphRules []*config.RequestHeaderRule
if rules.All != nil {
subgraphRules = append(subgraphRules, rules.All.Request...)
}
if rules.Subgraphs != nil {
if subgraphSpecificRules, ok := rules.Subgraphs[subgraphName]; ok {
subgraphRules = append(subgraphRules, subgraphSpecificRules.Request...)
}
}
return subgraphRules
}
// FetchURLRules returns the list of header rules for first subgraph that matches the given URL
func FetchURLRules(rules *config.HeaderRules, subgraphs []*nodev1.Subgraph, routingURL string) []*config.RequestHeaderRule {
var subgraphName string
for _, subgraph := range subgraphs {
if subgraph.RoutingUrl == routingURL {
subgraphName = subgraph.Name
break
}
}
return SubgraphRules(rules, subgraphName)
}
// PropagatedHeaders returns the list of header names and regular expressions
// that will be propagated when applying the given rules.
func PropagatedHeaders(rules []*config.RequestHeaderRule) (headerNames []string, headerNameRegexps []*regexp.Regexp, err error) {
for _, rule := range rules {
switch rule.Operation {
case config.HeaderRuleOperationSet:
if rule.Name == "" || (rule.Value == "" && rule.ValueFrom == nil) {
return nil, nil, fmt.Errorf("invalid header set rule %+v, no header name/value combination", rule)
}
headerNames = append(headerNames, rule.Name)
case config.HeaderRuleOperationPropagate:
if rule.Matching != "" {
// Header Names are case insensitive: https://www.w3.org/Protocols/rfc2616/rfc2616.html
re, err := regexp.Compile(caseInsensitiveRegexp + rule.Matching)
if err != nil {
return nil, nil, fmt.Errorf("error compiling regular expression %q in header rule %+v: %w", rule.Matching, rule, err)
}
headerNameRegexps = append(headerNameRegexps, re)
} else if rule.Named != "" {
headerNames = append(headerNames, rule.Named)
} else {
return nil, nil, fmt.Errorf("invalid header propagation rule %+v, no header name nor regular expression", rule)
}
default:
return nil, nil, fmt.Errorf("invalid header rule operation %q in rule %+v", rule.Operation, rule)
}
}
return headerNames, headerNameRegexps, nil
}