func()

in plugin/modelgen/models.go [102:329]


func (m *Plugin) MutateConfig(cfg *config.Config) error {
	b := &ModelBuild{
		PackageName: cfg.Model.Package,
	}

	for _, schemaType := range cfg.Schema.Types {
		if cfg.Models.UserDefined(schemaType.Name) {
			continue
		}
		switch schemaType.Kind {
		case ast.Interface, ast.Union:
			var fields []*Field
			var err error
			if !cfg.OmitGetters {
				fields, err = m.generateFields(cfg, schemaType)
				if err != nil {
					return err
				}
			}

			it := &Interface{
				Description: schemaType.Description,
				Name:        schemaType.Name,
				Implements:  schemaType.Interfaces,
				Fields:      fields,
				OmitCheck:   cfg.OmitInterfaceChecks,
			}

			// if the interface has a key directive as an entity interface, allow it to implement _Entity
			if schemaType.Directives.ForName("key") != nil {
				it.Implements = append(it.Implements, "_Entity")
			}

			b.Interfaces = append(b.Interfaces, it)
		case ast.Object, ast.InputObject:
			if cfg.IsRoot(schemaType) {
				if !cfg.OmitRootModels {
					b.Models = append(b.Models, &Object{
						Description: schemaType.Description,
						Name:        schemaType.Name,
					})
				}
				continue
			}

			fields, err := m.generateFields(cfg, schemaType)
			if err != nil {
				return err
			}

			it := &Object{
				Description: schemaType.Description,
				Name:        schemaType.Name,
				Fields:      fields,
			}

			// If Interface A implements interface B, and Interface C also implements interface B
			// then both A and C have methods of B.
			// The reason for checking unique is to prevent the same method B from being generated twice.
			uniqueMap := map[string]bool{}
			for _, implementor := range cfg.Schema.GetImplements(schemaType) {
				if !uniqueMap[implementor.Name] {
					it.Implements = append(it.Implements, implementor.Name)
					uniqueMap[implementor.Name] = true
				}
				// for interface implements
				for _, iface := range implementor.Interfaces {
					if !uniqueMap[iface] {
						it.Implements = append(it.Implements, iface)
						uniqueMap[iface] = true
					}
				}
			}

			b.Models = append(b.Models, it)
		case ast.Enum:
			it := &Enum{
				Name:        schemaType.Name,
				Description: schemaType.Description,
			}

			for _, v := range schemaType.EnumValues {
				it.Values = append(it.Values, &EnumValue{
					Name:        v.Name,
					Description: v.Description,
				})
			}

			b.Enums = append(b.Enums, it)
		case ast.Scalar:
			b.Scalars = append(b.Scalars, schemaType.Name)
		}
	}
	sort.Slice(b.Enums, func(i, j int) bool { return b.Enums[i].Name < b.Enums[j].Name })
	sort.Slice(b.Models, func(i, j int) bool { return b.Models[i].Name < b.Models[j].Name })
	sort.Slice(b.Interfaces, func(i, j int) bool { return b.Interfaces[i].Name < b.Interfaces[j].Name })

	// if we are not just turning all struct-type fields in generated structs into pointers, we need to at least
	// check for cyclical relationships and recursive structs
	if !cfg.StructFieldsAlwaysPointers {
		findAndHandleCyclicalRelationships(b)
	}

	for _, it := range b.Enums {
		cfg.Models.Add(it.Name, cfg.Model.ImportPath()+"."+templates.ToGo(it.Name))
	}
	for _, it := range b.Models {
		cfg.Models.Add(it.Name, cfg.Model.ImportPath()+"."+templates.ToGo(it.Name))
	}
	for _, it := range b.Interfaces {
		// On a given interface we want to keep a reference to all the models that implement it
		for _, model := range b.Models {
			for _, impl := range model.Implements {
				if impl == it.Name {
					// check if this isn't an implementation of an entity interface
					if impl != "_Entity" {
						// If this model has an implementation, add it to the Interface's Models
						it.Models = append(it.Models, model)
					}
				}
			}
		}
		cfg.Models.Add(it.Name, cfg.Model.ImportPath()+"."+templates.ToGo(it.Name))
	}
	for _, it := range b.Scalars {
		cfg.Models.Add(it, "github.com/99designs/gqlgen/graphql.String")
	}

	if len(b.Models) == 0 && len(b.Enums) == 0 && len(b.Interfaces) == 0 && len(b.Scalars) == 0 {
		return nil
	}

	if m.MutateHook != nil {
		b = m.MutateHook(b)
	}

	getInterfaceByName := func(name string) *Interface {
		// Allow looking up interfaces, so template can generate getters for each field
		for _, i := range b.Interfaces {
			if i.Name == name {
				return i
			}
		}

		return nil
	}
	gettersGenerated := make(map[string]map[string]struct{})
	generateGetter := func(model *Object, field *Field) string {
		if model == nil || field == nil {
			return ""
		}

		// Let templates check if a given getter has been generated already
		typeGetters, exists := gettersGenerated[model.Name]
		if !exists {
			typeGetters = make(map[string]struct{})
			gettersGenerated[model.Name] = typeGetters
		}

		_, exists = typeGetters[field.GoName]
		typeGetters[field.GoName] = struct{}{}
		if exists {
			return ""
		}

		_, interfaceFieldTypeIsPointer := field.Type.(*types.Pointer)
		var structFieldTypeIsPointer bool
		for _, f := range model.Fields {
			if f.GoName == field.GoName {
				_, structFieldTypeIsPointer = f.Type.(*types.Pointer)
				break
			}
		}
		goType := templates.CurrentImports.LookupType(field.Type)
		if strings.HasPrefix(goType, "[]") {
			getter := fmt.Sprintf("func (this %s) Get%s() %s {\n", templates.ToGo(model.Name), field.GoName, goType)
			getter += fmt.Sprintf("\tif this.%s == nil { return nil }\n", field.GoName)
			getter += fmt.Sprintf("\tinterfaceSlice := make(%s, 0, len(this.%s))\n", goType, field.GoName)
			getter += fmt.Sprintf("\tfor _, concrete := range this.%s { interfaceSlice = append(interfaceSlice, ", field.GoName)
			if interfaceFieldTypeIsPointer && !structFieldTypeIsPointer {
				getter += "&"
			} else if !interfaceFieldTypeIsPointer && structFieldTypeIsPointer {
				getter += "*"
			}
			getter += "concrete) }\n"
			getter += "\treturn interfaceSlice\n"
			getter += "}"
			return getter
		}
		getter := fmt.Sprintf("func (this %s) Get%s() %s { return ", templates.ToGo(model.Name), field.GoName, goType)

		if interfaceFieldTypeIsPointer && !structFieldTypeIsPointer {
			getter += "&"
		} else if !interfaceFieldTypeIsPointer && structFieldTypeIsPointer {
			getter += "*"
		}

		getter += fmt.Sprintf("this.%s }", field.GoName)
		return getter
	}
	funcMap := template.FuncMap{
		"getInterfaceByName": getInterfaceByName,
		"generateGetter":     generateGetter,
	}
	newModelTemplate := modelTemplate
	if cfg.Model.ModelTemplate != "" {
		newModelTemplate = readModelTemplate(cfg.Model.ModelTemplate)
	}

	err := templates.Render(templates.Options{
		PackageName:     cfg.Model.Package,
		Filename:        cfg.Model.Filename,
		Data:            b,
		GeneratedHeader: true,
		Packages:        cfg.Packages,
		Template:        newModelTemplate,
		Funcs:           funcMap,
	})
	if err != nil {
		return err
	}

	// We may have generated code in a package we already loaded, so we reload all packages
	// to allow packages to be compared correctly
	cfg.ReloadAllPackages()

	return nil
}