router/core/factoryresolver.go (511 lines of code) (raw):

package core import ( "context" "encoding/json" "fmt" "net/http" "net/url" "slices" "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/argument_templates" "github.com/buger/jsonparser" "github.com/wundergraph/cosmo/router/pkg/config" "github.com/jensneuse/abstractlogger" "go.uber.org/zap" "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/datasource/graphql_datasource" "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/datasource/pubsub_datasource" "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/datasource/staticdatasource" "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/plan" "github.com/wundergraph/cosmo/router/gen/proto/wg/cosmo/common" nodev1 "github.com/wundergraph/cosmo/router/gen/proto/wg/cosmo/node/v1" ) type Loader struct { resolver FactoryResolver // includeInfo controls whether additional information like type usage and field usage is included in the plan includeInfo bool } type FactoryResolver interface { ResolveGraphqlFactory(subgraphName string) (plan.PlannerFactory[graphql_datasource.Configuration], error) ResolveStaticFactory() (plan.PlannerFactory[staticdatasource.Configuration], error) ResolvePubsubFactory() (plan.PlannerFactory[pubsub_datasource.Configuration], error) } type ApiTransportFactory interface { RoundTripper(enableSingleFlight bool, transport http.RoundTripper) http.RoundTripper DefaultHTTPProxyURL() *url.URL } type DefaultFactoryResolver struct { baseTransport http.RoundTripper transportFactory ApiTransportFactory transportOptions *TransportOptions static *staticdatasource.Factory[staticdatasource.Configuration] pubsub *pubsub_datasource.Factory[pubsub_datasource.Configuration] log *zap.Logger engineCtx context.Context httpClient *http.Client enableSingleFlight bool streamingClient *http.Client subscriptionClient graphql_datasource.GraphQLSubscriptionClient factoryLogger abstractlogger.Logger } func NewDefaultFactoryResolver( ctx context.Context, transportOptions *TransportOptions, baseTransport http.RoundTripper, log *zap.Logger, enableSingleFlight bool, enableNetPoll bool, natsPubSubBySourceID map[string]pubsub_datasource.NatsPubSub, kafkaPubSubBySourceID map[string]pubsub_datasource.KafkaPubSub, ) *DefaultFactoryResolver { transportFactory := NewTransport(transportOptions) defaultHttpClient := &http.Client{ Timeout: transportOptions.SubgraphTransportOptions.RequestTimeout, Transport: transportFactory.RoundTripper(enableSingleFlight, baseTransport), } streamingClient := &http.Client{ Transport: transportFactory.RoundTripper(enableSingleFlight, baseTransport), } var factoryLogger abstractlogger.Logger if log != nil { factoryLogger = abstractlogger.NewZapLogger(log, abstractlogger.DebugLevel) } var netPollConfig graphql_datasource.NetPollConfiguration netPollConfig.ApplyDefaults() netPollConfig.Enable = enableNetPoll subscriptionClient := graphql_datasource.NewGraphQLSubscriptionClient( defaultHttpClient, streamingClient, ctx, graphql_datasource.WithLogger(factoryLogger), graphql_datasource.WithNetPollConfiguration(netPollConfig), ) return &DefaultFactoryResolver{ baseTransport: baseTransport, transportFactory: transportFactory, transportOptions: transportOptions, static: &staticdatasource.Factory[staticdatasource.Configuration]{}, pubsub: pubsub_datasource.NewFactory(ctx, natsPubSubBySourceID, kafkaPubSubBySourceID), log: log, factoryLogger: factoryLogger, engineCtx: ctx, httpClient: defaultHttpClient, enableSingleFlight: enableSingleFlight, streamingClient: streamingClient, subscriptionClient: subscriptionClient, } } func (d *DefaultFactoryResolver) ResolveGraphqlFactory(subgraphName string) (plan.PlannerFactory[graphql_datasource.Configuration], error) { if subgraphName != "" && d.transportOptions.SubgraphTransportOptions.SubgraphMap[subgraphName] != nil { timeout := d.transportOptions.SubgraphTransportOptions.SubgraphMap[subgraphName].RequestTimeout // make a new http client subgraphClient := &http.Client{ Transport: d.transportFactory.RoundTripper(d.enableSingleFlight, d.baseTransport), Timeout: timeout, } factory, err := graphql_datasource.NewFactory(d.engineCtx, subgraphClient, d.subscriptionClient) return factory, err } factory, err := graphql_datasource.NewFactory(d.engineCtx, d.httpClient, d.subscriptionClient) return factory, err } func (d *DefaultFactoryResolver) ResolveStaticFactory() (factory plan.PlannerFactory[staticdatasource.Configuration], err error) { return d.static, nil } func (d *DefaultFactoryResolver) ResolvePubsubFactory() (factory plan.PlannerFactory[pubsub_datasource.Configuration], err error) { return d.pubsub, nil } func NewLoader(includeInfo bool, resolver FactoryResolver) *Loader { return &Loader{ resolver: resolver, includeInfo: includeInfo, } } func (l *Loader) LoadInternedString(engineConfig *nodev1.EngineConfiguration, str *nodev1.InternedString) (string, error) { key := str.GetKey() s, ok := engineConfig.StringStorage[key] if !ok { return "", fmt.Errorf("no string found for key %q", key) } return s, nil } type RouterEngineConfiguration struct { Execution config.EngineExecutionConfiguration Headers *config.HeaderRules Events config.EventsConfiguration SubgraphErrorPropagation config.SubgraphErrorPropagationConfiguration } func mapProtoFilterToPlanFilter(input *nodev1.SubscriptionFilterCondition, output *plan.SubscriptionFilterCondition) *plan.SubscriptionFilterCondition { if input == nil { return nil } if input.And != nil { output.And = make([]plan.SubscriptionFilterCondition, len(input.And)) for i := range input.And { mapProtoFilterToPlanFilter(input.And[i], &output.And[i]) } return output } if input.In != nil { var values []string _, err := jsonparser.ArrayEach([]byte(input.In.Json), func(value []byte, dataType jsonparser.ValueType, offset int, err error) { // if the value is not a string, just append it as is because this is the JSON // representation of the value. If it contains a template, we want to keep it as // is to explode it later with the actual values if dataType != jsonparser.String || argument_templates.ContainsArgumentTemplateString(value) { values = append(values, string(value)) return } // stringify values to prevent its actual type from being lost // during the transport to the engine as bytes marshaledValue, mErr := json.Marshal(string(value)) if mErr != nil { return } values = append(values, string(marshaledValue)) }) if err != nil { return nil } output.In = &plan.SubscriptionFieldCondition{ FieldPath: input.In.FieldPath, Values: values, } return output } if input.Not != nil { output.Not = mapProtoFilterToPlanFilter(input.Not, &plan.SubscriptionFilterCondition{}) return output } if input.Or != nil { output.Or = make([]plan.SubscriptionFilterCondition, len(input.Or)) for i := range input.Or { output.Or[i] = plan.SubscriptionFilterCondition{} mapProtoFilterToPlanFilter(input.Or[i], &output.Or[i]) } return output } return nil } func (l *Loader) Load(engineConfig *nodev1.EngineConfiguration, subgraphs []*nodev1.Subgraph, routerEngineConfig *RouterEngineConfiguration) (*plan.Configuration, error) { var outConfig plan.Configuration // attach field usage information to the plan outConfig.DefaultFlushIntervalMillis = engineConfig.DefaultFlushInterval for _, configuration := range engineConfig.FieldConfigurations { var args []plan.ArgumentConfiguration for _, argumentConfiguration := range configuration.ArgumentsConfiguration { arg := plan.ArgumentConfiguration{ Name: argumentConfiguration.Name, } switch argumentConfiguration.SourceType { case nodev1.ArgumentSource_FIELD_ARGUMENT: arg.SourceType = plan.FieldArgumentSource case nodev1.ArgumentSource_OBJECT_FIELD: arg.SourceType = plan.ObjectFieldSource } args = append(args, arg) } fieldConfig := plan.FieldConfiguration{ TypeName: configuration.TypeName, FieldName: configuration.FieldName, Arguments: args, HasAuthorizationRule: l.fieldHasAuthorizationRule(configuration), SubscriptionFilterCondition: mapProtoFilterToPlanFilter(configuration.SubscriptionFilterCondition, &plan.SubscriptionFilterCondition{}), } outConfig.Fields = append(outConfig.Fields, fieldConfig) } for _, configuration := range engineConfig.TypeConfigurations { outConfig.Types = append(outConfig.Types, plan.TypeConfiguration{ TypeName: configuration.TypeName, RenameTo: configuration.RenameTo, }) } for _, in := range engineConfig.DatasourceConfigurations { var out plan.DataSource switch in.Kind { case nodev1.DataSourceKind_STATIC: factory, err := l.resolver.ResolveStaticFactory() if err != nil { return nil, err } out, err = plan.NewDataSourceConfiguration[staticdatasource.Configuration]( in.Id, factory, l.dataSourceMetaData(in), staticdatasource.Configuration{ Data: config.LoadStringVariable(in.CustomStatic.Data), }, ) if err != nil { return nil, fmt.Errorf("error creating data source configuration for data source %s: %w", in.Id, err) } case nodev1.DataSourceKind_GRAPHQL: header := http.Header{} for s, httpHeader := range in.CustomGraphql.Fetch.Header { for _, value := range httpHeader.Values { header.Add(s, config.LoadStringVariable(value)) } } fetchUrl := config.LoadStringVariable(in.CustomGraphql.Fetch.GetUrl()) subscriptionUrl := config.LoadStringVariable(in.CustomGraphql.Subscription.Url) if subscriptionUrl == "" { subscriptionUrl = fetchUrl } customScalarTypeFields := make([]graphql_datasource.SingleTypeField, len(in.CustomGraphql.CustomScalarTypeFields)) for i, v := range in.CustomGraphql.CustomScalarTypeFields { customScalarTypeFields[i] = graphql_datasource.SingleTypeField{ TypeName: v.TypeName, FieldName: v.FieldName, } } graphqlSchema, err := l.LoadInternedString(engineConfig, in.CustomGraphql.GetUpstreamSchema()) if err != nil { return nil, fmt.Errorf("could not load GraphQL schema for data source %s: %w", in.Id, err) } var subscriptionUseSSE bool var subscriptionSSEMethodPost bool if in.CustomGraphql.Subscription.Protocol != nil { switch *in.CustomGraphql.Subscription.Protocol { case common.GraphQLSubscriptionProtocol_GRAPHQL_SUBSCRIPTION_PROTOCOL_WS: subscriptionUseSSE = false subscriptionSSEMethodPost = false case common.GraphQLSubscriptionProtocol_GRAPHQL_SUBSCRIPTION_PROTOCOL_SSE: subscriptionUseSSE = true subscriptionSSEMethodPost = false case common.GraphQLSubscriptionProtocol_GRAPHQL_SUBSCRIPTION_PROTOCOL_SSE_POST: subscriptionUseSSE = true subscriptionSSEMethodPost = true } } else { // Old style config if in.CustomGraphql.Subscription.UseSSE != nil { subscriptionUseSSE = *in.CustomGraphql.Subscription.UseSSE } } wsSubprotocol := "auto" if in.CustomGraphql.Subscription.WebsocketSubprotocol != nil { switch *in.CustomGraphql.Subscription.WebsocketSubprotocol { case common.GraphQLWebsocketSubprotocol_GRAPHQL_WEBSOCKET_SUBPROTOCOL_WS: wsSubprotocol = "graphql-ws" case common.GraphQLWebsocketSubprotocol_GRAPHQL_WEBSOCKET_SUBPROTOCOL_TRANSPORT_WS: wsSubprotocol = "graphql-transport-ws" case common.GraphQLWebsocketSubprotocol_GRAPHQL_WEBSOCKET_SUBPROTOCOL_AUTO: wsSubprotocol = "auto" } } dataSourceRules := FetchURLRules(routerEngineConfig.Headers, subgraphs, subscriptionUrl) forwardedClientHeaders, forwardedClientRegexps, err := PropagatedHeaders(dataSourceRules) if err != nil { return nil, fmt.Errorf("error parsing header rules for data source %s: %w", in.Id, err) } schemaConfiguration, err := graphql_datasource.NewSchemaConfiguration( graphqlSchema, &graphql_datasource.FederationConfiguration{ Enabled: in.CustomGraphql.Federation.Enabled, ServiceSDL: in.CustomGraphql.Federation.ServiceSdl, }, ) if err != nil { return nil, fmt.Errorf("error creating schema configuration for data source %s: %w", in.Id, err) } customConfiguration, err := graphql_datasource.NewConfiguration(graphql_datasource.ConfigurationInput{ Fetch: &graphql_datasource.FetchConfiguration{ URL: fetchUrl, Method: in.CustomGraphql.Fetch.Method.String(), Header: header, }, Subscription: &graphql_datasource.SubscriptionConfiguration{ URL: subscriptionUrl, UseSSE: subscriptionUseSSE, SSEMethodPost: subscriptionSSEMethodPost, ForwardedClientHeaderNames: forwardedClientHeaders, ForwardedClientHeaderRegularExpressions: forwardedClientRegexps, WsSubProtocol: wsSubprotocol, }, SchemaConfiguration: schemaConfiguration, CustomScalarTypeFields: customScalarTypeFields, }) if err != nil { return nil, fmt.Errorf("error creating custom configuration for data source %s: %w", in.Id, err) } dataSourceName := l.subgraphName(subgraphs, in.Id) factory, err := l.resolver.ResolveGraphqlFactory(dataSourceName) if err != nil { return nil, err } out, err = plan.NewDataSourceConfigurationWithName[graphql_datasource.Configuration]( in.Id, dataSourceName, factory, l.dataSourceMetaData(in), customConfiguration, ) if err != nil { return nil, fmt.Errorf("error creating data source configuration for data source %s: %w", in.Id, err) } case nodev1.DataSourceKind_PUBSUB: var eventConfigurations []pubsub_datasource.EventConfiguration for _, eventConfiguration := range in.GetCustomEvents().GetNats() { eventType, err := pubsub_datasource.EventTypeFromString(eventConfiguration.EngineEventConfiguration.Type.String()) if err != nil { return nil, fmt.Errorf("invalid event type %q for data source %q: %w", eventConfiguration.EngineEventConfiguration.Type.String(), in.Id, err) } var streamConfiguration *pubsub_datasource.NatsStreamConfiguration if eventConfiguration.StreamConfiguration != nil { streamConfiguration = &pubsub_datasource.NatsStreamConfiguration{ Consumer: eventConfiguration.StreamConfiguration.GetConsumerName(), StreamName: eventConfiguration.StreamConfiguration.GetStreamName(), ConsumerInactiveThreshold: eventConfiguration.StreamConfiguration.GetConsumerInactiveThreshold(), } } eventConfigurations = append(eventConfigurations, pubsub_datasource.EventConfiguration{ Metadata: &pubsub_datasource.EventMetadata{ ProviderID: eventConfiguration.EngineEventConfiguration.GetProviderId(), Type: eventType, TypeName: eventConfiguration.EngineEventConfiguration.GetTypeName(), FieldName: eventConfiguration.EngineEventConfiguration.GetFieldName(), }, Configuration: &pubsub_datasource.NatsEventConfiguration{ StreamConfiguration: streamConfiguration, Subjects: eventConfiguration.GetSubjects(), }, }) } for _, eventConfiguration := range in.GetCustomEvents().GetKafka() { eventType, err := pubsub_datasource.EventTypeFromString(eventConfiguration.EngineEventConfiguration.Type.String()) if err != nil { return nil, fmt.Errorf("invalid event type %q for data source %q: %w", eventConfiguration.EngineEventConfiguration.Type.String(), in.Id, err) } eventConfigurations = append(eventConfigurations, pubsub_datasource.EventConfiguration{ Metadata: &pubsub_datasource.EventMetadata{ ProviderID: eventConfiguration.EngineEventConfiguration.GetProviderId(), Type: eventType, TypeName: eventConfiguration.EngineEventConfiguration.GetTypeName(), FieldName: eventConfiguration.EngineEventConfiguration.GetFieldName(), }, Configuration: &pubsub_datasource.KafkaEventConfiguration{ Topics: eventConfiguration.GetTopics(), }, }) } factory, err := l.resolver.ResolvePubsubFactory() if err != nil { return nil, err } out, err = plan.NewDataSourceConfiguration[pubsub_datasource.Configuration]( in.Id, factory, l.dataSourceMetaData(in), pubsub_datasource.Configuration{ Events: eventConfigurations, }, ) if err != nil { return nil, fmt.Errorf("error creating data source configuration for data source %s: %w", in.Id, err) } default: return nil, fmt.Errorf("unknown data source type %q", in.Kind) } outConfig.DataSources = append(outConfig.DataSources, out) } return &outConfig, nil } func (l *Loader) subgraphName(subgraphs []*nodev1.Subgraph, dataSourceID string) string { i := slices.IndexFunc(subgraphs, func(s *nodev1.Subgraph) bool { return s.Id == dataSourceID }) if i != -1 { return subgraphs[i].Name } return "" } func (l *Loader) dataSourceMetaData(in *nodev1.DataSourceConfiguration) *plan.DataSourceMetadata { var d plan.DirectiveConfigurations = make([]plan.DirectiveConfiguration, 0, len(in.Directives)) out := &plan.DataSourceMetadata{ RootNodes: make([]plan.TypeField, 0, len(in.RootNodes)), ChildNodes: make([]plan.TypeField, 0, len(in.ChildNodes)), Directives: &d, FederationMetaData: plan.FederationMetaData{ Keys: make([]plan.FederationFieldConfiguration, 0, len(in.Keys)), Requires: make([]plan.FederationFieldConfiguration, 0, len(in.Requires)), Provides: make([]plan.FederationFieldConfiguration, 0, len(in.Provides)), }, } for _, node := range in.RootNodes { out.RootNodes = append(out.RootNodes, plan.TypeField{ TypeName: node.TypeName, FieldNames: node.FieldNames, ExternalFieldNames: node.ExternalFieldNames, }) } for _, node := range in.ChildNodes { out.ChildNodes = append(out.ChildNodes, plan.TypeField{ TypeName: node.TypeName, FieldNames: node.FieldNames, ExternalFieldNames: node.ExternalFieldNames, }) } for _, directive := range in.Directives { *out.Directives = append(*out.Directives, plan.DirectiveConfiguration{ DirectiveName: directive.DirectiveName, RenameTo: directive.DirectiveName, }) } for _, keyConfiguration := range in.Keys { var conditions []plan.KeyCondition if len(keyConfiguration.Conditions) > 0 { conditions = make([]plan.KeyCondition, 0, len(keyConfiguration.Conditions)) for _, condition := range keyConfiguration.Conditions { coordinates := make([]plan.KeyConditionCoordinate, 0, len(condition.FieldCoordinatesPath)) for _, coordinate := range condition.FieldCoordinatesPath { coordinates = append(coordinates, plan.KeyConditionCoordinate{ TypeName: coordinate.TypeName, FieldName: coordinate.FieldName, }) } conditions = append(conditions, plan.KeyCondition{ Coordinates: coordinates, FieldPath: condition.FieldPath, }) } } out.FederationMetaData.Keys = append(out.FederationMetaData.Keys, plan.FederationFieldConfiguration{ TypeName: keyConfiguration.TypeName, FieldName: keyConfiguration.FieldName, SelectionSet: keyConfiguration.SelectionSet, DisableEntityResolver: keyConfiguration.DisableEntityResolver, Conditions: conditions, }) } for _, providesConfiguration := range in.Provides { out.FederationMetaData.Provides = append(out.FederationMetaData.Provides, plan.FederationFieldConfiguration{ TypeName: providesConfiguration.TypeName, FieldName: providesConfiguration.FieldName, SelectionSet: providesConfiguration.SelectionSet, }) } for _, requiresConfiguration := range in.Requires { out.FederationMetaData.Requires = append(out.FederationMetaData.Requires, plan.FederationFieldConfiguration{ TypeName: requiresConfiguration.TypeName, FieldName: requiresConfiguration.FieldName, SelectionSet: requiresConfiguration.SelectionSet, }) } for _, entityInterfacesConfiguration := range in.EntityInterfaces { out.FederationMetaData.EntityInterfaces = append(out.FederationMetaData.EntityInterfaces, plan.EntityInterfaceConfiguration{ InterfaceTypeName: entityInterfacesConfiguration.InterfaceTypeName, ConcreteTypeNames: entityInterfacesConfiguration.ConcreteTypeNames, }) } for _, interfaceObjectConfiguration := range in.InterfaceObjects { out.FederationMetaData.InterfaceObjects = append(out.FederationMetaData.InterfaceObjects, plan.EntityInterfaceConfiguration{ InterfaceTypeName: interfaceObjectConfiguration.InterfaceTypeName, ConcreteTypeNames: interfaceObjectConfiguration.ConcreteTypeNames, }) } return out } func (l *Loader) fieldHasAuthorizationRule(fieldConfiguration *nodev1.FieldConfiguration) bool { if fieldConfiguration == nil { return false } if fieldConfiguration.AuthorizationConfiguration == nil { return false } if fieldConfiguration.AuthorizationConfiguration.RequiresAuthentication { return true } if len(fieldConfiguration.AuthorizationConfiguration.RequiredOrScopes) > 0 { return true } return false }