input/elasticapm/internal/modeldecoder/generator/code.go (270 lines of code) (raw):

// Licensed to Elasticsearch B.V. under one or more contributor // license agreements. See the NOTICE file distributed with // this work for additional information regarding copyright // ownership. Elasticsearch B.V. licenses this file to you under // the Apache License, Version 2.0 (the "License"); you may // not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, // software distributed under the License is distributed on an // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY // KIND, either express or implied. See the License for the // specific language governing permissions and limitations // under the License. package generator import ( "bytes" "fmt" "go/types" "io" "reflect" "sort" "strings" "github.com/pkg/errors" ) const ( anonymousField = "_" ) // CodeGenerator creates following struct methods // // `IsSet() bool` // `validate() error` // `validate() error` // `processNestedSource() error` // // on all exported and anonymous structs that are referenced // by at least one of the root types type CodeGenerator struct { buf bytes.Buffer parsed *Parsed rootObjs []structType // keep track of already processed types in case one type is // referenced multiple times processedTypes map[string]struct{} } type validationGenerator func(io.Writer, []structField, structField, bool) error // NewCodeGenerator takes an importPath and the package name for which // the type definitions should be loaded. // The nullableTypePath is used to implement validation rules specific to types // of the nullable package. The generator creates methods only for types referenced // directly or indirectly by any of the root types. func NewCodeGenerator(parsed *Parsed, rootTypes []string) (*CodeGenerator, error) { g := CodeGenerator{ parsed: parsed, rootObjs: make([]structType, len(rootTypes)), processedTypes: make(map[string]struct{}), } for i := 0; i < len(rootTypes); i++ { rootStruct, ok := parsed.structTypes[rootTypes[i]] if !ok { return nil, fmt.Errorf("object with root key %s not found", rootTypes[i]) } g.rootObjs[i] = rootStruct } return &g, nil } // Generate generates the code for given root structs and all // dependencies and returns it as bytes.Buffer func (g *CodeGenerator) Generate() (bytes.Buffer, error) { fmt.Fprintf(&g.buf, ` // Code generated by "modeldecoder/generator". DO NOT EDIT. package %s import ( "fmt" "encoding/json" "github.com/pkg/errors" "regexp" "unicode/utf8" ) var ( `[1:], g.parsed.pkgName) for _, name := range sortKeys(g.parsed.patternVariables) { fmt.Fprintf(&g.buf, ` %sRegexp = regexp.MustCompile(%s) `[1:], name, name) } fmt.Fprint(&g.buf, ` ) `[1:]) // run generator code for _, rootObj := range g.rootObjs { if err := g.generate(rootObj, ""); err != nil { return g.buf, errors.Wrap(err, "code generator") } } return g.buf, nil } // create flattened field keys by recursively iterating through the struct types; // there is only struct local knowledge and no knowledge about the parent, // deriving the absolute key is not possible in scenarios where one struct // type is referenced as a field in multiple struct types func (g *CodeGenerator) generate(st structType, key string) error { if _, ok := g.processedTypes[st.name]; ok { return nil } g.processedTypes[st.name] = struct{}{} if err := g.generateIsSet(st, key); err != nil { return err } if err := g.generateValidation(st, key); err != nil { return err } if err := g.generateNestedSourceProcessor(st, key); err != nil { return err } if key != "" { key += "." } for _, field := range st.fields { var childTyp types.Type switch fieldTyp := field.Type().Underlying().(type) { case *types.Map: childTyp = fieldTyp.Elem() case *types.Slice: childTyp = fieldTyp.Elem() default: childTyp = field.Type() } if child, ok := g.customStruct(childTyp); ok { if err := g.generate(child, fmt.Sprintf("%s%s", key, jsonName(field))); err != nil { return err } } } return nil } // generateIsSet creates `IsSet` methods for struct fields, // indicating if the fields have been initialized; // it only considers exported fields, aligned with standard marshal behavior func (g *CodeGenerator) generateIsSet(structTyp structType, key string) error { if len(structTyp.fields) == 0 { return fmt.Errorf("unhandled struct %s (does not have any exported fields)", structTyp.name) } fmt.Fprintf(&g.buf, ` func (val *%s) IsSet() bool { return`, structTyp.name) if key != "" { key += "." } prefix := ` ` for i := 0; i < len(structTyp.fields); i++ { f := structTyp.fields[i] if !f.Exported() { continue } g.buf.WriteString(prefix) if err := generateIsSet(&g.buf, f, "val."); err != nil { return errors.Wrapf(err, "error generating IsSet() for '%s%s'", key, jsonName(f)) } prefix = ` || ` } fmt.Fprint(&g.buf, ` } `) return nil } func generateIsSet(w io.Writer, field structField, fieldSelectorPrefix string) error { switch typ := field.Type().Underlying(); typ.(type) { case *types.Slice, *types.Map: fmt.Fprintf(w, "(len(%s%s) > 0)", fieldSelectorPrefix, field.Name()) return nil case *types.Struct: fmt.Fprintf(w, "%s%s.IsSet()", fieldSelectorPrefix, field.Name()) return nil default: return fmt.Errorf("unhandled type %T generating IsSet() for '%s'", typ, jsonName(field)) } } // generateNestedSourceProcessor generates code for processing fully nested map of fields // into flat fields as identified by their JSON tag. The nested source is identified by // the tag `nested="true"`. If a struct has the nested source tag then all the other struct // fields are eligible for override from the nested source if they have the correct JSON tag. func (g *CodeGenerator) generateNestedSourceProcessor(structTyp structType, key string) error { fmt.Fprintf(&g.buf, ` func (val *%s) processNestedSource() error { `, structTyp.name) var foundNestedSource bool var nestedSource structField for i := 0; i < len(structTyp.fields); i++ { f := structTyp.fields[i] hasTag := f.tag.Get("nested") == "true" if hasTag && f.Type().String() != "map[string]interface{}" { return errors.New("invalid nested source specified, should always be map[string]interface{}") } if hasTag && foundNestedSource { return errors.New("only one nested source per struct is allowed") } if hasTag && !foundNestedSource { foundNestedSource = hasTag nestedSource = f // return if the map is empty fmt.Fprintf(&g.buf, ` if len(val.%s) == 0 { return nil } `[1:], nestedSource.Name()) } } for i := 0; i < len(structTyp.fields); i++ { f := structTyp.fields[i] if foundNestedSource { if f.Name() == nestedSource.Name() { continue } if err := generateParseFlatFieldCode(&g.buf, f, nestedSource); err != nil { return err } } switch f.Type().Underlying().(type) { case *types.Struct: if _, isCustom := g.customStruct(f.Type()); isCustom { fmt.Fprintf(&g.buf, ` if err := val.%s.processNestedSource(); err != nil { return errors.Wrapf(err, "%s") } `[1:], f.Name(), jsonName(f)) } } } fmt.Fprint(&g.buf, ` return nil } `[1:]) return nil } // generateValidation creates `validate` methods for struct fields // it only considers exported and anonymous fields func (g *CodeGenerator) generateValidation(structTyp structType, key string) error { fmt.Fprintf(&g.buf, ` func (val *%s) validate() error { `, structTyp.name) var validation validationGenerator for i := 0; i < len(structTyp.fields); i++ { f := structTyp.fields[i] // according to https://golang.org/pkg/go/types/#Var.Anonymous // f.Anonymous() actually checks if f is embedded, not anonymous, // so we need to do a name check instead if !f.Exported() && f.Name() != anonymousField { continue } var custom bool switch f.Type().String() { case nullableTypeString: validation = generateNullableStringValidation case nullableTypeInt, nullableTypeInt64: validation = generateNullableIntValidation case nullableTypeFloat64: // right now we can reuse the validation rules for int // and only introduce dedicated rules for float64 when they diverge validation = generateNullableIntValidation case nullableTypeInterface: validation = generateNullableInterfaceValidation default: switch t := f.Type().Underlying().(type) { case *types.Slice: validation = generateSliceValidation _, custom = g.customStruct(t.Elem()) case *types.Map: validation = generateMapValidation _, custom = g.customStruct(t.Elem()) case *types.Struct: validation = generateStructValidation _, custom = g.customStruct(f.Type()) default: return errors.Wrap(fmt.Errorf("unhandled type %T", t), flattenName(key, f)) } } if err := validation(&g.buf, structTyp.fields, f, custom); err != nil { return errors.Wrap(err, flattenName(key, f)) } } fmt.Fprint(&g.buf, ` return nil } `[1:]) return nil } func (g *CodeGenerator) customStruct(typ types.Type) (t structType, ok bool) { t, ok = g.parsed.structTypes[typ.String()] return } func flattenName(key string, f structField) string { if key != "" { key += "." } return fmt.Sprintf("%s%s", key, jsonName(f)) } func jsonName(f structField) string { parts := parseTag(f.tag, "json") if len(parts) == 0 { return strings.ToLower(f.Name()) } return parts[0] } func parseTag(structTag reflect.StructTag, tagName string) []string { tag, ok := structTag.Lookup(tagName) if !ok { return []string{} } if tag == "-" { return nil } return strings.Split(tag, ",") } func sortKeys(input map[string]string) []string { keys := make(sort.StringSlice, 0, len(input)) for k := range input { keys = append(keys, k) } keys.Sort() return keys }