plugin/federation/fieldset/fieldset.go (156 lines of code) (raw):
package fieldset
import (
"fmt"
"strings"
"github.com/vektah/gqlparser/v2/ast"
"github.com/99designs/gqlgen/codegen"
"github.com/99designs/gqlgen/codegen/templates"
)
// Set represents a FieldSet that is used in federation directives @key and @requires.
// Would be happier to reuse FieldSet parsing from gqlparser, but this suits for now.
type Set []Field
// Field represents a single field in a FieldSet
type Field []string
// New parses a FieldSet string into a TinyFieldSet.
func New(raw string, prefix []string) Set {
if !strings.Contains(raw, "{") {
return parseUnnestedKeyFieldSet(raw, prefix)
}
var (
ret = Set{}
subPrefix = prefix
)
before, during, after := extractSubs(raw)
if before != "" {
befores := New(before, prefix)
if len(befores) > 0 {
subPrefix = befores[len(befores)-1]
ret = append(ret, befores[:len(befores)-1]...)
}
}
if during != "" {
ret = append(ret, New(during, subPrefix)...)
}
if after != "" {
ret = append(ret, New(after, prefix)...)
}
return ret
}
// FieldDefinition looks up a field in the type.
func (f Field) FieldDefinition(schemaType *ast.Definition, schema *ast.Schema) *ast.FieldDefinition {
objType := schemaType
def := objType.Fields.ForName(f[0])
for _, part := range f[1:] {
if objType.Kind != ast.Object {
panic(fmt.Sprintf(`invalid sub-field reference "%s" in %v: `, objType.Name, f))
}
x := def.Type.Name()
objType = schema.Types[x]
if objType == nil {
panic("invalid schema type: " + x)
}
def = objType.Fields.ForName(part)
}
if def == nil {
return nil
}
ret := *def // shallow copy
ret.Name = f.ToGoPrivate()
return &ret
}
// TypeReference looks up the type of a field.
func (f Field) TypeReference(obj *codegen.Object, objects codegen.Objects) *codegen.Field {
var def *codegen.Field
for _, part := range f {
def = fieldByName(obj, part)
if def == nil {
panic("unable to find field " + f[0])
}
obj = objects.ByName(def.TypeReference.Definition.Name)
}
return def
}
// ToGo converts a (possibly nested) field into a proper public Go name.
func (f Field) ToGo() string {
var ret string
for _, field := range f {
ret += templates.ToGo(field)
}
return ret
}
// ToGoPrivate converts a (possibly nested) field into a proper private Go name.
func (f Field) ToGoPrivate() string {
var ret string
for i, field := range f {
if i == 0 {
field = trimArgumentFromFieldName(field)
ret += templates.ToGoPrivate(field)
continue
}
ret += templates.ToGo(field)
}
return ret
}
// Join concatenates the field parts with a string separator between. Useful in templates.
func (f Field) Join(str string) string {
return strings.Join(f, str)
}
// JoinGo concatenates the Go name of field parts with a string separator between. Useful in templates.
func (f Field) JoinGo(str string) string {
strs := []string{}
for _, s := range f {
strs = append(strs, templates.ToGo(s))
}
return strings.Join(strs, str)
}
func (f Field) LastIndex() int {
return len(f) - 1
}
// local functions
// parseUnnestedKeyFieldSet // handles simple case where none of the fields are nested.
func parseUnnestedKeyFieldSet(raw string, prefix []string) Set {
ret := Set{}
unionField := false
for _, s := range strings.Fields(raw) {
if s == "..." {
continue
}
if s == "on" {
unionField = true
continue
}
if unionField {
s = "... on " + s
unionField = false
}
next := append(prefix[0:len(prefix):len(prefix)], s) //nolint:gocritic // set cap=len in order to force slice reallocation
ret = append(ret, next)
}
return ret
}
// extractSubs splits out and trims sub-expressions from before, inside, and after "{}".
func extractSubs(str string) (string, string, string) {
start := strings.Index(str, "{")
end := matchingBracketIndex(str, start)
if start < 0 || end < 0 {
panic("invalid key fieldSet: " + str)
}
return trimArgumentFromFieldName(strings.TrimSpace(str[:start])), strings.TrimSpace(str[start+1 : end]), strings.TrimSpace(str[end+1:])
}
// matchingBracketIndex returns the index of the closing bracket, assuming an open bracket at start.
func matchingBracketIndex(str string, start int) int {
if start < 0 || len(str) <= start+1 {
return -1
}
var depth int
for i, c := range str[start+1:] {
switch c {
case '{':
depth++
case '}':
if depth == 0 {
return start + 1 + i
}
depth--
}
}
return -1
}
func fieldByName(obj *codegen.Object, name string) *codegen.Field {
for _, field := range obj.Fields {
field.Name = trimArgumentFromFieldName(field.Name)
if field.Name == name {
return field
}
}
return nil
}
// trimArgumentFromFieldName removes any arguments from the field name.
// It removes any suffixes from the raw string, starting from the argument-open character `(`
func trimArgumentFromFieldName(raw string) string {
return strings.Split(raw, "(")[0]
}