plugin/federation/federation.go (463 lines of code) (raw):
package federation
import (
_ "embed"
"errors"
"fmt"
"sort"
"strings"
"github.com/vektah/gqlparser/v2/ast"
"github.com/99designs/gqlgen/codegen"
"github.com/99designs/gqlgen/codegen/config"
"github.com/99designs/gqlgen/codegen/templates"
"github.com/99designs/gqlgen/internal/rewrite"
"github.com/99designs/gqlgen/plugin/federation/fieldset"
)
//go:embed federation.gotpl
var federationTemplate string
//go:embed requires.gotpl
var explicitRequiresTemplate string
type Federation struct {
Entities []*Entity
PackageOptions PackageOptions
version int
// true if @requires is used in the schema
usesRequires bool
}
type PackageOptions struct {
// ExplicitRequires will generate a function in the execution context
// to populate fields using the @required directive into the entity.
//
// You can only set one of ExplicitRequires or ComputedRequires to true.
ExplicitRequires bool
// ComputedRequires generates resolver functions to compute values for
// fields using the @required directive.
ComputedRequires bool
}
// New returns a federation plugin that injects
// federated directives and types into the schema
func New(version int, cfg *config.Config) (*Federation, error) {
if version == 0 {
version = 1
}
options, err := buildPackageOptions(cfg)
if err != nil {
return nil, fmt.Errorf("invalid federation package options: %w", err)
}
return &Federation{
version: version,
PackageOptions: options,
}, nil
}
func buildPackageOptions(cfg *config.Config) (PackageOptions, error) {
packageOptions := cfg.Federation.Options
explicitRequires := packageOptions["explicit_requires"]
computedRequires := packageOptions["computed_requires"]
if explicitRequires && computedRequires {
return PackageOptions{}, errors.New("only one of explicit_requires or computed_requires can be set to true")
}
if computedRequires {
if cfg.Federation.Version != 2 {
return PackageOptions{}, errors.New("when using federation.options.computed_requires you must be using Federation 2")
}
// We rely on injecting a null argument with a directives for fields with @requires, so we need to ensure
// our directive is always called.
if !cfg.CallArgumentDirectivesWithNull {
return PackageOptions{}, errors.New("when using federation.options.computed_requires, call_argument_directives_with_null must be set to true")
}
}
// We rely on injecting a null argument with a directives for fields with @requires, so we need to ensure
// our directive is always called.
return PackageOptions{
ExplicitRequires: explicitRequires,
ComputedRequires: computedRequires,
}, nil
}
// Name returns the plugin name
func (f *Federation) Name() string {
return "federation"
}
// MutateConfig mutates the configuration
func (f *Federation) MutateConfig(cfg *config.Config) error {
for typeName, entry := range builtins {
if cfg.Models.Exists(typeName) {
return fmt.Errorf("%v already exists which must be reserved when Federation is enabled", typeName)
}
cfg.Models[typeName] = entry
}
cfg.Directives["external"] = config.DirectiveConfig{SkipRuntime: true}
cfg.Directives[dirNameRequires] = config.DirectiveConfig{SkipRuntime: true}
cfg.Directives["provides"] = config.DirectiveConfig{SkipRuntime: true}
cfg.Directives[dirNameKey] = config.DirectiveConfig{SkipRuntime: true}
cfg.Directives["extends"] = config.DirectiveConfig{SkipRuntime: true}
cfg.Directives[dirNameEntityResolver] = config.DirectiveConfig{SkipRuntime: true}
// Federation 2 specific directives
if f.version == 2 {
cfg.Directives["shareable"] = config.DirectiveConfig{SkipRuntime: true}
cfg.Directives["link"] = config.DirectiveConfig{SkipRuntime: true}
cfg.Directives["tag"] = config.DirectiveConfig{SkipRuntime: true}
cfg.Directives["override"] = config.DirectiveConfig{SkipRuntime: true}
cfg.Directives["inaccessible"] = config.DirectiveConfig{SkipRuntime: true}
cfg.Directives["authenticated"] = config.DirectiveConfig{SkipRuntime: true}
cfg.Directives["requiresScopes"] = config.DirectiveConfig{SkipRuntime: true}
cfg.Directives["policy"] = config.DirectiveConfig{SkipRuntime: true}
cfg.Directives["interfaceObject"] = config.DirectiveConfig{SkipRuntime: true}
cfg.Directives["composeDirective"] = config.DirectiveConfig{SkipRuntime: true}
}
if f.usesRequires && f.PackageOptions.ComputedRequires {
cfg.Schema.Directives[dirPopulateFromRepresentations.Name] = dirPopulateFromRepresentations
cfg.Directives[dirPopulateFromRepresentations.Name] = config.DirectiveConfig{Implementation: &populateFromRepresentationsImplementation}
cfg.Schema.Directives[dirEntityReference.Name] = dirEntityReference
cfg.Directives[dirEntityReference.Name] = config.DirectiveConfig{SkipRuntime: true}
f.addMapType(cfg)
f.mutateSchemaForRequires(cfg.Schema, cfg)
}
return nil
}
func (f *Federation) InjectSourcesEarly() ([]*ast.Source, error) {
input := ``
// add version-specific changes on key directive, as well as adding the new directives for federation 2
if f.version == 1 {
input += federationVersion1Schema
} else if f.version == 2 {
input += federationVersion2Schema
}
return []*ast.Source{{
Name: dirGraphQLQFile,
Input: input,
BuiltIn: true,
}}, nil
}
// InjectSourceLate creates a GraphQL Entity type with all
// the fields that had the @key directive
func (f *Federation) InjectSourcesLate(schema *ast.Schema) ([]*ast.Source, error) {
f.Entities = f.buildEntities(schema, f.version)
entities := make([]string, 0)
resolvers := make([]string, 0)
entityResolverInputDefinitions := make([]string, 0)
for _, e := range f.Entities {
if e.Def.Kind != ast.Interface {
entities = append(entities, e.Name)
} else if len(schema.GetPossibleTypes(e.Def)) == 0 {
fmt.Println(
"skipping @key field on interface " + e.Def.Name + " as no types implement it",
)
}
for _, r := range e.Resolvers {
resolverSDL, entityResolverInputSDL := buildResolverSDL(r, e.Multi)
resolvers = append(resolvers, resolverSDL)
if entityResolverInputSDL != "" {
entityResolverInputDefinitions = append(entityResolverInputDefinitions, entityResolverInputSDL)
}
}
}
var blocks []string
if len(entities) > 0 {
entitiesSDL := `# a union of all types that use the @key directive
union _Entity = ` + strings.Join(entities, " | ")
blocks = append(blocks, entitiesSDL)
}
// resolvers can be empty if a service defines only "empty
// extend" types. This should be rare.
if len(resolvers) > 0 {
if len(entityResolverInputDefinitions) > 0 {
inputSDL := strings.Join(entityResolverInputDefinitions, "\n\n")
blocks = append(blocks, inputSDL)
}
resolversSDL := `# fake type to build resolver interfaces for users to implement
type Entity {
` + strings.Join(resolvers, "\n") + `
}`
blocks = append(blocks, resolversSDL)
}
_serviceTypeDef := `type _Service {
sdl: String
}`
blocks = append(blocks, _serviceTypeDef)
var additionalQueryFields string
// Quote from the Apollo Federation subgraph specification:
// If no types are annotated with the key directive, then the
// _Entity union and _entities field should be removed from the schema
if len(f.Entities) > 0 {
additionalQueryFields += ` _entities(representations: [_Any!]!): [_Entity]!
`
}
// _service field is required in any case
additionalQueryFields += ` _service: _Service!`
extendTypeQueryDef := `extend type ` + schema.Query.Name + ` {
` + additionalQueryFields + `
}`
blocks = append(blocks, extendTypeQueryDef)
return []*ast.Source{{
Name: entityGraphQLQFile,
BuiltIn: true,
Input: "\n" + strings.Join(blocks, "\n\n") + "\n",
}}, nil
}
func (f *Federation) GenerateCode(data *codegen.Data) error {
// requires imports
requiresImports := make(map[string]bool, 0)
requiresImports["context"] = true
requiresImports["fmt"] = true
requiresEntities := make(map[string]*Entity, 0)
// Save package options on f for template use
packageOptions, err := buildPackageOptions(data.Config)
if err != nil {
return fmt.Errorf("invalid federation package options: %w", err)
}
f.PackageOptions = packageOptions
if len(f.Entities) > 0 {
if data.Objects.ByName("Entity") != nil {
data.Objects.ByName("Entity").Root = true
}
for _, e := range f.Entities {
obj := data.Objects.ByName(e.Def.Name)
if e.Def.Kind == ast.Interface {
if len(data.Interfaces[e.Def.Name].Implementors) == 0 {
fmt.Println(
"skipping @key field on interface " + e.Def.Name + " as no types implement it",
)
continue
}
obj = data.Objects.ByName(data.Interfaces[e.Def.Name].Implementors[0].Name)
}
for _, r := range e.Resolvers {
populateKeyFieldTypes(r, obj, data.Objects, e.Def.Name)
}
// fill in types for requires fields
//
for _, reqField := range e.Requires {
if len(reqField.Field) == 0 {
fmt.Println("skipping @requires field " + reqField.Name + " in " + e.Def.Name)
continue
}
// keep track of which entities have requires
requiresEntities[e.Def.Name] = e
// make a proper import path
typeString := strings.Split(obj.Type.String(), ".")
requiresImports[strings.Join(typeString[:len(typeString)-1], ".")] = true
if containsUnionField(reqField) {
continue
}
cgField := reqField.Field.TypeReference(obj, data.Objects)
reqField.Type = cgField.TypeReference
}
// add type info to entity
e.Type = obj.Type
}
}
// fill in types for resolver inputs
//
for _, entity := range f.Entities {
if !entity.Multi {
continue
}
for _, resolver := range entity.Resolvers {
obj := data.Inputs.ByName(resolver.InputTypeName)
if obj == nil {
return fmt.Errorf("input object %s not found", resolver.InputTypeName)
}
resolver.InputType = obj.Type
}
}
if f.PackageOptions.ExplicitRequires && len(requiresEntities) > 0 {
err := f.generateExplicitRequires(
data,
requiresEntities,
requiresImports,
)
if err != nil {
return err
}
}
return templates.Render(templates.Options{
PackageName: data.Config.Federation.Package,
Filename: data.Config.Federation.Filename,
Data: struct {
Federation
UsePointers bool
}{*f, data.Config.ResolversAlwaysReturnPointers},
GeneratedHeader: true,
Packages: data.Config.Packages,
Template: federationTemplate,
})
}
func containsUnionField(reqField *Requires) bool {
for _, requireFields := range reqField.Field {
if strings.HasPrefix(requireFields, "... on") {
return true
}
}
return false
}
// Fill in types for key fields
func populateKeyFieldTypes(
resolver *EntityResolver,
obj *codegen.Object,
allObjects codegen.Objects,
name string,
) {
for _, keyField := range resolver.KeyFields {
if len(keyField.Field) == 0 {
fmt.Println(
"skipping @key field " + keyField.Definition.Name + " in " + resolver.ResolverName + " in " + name,
)
continue
}
cgField := keyField.Field.TypeReference(obj, allObjects)
keyField.Type = cgField.TypeReference
}
}
func (f *Federation) buildEntities(schema *ast.Schema, version int) []*Entity {
entities := make([]*Entity, 0)
for _, schemaType := range schema.Types {
entity := f.buildEntity(schemaType, schema, version)
if entity != nil {
entities = append(entities, entity)
}
}
// make sure order remains stable across multiple builds
sort.Slice(entities, func(i, j int) bool {
return entities[i].Name < entities[j].Name
})
return entities
}
func (f *Federation) buildEntity(
schemaType *ast.Definition,
schema *ast.Schema,
version int,
) *Entity {
keys, ok := isFederatedEntity(schemaType)
if !ok {
return nil
}
if (schemaType.Kind == ast.Interface) && (len(schema.GetPossibleTypes(schemaType)) == 0) {
fmt.Printf("@key directive found on unused \"interface %s\". Will be ignored.\n", schemaType.Name)
return nil
}
entity := &Entity{
Name: schemaType.Name,
Def: schemaType,
Resolvers: nil,
Requires: nil,
Multi: isMultiEntity(schemaType),
}
// If our schema has a field with a type defined in
// another service, then we need to define an "empty
// extend" of that type in this service, so this service
// knows what the type is like. But the graphql-server
// will never ask us to actually resolve this "empty
// extend", so we don't require a resolver function for
// it. (Well, it will never ask in practice; it's
// unclear whether the spec guarantees this. See
// https://github.com/apollographql/apollo-server/issues/3852
// ). Example:
// type MyType {
// myvar: TypeDefinedInOtherService
// }
// // Federation needs this type, but
// // it doesn't need a resolver for it!
// extend TypeDefinedInOtherService @key(fields: "id") {
// id: ID @external
// }
if entity.allFieldsAreExternal(version) {
return entity
}
entity.Resolvers = buildResolvers(schemaType, schema, keys, entity.Multi)
entity.Requires = buildRequires(schemaType)
if len(entity.Requires) > 0 {
f.usesRequires = true
}
return entity
}
func isMultiEntity(schemaType *ast.Definition) bool {
dir := schemaType.Directives.ForName(dirNameEntityResolver)
if dir == nil {
return false
}
if dirArg := dir.Arguments.ForName("multi"); dirArg != nil {
if dirVal, err := dirArg.Value.Value(nil); err == nil {
return dirVal.(bool)
}
}
return false
}
func buildResolvers(
schemaType *ast.Definition,
schema *ast.Schema,
keys []*ast.Directive,
multi bool,
) []*EntityResolver {
resolvers := make([]*EntityResolver, 0)
for _, dir := range keys {
if len(dir.Arguments) > 2 {
panic("More than two arguments provided for @key declaration.")
}
keyFields, resolverFields := buildKeyFields(
schemaType,
schema,
dir,
)
resolverFieldsToGo := schemaType.Name + "By" + strings.Join(resolverFields, "And")
var resolverName string
if multi {
resolverFieldsToGo += "s" // Pluralize for better API readability
resolverName = fmt.Sprintf("findMany%s", resolverFieldsToGo)
} else {
resolverName = fmt.Sprintf("find%s", resolverFieldsToGo)
}
resolvers = append(resolvers, &EntityResolver{
ResolverName: resolverName,
KeyFields: keyFields,
InputTypeName: resolverFieldsToGo + "Input",
ReturnTypeName: schemaType.Name,
})
}
return resolvers
}
func extractFields(
dir *ast.Directive,
) (string, error) {
var arg *ast.Argument
// since directives are able to now have multiple arguments, we need to check both possible for a possible @key(fields="" fields="")
for _, a := range dir.Arguments {
if a.Name == DirArgFields {
if arg != nil {
return "", errors.New("more than one \"fields\" argument provided for declaration")
}
arg = a
}
}
return arg.Value.Raw, nil
}
func buildKeyFields(
schemaType *ast.Definition,
schema *ast.Schema,
dir *ast.Directive,
) ([]*KeyField, []string) {
fieldsRaw, err := extractFields(dir)
if err != nil {
panic("More than one `fields` argument provided for declaration.")
}
keyFieldSet := fieldset.New(fieldsRaw, nil)
keyFields := make([]*KeyField, len(keyFieldSet))
resolverFields := []string{}
for i, field := range keyFieldSet {
def := field.FieldDefinition(schemaType, schema)
if def == nil {
panic(fmt.Sprintf("no field for %v", field))
}
keyFields[i] = &KeyField{Definition: def, Field: field}
resolverFields = append(resolverFields, keyFields[i].Field.ToGo())
}
return keyFields, resolverFields
}
func buildRequires(schemaType *ast.Definition) []*Requires {
requires := make([]*Requires, 0)
for _, f := range schemaType.Fields {
dir := f.Directives.ForName(dirNameRequires)
if dir == nil {
continue
}
fieldsRaw, err := extractFields(dir)
if err != nil {
panic("Exactly one `fields` argument needed for @requires declaration.")
}
requiresFieldSet := fieldset.New(fieldsRaw, nil)
for _, field := range requiresFieldSet {
requires = append(requires, &Requires{
Name: field.ToGoPrivate(),
Field: field,
})
}
}
return requires
}
func isFederatedEntity(schemaType *ast.Definition) ([]*ast.Directive, bool) {
switch schemaType.Kind {
case ast.Object:
keys := schemaType.Directives.ForNames(dirNameKey)
if len(keys) > 0 {
return keys, true
}
case ast.Interface:
keys := schemaType.Directives.ForNames(dirNameKey)
if len(keys) > 0 {
return keys, true
}
// TODO: support @extends for interfaces
if dir := schemaType.Directives.ForName("extends"); dir != nil {
panic(
fmt.Sprintf(
"@extends directive is not currently supported for interfaces, use \"extend interface %s\" instead.",
schemaType.Name,
))
}
default:
// ignore
}
return nil, false
}
func (f *Federation) generateExplicitRequires(
data *codegen.Data,
requiresEntities map[string]*Entity,
requiresImports map[string]bool,
) error {
// check for existing requires functions
type Populator struct {
FuncName string
Exists bool
Comment string
Implementation string
Entity *Entity
}
populators := make([]Populator, 0)
rewriter, err := rewrite.New(data.Config.Federation.Dir())
if err != nil {
return err
}
for name, entity := range requiresEntities {
populator := Populator{
FuncName: fmt.Sprintf("Populate%sRequires", name),
Entity: entity,
}
populator.Comment = strings.TrimSpace(strings.TrimLeft(rewriter.GetMethodComment("executionContext", populator.FuncName), `\`))
populator.Implementation = strings.TrimSpace(rewriter.GetMethodBody("executionContext", populator.FuncName))
if populator.Implementation == "" {
populator.Exists = false
populator.Implementation = fmt.Sprintf("panic(fmt.Errorf(\"not implemented: %v\"))", populator.FuncName)
}
populators = append(populators, populator)
}
sort.Slice(populators, func(i, j int) bool {
return populators[i].FuncName < populators[j].FuncName
})
requiresFile := data.Config.Federation.Dir() + "/federation.requires.go"
existingImports := rewriter.ExistingImports(requiresFile)
for _, imp := range existingImports {
if imp.Alias == "" {
// import exists in both places, remove
delete(requiresImports, imp.ImportPath)
}
}
for k := range requiresImports {
existingImports = append(existingImports, rewrite.Import{ImportPath: k})
}
// render requires populators
return templates.Render(templates.Options{
PackageName: data.Config.Federation.Package,
Filename: requiresFile,
Data: struct {
Federation
ExistingImports []rewrite.Import
Populators []Populator
OriginalSource string
}{*f, existingImports, populators, ""},
GeneratedHeader: false,
Packages: data.Config.Packages,
Template: explicitRequiresTemplate,
})
}
func buildResolverSDL(
resolver *EntityResolver,
multi bool,
) (resolverSDL, entityResolverInputSDL string) {
if multi {
entityResolverInputSDL = buildEntityResolverInputDefinitionSDL(resolver)
resolverSDL := fmt.Sprintf("\t%s(reps: [%s]!): [%s]", resolver.ResolverName, resolver.InputTypeName, resolver.ReturnTypeName)
return resolverSDL, entityResolverInputSDL
}
resolverArgs := ""
for _, keyField := range resolver.KeyFields {
resolverArgs += fmt.Sprintf("%s: %s,", keyField.Field.ToGoPrivate(), keyField.Definition.Type.String())
}
resolverSDL = fmt.Sprintf("\t%s(%s): %s!", resolver.ResolverName, resolverArgs, resolver.ReturnTypeName)
return resolverSDL, ""
}
func buildEntityResolverInputDefinitionSDL(resolver *EntityResolver) string {
entityResolverInputDefinition := "input " + resolver.InputTypeName + " {\n"
for _, keyField := range resolver.KeyFields {
entityResolverInputDefinition += fmt.Sprintf(
"\t%s: %s\n",
keyField.Field.ToGo(),
keyField.Definition.Type.String(),
)
}
return entityResolverInputDefinition + "}"
}
func (f *Federation) addMapType(cfg *config.Config) {
cfg.Models[mapTypeName] = config.TypeMapEntry{
Model: config.StringList{"github.com/99designs/gqlgen/graphql.Map"},
}
cfg.Schema.Types[mapTypeName] = &ast.Definition{
Kind: ast.Scalar,
Name: mapTypeName,
Description: "Maps an arbitrary GraphQL value to a map[string]any Go type.",
}
}
func (f *Federation) mutateSchemaForRequires(
schema *ast.Schema,
cfg *config.Config,
) {
for _, schemaType := range schema.Types {
for _, field := range schemaType.Fields {
if dir := field.Directives.ForName(dirNameRequires); dir != nil {
// ensure we always generate a resolver for any @requires field
model := cfg.Models[schemaType.Name]
fieldConfig := model.Fields[field.Name]
fieldConfig.Resolver = true
if model.Fields == nil {
model.Fields = make(map[string]config.TypeMapField)
}
model.Fields[field.Name] = fieldConfig
cfg.Models[schemaType.Name] = model
requiresArgument := &ast.ArgumentDefinition{
Name: fieldArgRequires,
Type: ast.NamedType(mapTypeName, nil),
Directives: ast.DirectiveList{
{
Name: dirNamePopulateFromRepresentations,
Definition: dirPopulateFromRepresentations,
},
},
}
field.Arguments = append(field.Arguments, requiresArgument)
}
}
}
}