router/core/context.go (475 lines of code) (raw):
package core
import (
"context"
"net/http"
"net/url"
"strings"
"sync"
"time"
"github.com/expr-lang/expr/vm"
"go.opentelemetry.io/otel/attribute"
"go.uber.org/zap"
"github.com/wundergraph/astjson"
"github.com/wundergraph/graphql-go-tools/v2/pkg/engine/datasource/httpclient"
"github.com/wundergraph/graphql-go-tools/v2/pkg/engine/resolve"
graphqlmetrics "github.com/wundergraph/cosmo/router/gen/proto/wg/cosmo/graphqlmetrics/v1"
"github.com/wundergraph/cosmo/router/internal/expr"
"github.com/wundergraph/cosmo/router/pkg/authentication"
"github.com/wundergraph/cosmo/router/pkg/config"
ctrace "github.com/wundergraph/cosmo/router/pkg/trace"
)
type contextKey int
const (
requestContextKey contextKey = iota
subgraphResolverContextKey
engineLoaderHooksContextKey
)
var _ RequestContext = (*requestContext)(nil)
type Subgraph struct {
Id string
Name string
Url *url.URL
UrlString string
}
type ClientInfo struct {
// Name contains the client name, derived from the request headers
Name string
// Version contains the client version, derived from the request headers
Version string
// WGRequestToken contains the token to authenticate the request from the platform
WGRequestToken string
}
func NewClientInfoFromRequest(r *http.Request, clientHeader config.ClientHeader) *ClientInfo {
clientName := ctrace.GetClientHeader(r.Header, []string{clientHeader.Name, "graphql-client-name", "apollographql-client-name"}, "unknown")
clientVersion := ctrace.GetClientHeader(r.Header, []string{clientHeader.Version, "graphql-client-version", "apollographql-client-version"}, "missing")
requestToken := r.Header.Get("X-WG-Token")
return &ClientInfo{
Name: clientName,
Version: clientVersion,
WGRequestToken: requestToken,
}
}
type RequestContext interface {
// ResponseWriter is the original response writer received by the router.
ResponseWriter() http.ResponseWriter
// Request is the original request received by the router.
Request() *http.Request
// Logger is the logger for the request
Logger() *zap.Logger
// Operation is the GraphQL operation
Operation() OperationContext
// Set is used to store a new key/value pair exclusively for this context.
Set(string, any)
// Get returns the value for the given key, ie: (value, true).
Get(string) (value any, exists bool)
// GetString returns the value associated with the key as a string.
GetString(string) string
// MustGet returns the value for the given key if it exists, otherwise it panics.
MustGet(string) any
// GetBool returns the value associated with the key as a boolean.
GetBool(string) bool
// GetInt returns the value associated with the key as an integer.
GetInt(string) int
// GetInt64 returns the value associated with the key as an integer.
GetInt64(string) int64
// GetUint returns the value associated with the key as an unsigned integer.
GetUint(string) uint
// GetUint64 returns the value associated with the key as an unsigned integer.
GetUint64(string) uint64
// GetFloat64 returns the value associated with the key as a float64.
GetFloat64(string) float64
// GetTime returns the value associated with the key as time.
GetTime(string) time.Time
// GetDuration returns the value associated with the key as a duration.
GetDuration(string) time.Duration
// GetStringSlice returns the value associated with the key as a slice of strings.
GetStringSlice(string) []string
// GetStringMap returns the value associated with the key as a map of interfaces.
GetStringMap(string) map[string]any
// GetStringMapString returns the value associated with the key as a map of strings.
GetStringMapString(string) map[string]string
// GetStringMapStringSlice returns the value associated with the key as a map to a slice of strings.
GetStringMapStringSlice(string) map[string][]string
// ActiveSubgraph returns the current subgraph to which the request is made to
ActiveSubgraph(subgraphRequest *http.Request) *Subgraph
// SubgraphByID return the subgraph by its ID
SubgraphByID(subgraphID string) *Subgraph
// Authentication returns the authentication information for the request, if any
Authentication() authentication.Authentication
// SetAuthenticationScopes sets the scopes for the request on Authentication
// If Authentication is not set, it will be initialized with the scopes
SetAuthenticationScopes(scopes []string)
}
var metricAttrsPool = sync.Pool{
New: func() any {
v := make([]attribute.KeyValue, 0, 20)
return &v
},
}
type requestTelemetryAttributes struct {
// traceAttrs are the base attributes for traces only
traceAttrs []attribute.KeyValue
// metricAttrs are the attributes for metrics only
metricAttrs []attribute.KeyValue
// metricSetAttrs is map to quickly check if a metric attribute is set and to what key it is remapped
metricSetAttrs map[string]string
// metricSliceAttrs are the attributes for metrics that are string slices and needs to be exploded for prometheus
metricSliceAttrs []attribute.KeyValue
// mapper is an attribute mapper for context attributes.
// It is used to identify attributes that should not be included by default but can be included if they are
// configured in the custom attributes list. The mapper will potentially filter out attributes or include them.
// It will also remap the key if configured.
mapper *attributeMapper
// traceAttributeExpressions is a map of expressions that can be used to resolve dynamic attributes in traces
traceAttributeExpressions *attributeExpressions
// metricAttributeExpressions is a map of expressions that can be used to resolve dynamic attributes in metrics
metricAttributeExpressions *attributeExpressions
// metricsEnabled indicates if metrics are enabled. If false, no metrics attributes will be added
metricsEnabled bool
// traceEnabled indicates if traces are enabled, if false, no trace attributes will be added
traceEnabled bool
}
func (r *requestTelemetryAttributes) AcquireAttributes() *[]attribute.KeyValue {
if !r.metricsEnabled && !r.traceEnabled {
return &[]attribute.KeyValue{}
}
return metricAttrsPool.Get().(*[]attribute.KeyValue)
}
func (r *requestTelemetryAttributes) ReleaseAttributes(attrs *[]attribute.KeyValue) {
if !r.metricsEnabled && !r.traceEnabled {
return
}
// reset slice
*attrs = (*attrs)[:0]
// If the slice is too big, we don't pool it to avoid holding on to too much memory
if cap(*attrs) > 128 {
return
}
metricAttrsPool.Put(attrs)
}
func (r *requestTelemetryAttributes) AddCustomMetricStringSliceAttr(key string, values []string) {
if !r.metricsEnabled {
return
}
if remapKey, ok := r.metricSetAttrs[key]; ok && len(values) > 0 {
v := attribute.StringSlice(remapKey, values)
r.metricSliceAttrs = append(r.metricSliceAttrs, v)
}
}
func (r *requestTelemetryAttributes) addCustomMetricStringAttr(key string, value string) {
if !r.metricsEnabled {
return
}
if remapKey, ok := r.metricSetAttrs[key]; ok && value != "" {
v := attribute.String(remapKey, value)
r.metricAttrs = append(r.metricAttrs, v)
}
}
func (r *requestTelemetryAttributes) addCommonAttribute(vals ...attribute.KeyValue) {
r.addMetricAttribute(vals...)
r.addCommonTraceAttribute(vals...)
}
func (r *requestTelemetryAttributes) addCommonTraceAttribute(vals ...attribute.KeyValue) {
if !r.traceEnabled {
return
}
r.traceAttrs = append(r.traceAttrs, vals...)
}
func (r *requestTelemetryAttributes) addMetricAttribute(vals ...attribute.KeyValue) {
if !r.metricsEnabled {
return
}
r.metricAttrs = append(r.metricAttrs, r.mapper.mapAttributes(vals)...)
}
// requestContext is the default implementation of RequestContext
// It is accessible to custom modules in the request lifecycle
type requestContext struct {
logger *zap.Logger
// This mutex protects keys map.
mu sync.RWMutex
// keys is a key/value pair exclusively for the context of each request.
keys map[string]any
// responseWriter is the original response writer received by the router.
responseWriter http.ResponseWriter
// error indicates if the request / response has an error. Due to the nature of GraphQL
// Error can be set even if the response status code is 200.
error error
// request is the original request received by the router.
request *http.Request
// operation is the GraphQL operation context
operation *operationContext
// subgraphResolver can be used to resolve Subgraph by ID or by request
subgraphResolver *SubgraphResolver
// dataSourceNames the list of datasource involved in resolving the operation
dataSourceNames []string
// graphQLErrorServices are the services that produced the GraphQL errors
graphQLErrorServices []string
// graphQLErrorCodes are the error codes of the GraphQL errors
graphQLErrorCodes []string
// telemetry are the base telemetry information of the request
telemetry *requestTelemetryAttributes
// expressionContext is the context that will be provided to a compiled expression in order to retrieve data via dynamic expressions
expressionContext expr.Context
}
func (c *requestContext) SetError(err error) {
c.error = err
c.expressionContext.Request.Error = err
}
func (c *requestContext) ResolveAnyExpressionWithWrappedError(expression *vm.Program) (any, error) {
// If an error exists already, wrap it and resolve the expression with the copied context
if c.expressionContext.Request.Error != nil {
// This will create a copy of the base expressionContext which we can modify
copyContext := c.expressionContext
copyContext.Request.Error = &ExprWrapError{c.expressionContext.Request.Error}
return expr.ResolveAnyExpression(expression, copyContext)
}
return expr.ResolveAnyExpression(expression, c.expressionContext)
}
func (c *requestContext) ResolveStringExpression(expression *vm.Program) (string, error) {
return expr.ResolveStringExpression(expression, c.expressionContext)
}
func (c *requestContext) ResolveBoolExpression(expression *vm.Program) (bool, error) {
return expr.ResolveBoolExpression(expression, c.expressionContext)
}
func (c *requestContext) Operation() OperationContext {
return c.operation
}
func (c *requestContext) Request() *http.Request {
return c.request
}
func withRequestContext(ctx context.Context, operation *requestContext) context.Context {
return context.WithValue(ctx, requestContextKey, operation)
}
func getRequestContext(ctx context.Context) *requestContext {
if ctx == nil {
return nil
}
op := ctx.Value(requestContextKey)
if op == nil {
return nil
}
return op.(*requestContext)
}
func (c *requestContext) ResponseWriter() http.ResponseWriter {
return c.responseWriter
}
func (c *requestContext) Logger() *zap.Logger {
return c.logger
}
// Set is used to store a new key/value pair exclusively for this context.
// It also lazy initializes c.keys if it was not used previously.
func (c *requestContext) Set(key string, value any) {
c.mu.Lock()
defer c.mu.Unlock()
if c.keys == nil {
c.keys = make(map[string]any)
}
c.keys[key] = value
}
// Get returns the value for the given key, ie: (value, true).
// If the value does not exist it returns (nil, false)
func (c *requestContext) Get(key string) (value any, exists bool) {
c.mu.RLock()
defer c.mu.RUnlock()
value, exists = c.keys[key]
return
}
// MustGet returns the value for the given key if it exists, otherwise it panics.
func (c *requestContext) MustGet(key string) any {
if value, exists := c.Get(key); exists {
return value
}
panic("Key \"" + key + "\" does not exist")
}
// GetString returns the value associated with the key as a string.
func (c *requestContext) GetString(key string) (s string) {
if val, ok := c.Get(key); ok && val != nil {
s, _ = val.(string)
}
return
}
// GetBool returns the value associated with the key as a boolean.
func (c *requestContext) GetBool(key string) (b bool) {
if val, ok := c.Get(key); ok && val != nil {
b, _ = val.(bool)
}
return
}
// GetInt returns the value associated with the key as an integer.
func (c *requestContext) GetInt(key string) (i int) {
if val, ok := c.Get(key); ok && val != nil {
i, _ = val.(int)
}
return
}
// GetInt64 returns the value associated with the key as an integer.
func (c *requestContext) GetInt64(key string) (i64 int64) {
if val, ok := c.Get(key); ok && val != nil {
i64, _ = val.(int64)
}
return
}
// GetUint returns the value associated with the key as an unsigned integer.
func (c *requestContext) GetUint(key string) (ui uint) {
if val, ok := c.Get(key); ok && val != nil {
ui, _ = val.(uint)
}
return
}
// GetUint64 returns the value associated with the key as an unsigned integer.
func (c *requestContext) GetUint64(key string) (ui64 uint64) {
if val, ok := c.Get(key); ok && val != nil {
ui64, _ = val.(uint64)
}
return
}
// GetFloat64 returns the value associated with the key as a float64.
func (c *requestContext) GetFloat64(key string) (f64 float64) {
if val, ok := c.Get(key); ok && val != nil {
f64, _ = val.(float64)
}
return
}
// GetTime returns the value associated with the key as time.
func (c *requestContext) GetTime(key string) (t time.Time) {
if val, ok := c.Get(key); ok && val != nil {
t, _ = val.(time.Time)
}
return
}
// GetDuration returns the value associated with the key as a duration.
func (c *requestContext) GetDuration(key string) (d time.Duration) {
if val, ok := c.Get(key); ok && val != nil {
d, _ = val.(time.Duration)
}
return
}
// GetStringSlice returns the value associated with the key as a slice of strings.
func (c *requestContext) GetStringSlice(key string) (ss []string) {
if val, ok := c.Get(key); ok && val != nil {
ss, _ = val.([]string)
}
return
}
// GetStringMap returns the value associated with the key as a map of interfaces.
func (c *requestContext) GetStringMap(key string) (sm map[string]any) {
if val, ok := c.Get(key); ok && val != nil {
sm, _ = val.(map[string]any)
}
return
}
// GetStringMapString returns the value associated with the key as a map of strings.
func (c *requestContext) GetStringMapString(key string) (sms map[string]string) {
if val, ok := c.Get(key); ok && val != nil {
sms, _ = val.(map[string]string)
}
return
}
// GetStringMapStringSlice returns the value associated with the key as a map to a slice of strings.
func (c *requestContext) GetStringMapStringSlice(key string) (smss map[string][]string) {
if val, ok := c.Get(key); ok && val != nil {
smss, _ = val.(map[string][]string)
}
return
}
func (c *requestContext) ActiveSubgraph(subgraphRequest *http.Request) *Subgraph {
if subgraphRequest == nil || subgraphRequest.URL == nil {
return nil
}
return c.subgraphResolver.BySubgraphURL(subgraphRequest.URL.String())
}
func (c *requestContext) SubgraphByID(subgraphID string) *Subgraph {
return c.subgraphResolver.ByID(subgraphID)
}
func (c *requestContext) Authentication() authentication.Authentication {
return authentication.FromContext(c.request.Context())
}
func (c *requestContext) SetAuthenticationScopes(scopes []string) {
auth := authentication.FromContext(c.request.Context())
if auth == nil {
auth = authentication.NewEmptyAuthentication()
c.request = c.request.WithContext(authentication.NewContext(c.request.Context(), auth))
}
auth.SetScopes(scopes)
}
type OperationContext interface {
// Name is the name of the operation
Name() string
// Type is the type of the operation (query, mutation, subscription)
Type() string
// Hash is the hash of the operation
Hash() uint64
// Content is the content of the operation
Content() string
// ClientInfo returns information about the client that initiated this operation
ClientInfo() ClientInfo
}
var _ OperationContext = (*operationContext)(nil)
type OperationType = string
const (
OperationTypeQuery OperationType = "query"
OperationTypeMutation OperationType = "mutation"
OperationTypeSubscription OperationType = "subscription"
)
// operationContext contains information about the current GraphQL operation
type operationContext struct {
// Name is the name of the operation
name string
// opType is the type of the operation (query, mutation, subscription)
opType OperationType
// hash is the hash of the operation with the normalized content and variables. Used for analytics.
hash uint64
// internalHash is the hash of the operation with the fully normalized content. Used for engine / executor caching.
// we can't use the hash for this due to engine limitations in handling variables with the normalized representation
internalHash uint64
// remapVariables is a map of variables that have been remapped to the new names
remapVariables map[string]string
// RawContent is the raw content of the operation
rawContent string
// Content is the normalized content of the operation
content string
variables *astjson.Value
files []*httpclient.FileUpload
clientInfo *ClientInfo
// preparedPlan is the prepared plan of the operation
preparedPlan *planWithMetaData
traceOptions resolve.TraceOptions
executionOptions resolve.ExecutionOptions
planCacheHit bool
initialPayload []byte
extensions []byte
persistedID string
// Hash on the original operation
sha256Hash string
protocol OperationProtocol
persistedOperationCacheHit bool
normalizationCacheHit bool
typeFieldUsageInfo []*graphqlmetrics.TypeFieldUsageInfo
argumentUsageInfo []*graphqlmetrics.ArgumentUsageInfo
inputUsageInfo []*graphqlmetrics.InputUsageInfo
parsingTime time.Duration
validationTime time.Duration
planningTime time.Duration
normalizationTime time.Duration
}
func (o *operationContext) Variables() *astjson.Value {
return o.variables
}
func (o *operationContext) Files() []*httpclient.FileUpload {
return o.files
}
func (o *operationContext) Name() string {
return o.name
}
func (o *operationContext) Type() string {
return o.opType
}
func (o *operationContext) Hash() uint64 {
return o.hash
}
func (o *operationContext) Content() string {
return o.content
}
func (o *operationContext) PersistedID() string {
return o.persistedID
}
func (o *operationContext) Protocol() OperationProtocol {
return o.protocol
}
func (o *operationContext) ClientInfo() ClientInfo {
return *o.clientInfo
}
// isMutationRequest returns true if the current request is a mutation request
func isMutationRequest(ctx context.Context) bool {
op := getRequestContext(ctx)
if op == nil {
return false
}
return op.Operation().Type() == "mutation"
}
type SubgraphResolver struct {
subgraphsByURL map[string]*Subgraph
subgraphsByID map[string]*Subgraph
}
func NewSubgraphResolver(subgraphs []Subgraph) *SubgraphResolver {
resolver := &SubgraphResolver{
subgraphsByURL: make(map[string]*Subgraph, len(subgraphs)),
subgraphsByID: make(map[string]*Subgraph, len(subgraphs)),
}
for i := range subgraphs {
sg := Subgraph{
Id: subgraphs[i].Id,
Name: subgraphs[i].Name,
Url: subgraphs[i].Url,
UrlString: subgraphs[i].UrlString,
}
if sg.UrlString != "" {
resolver.subgraphsByURL[sg.UrlString] = &sg
}
if sg.Id != "" {
resolver.subgraphsByID[sg.Id] = &sg
}
}
return resolver
}
func (s *SubgraphResolver) ByID(subgraphID string) *Subgraph {
return s.subgraphsByID[subgraphID]
}
func (s *SubgraphResolver) BySubgraphURL(u string) *Subgraph {
if len(u) >= 4 && u[:4] == "http" {
return s.subgraphsByURL[u]
}
if strings.HasPrefix(u, "wss") {
u = "https" + u[3:]
} else if strings.HasPrefix(u, "ws") {
u = "http" + u[2:]
}
return s.subgraphsByURL[u]
}
func withSubgraphResolver(ctx context.Context, resolver *SubgraphResolver) context.Context {
return context.WithValue(ctx, subgraphResolverContextKey, resolver)
}
func subgraphResolverFromContext(ctx context.Context) *SubgraphResolver {
resolver, _ := ctx.Value(subgraphResolverContextKey).(*SubgraphResolver)
return resolver
}
type requestContextOptions struct {
operationContext *operationContext
requestLogger *zap.Logger
metricSetAttributes map[string]string
metricsEnabled bool
traceEnabled bool
mapper *attributeMapper
metricAttributeExpressions *attributeExpressions
telemetryAttributeExpressions *attributeExpressions
w http.ResponseWriter
r *http.Request
}
func buildRequestContext(opts requestContextOptions) *requestContext {
rootCtx := expr.Context{
Request: expr.LoadRequest(opts.r),
}
return &requestContext{
logger: opts.requestLogger,
keys: map[string]any{},
responseWriter: opts.w,
request: opts.r,
operation: opts.operationContext,
telemetry: &requestTelemetryAttributes{
metricSetAttrs: opts.metricSetAttributes,
metricsEnabled: opts.metricsEnabled,
traceEnabled: opts.traceEnabled,
mapper: opts.mapper,
traceAttributeExpressions: opts.telemetryAttributeExpressions,
metricAttributeExpressions: opts.metricAttributeExpressions,
},
expressionContext: rootCtx,
subgraphResolver: subgraphResolverFromContext(opts.r.Context()),
}
}