cmd/generate-fastjson/main.go (432 lines of code) (raw):

// Copyright 2018 Elasticsearch BV // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package main import ( "bytes" "flag" "fmt" "go/ast" "go/format" "go/token" "go/types" "io" "log" "os" "reflect" "sort" "strings" "golang.org/x/tools/go/packages" ) const ( fastjsonPath = "go.elastic.co/fastjson" isZeroMethod = "isZero" marshalMethod = "MarshalFastJSON" ) var ( force bool outfile string ) func init() { flag.BoolVar(&force, "f", false, "remove the output file if it exists") flag.StringVar(&outfile, "o", "-", "file to which output will be written") flag.Usage = func() { fmt.Fprintf(os.Stderr, "Usage: %s <package>\n", os.Args[0]) flag.PrintDefaults() } } func main() { flag.Parse() if flag.NArg() != 1 { flag.Usage() os.Exit(1) } if outfile != "-" { if _, err := os.Stat(outfile); err == nil { if force { if err := os.Remove(outfile); err != nil { log.Fatal(err) } } else { fmt.Fprintf(os.Stderr, "%s already exists, and -f not specified; aborting\n", outfile) os.Exit(2) } } } cfg := &packages.Config{ Mode: packages.NeedTypes | packages.NeedSyntax | packages.NeedTypesInfo, } pkgs, err := packages.Load(cfg, flag.Arg(0)) if err != nil { fmt.Fprintf(os.Stderr, "load: %v\n", err) os.Exit(1) } if packages.PrintErrors(pkgs) > 0 { os.Exit(1) } pkg := pkgs[0] var buf bytes.Buffer fmt.Fprintf(&buf, ` // Code generated by "generate-fastjson". DO NOT EDIT. package %s import ( "errors" "math" %q ) var ( _ = errors.New _ = math.IsNaN ) `[1:], pkg.Types.Name(), fastjsonPath) var generated int for _, f := range pkg.Syntax { for _, decl := range f.Decls { genDecl, ok := decl.(*ast.GenDecl) if !ok || genDecl.Tok != token.TYPE { continue } for _, spec := range genDecl.Specs { typeSpec, ok := spec.(*ast.TypeSpec) if !ok { continue } obj := pkg.TypesInfo.Defs[typeSpec.Name] if obj == nil || !obj.Exported() { continue } typeName := obj.(*types.TypeName) named := typeName.Type().(*types.Named) if !hasMethod(named, marshalMethod) { generate(&buf, named) generated++ } } } } formatted, err := format.Source(buf.Bytes()) if err != nil { fmt.Println(buf.String()) log.Fatal(err) } var out io.Writer = os.Stdout if outfile != "-" { f, err := os.Create(outfile) if err != nil { log.Fatal(err) } defer f.Close() out = f } if _, err := out.Write(formatted); err != nil { log.Fatal(err) } if outfile != "" { fmt.Fprintf(os.Stderr, "generated %d methods in %q\n", generated, outfile) } } func generate(w *bytes.Buffer, named *types.Named) { structType, ok := named.Underlying().(*types.Struct) if !ok { panic(fmt.Errorf("unhandled type %T", named.Underlying())) } origw := w w = new(bytes.Buffer) defer func() { fmt.Fprintf(origw, "\nfunc (v *%s) %s(w *fastjson.Writer) error {\n", named.Obj().Name(), marshalMethod) // Hypothetically you could create a type whose names contains // "firstErr" which would force this. No big deal if the var is // never written to, this is just for aesthetics. mayError := strings.Contains(w.String(), "firstErr") if mayError { fmt.Fprintln(origw, "var firstErr error") } fmt.Fprintln(origw, `w.RawByte('{')`) w.WriteTo(origw) fmt.Fprintln(origw, `w.RawByte('}')`) if mayError { fmt.Fprintln(origw, "return firstErr") } else { fmt.Fprintln(origw, "return nil") } fmt.Fprintln(origw, "}") }() numFields := structType.NumFields() structFields := make([]structField, 0, numFields) for i := 0; i < numFields; i++ { structField, ok := makeStructField(structType, i) if !ok { continue } structFields = append(structFields, structField) } sort.Slice(structFields, func(i, j int) bool { // Put non-omitempty fields first, so we can elide // the runtime "first" tracking. switch { case !structFields[i].omitempty && structFields[j].omitempty: return true case structFields[i].omitempty && !structFields[j].omitempty: return false } return structFields[i].jsonName < structFields[j].jsonName }) checkFirst := len(structFields) > 1 && structFields[0].omitempty if checkFirst { fmt.Fprintln(w, "first := true") } for i, f := range structFields { if f.omitempty { fmt.Fprintf(w, "if %s {", isNonZero("v."+f.fieldName, f.fieldType)) } prefix := fmt.Sprintf(",%q:", f.jsonName) if checkFirst { fmt.Fprintf(w, ` const prefix = %q if first { first = false w.RawString(prefix[1:]) } else { w.RawString(prefix) } `[1:], prefix) } else { if i == 0 { prefix = prefix[1:] } fmt.Fprintf(w, "w.RawString(%q)\n", prefix) } var nillable bool if !f.omitempty { // For nillable types (pointer, slice, map, interface), // emit a null check to write "null". switch f.fieldType.Underlying().(type) { case *types.Pointer: nillable = true case *types.Slice: nillable = true case *types.Map: nillable = true case *types.Interface: nillable = true } if nillable { fmt.Fprintf(w, ` if v.%s == nil { w.RawString("null") } else { `[1:], f.fieldName) } } generateValue(w, "v."+f.fieldName, f.fieldType) if f.omitempty || nillable { fmt.Fprintln(w, "}") } } } func generateValue(w *bytes.Buffer, expr string, exprType types.Type) { if named, ok := exprType.(*types.Named); ok { if hasMethod(named, marshalMethod) { fmt.Fprintf(w, ` if err := %s.%s(w); err != nil && firstErr == nil { firstErr = err } `[1:], expr, marshalMethod) return } exprType = named.Underlying() } switch t := exprType.(type) { case *types.Pointer: generatePointerValue(w, expr, t) case *types.Slice: generateSliceValue(w, expr, t) case *types.Basic: generateBasicValue(w, expr, t) case *types.Map: generateMapValue(w, expr, t) case *types.Interface: generateInterfaceValue(w, expr, t) case *types.Struct: generateStructValue(w, expr, t) case *types.Alias: unaliasType := types.Unalias(t) generateValue(w, expr, unaliasType) default: panic(fmt.Errorf("unhandled type %T", t)) } } func generatePointerValue(w *bytes.Buffer, expr string, exprType *types.Pointer) { elem := exprType.Elem() switch t := elem.Underlying().(type) { case *types.Basic: generateBasicValue(w, "*"+expr, t) case *types.Struct: generateStructValue(w, expr, t) default: panic(fmt.Errorf("unhandled type %T", exprType)) } } func generateBasicValue(w *bytes.Buffer, expr string, exprType *types.Basic) { convert := func(t string) { expr = fmt.Sprintf("%s(%s)", t, expr) } var method string switch k := exprType.Kind(); k { case types.Bool: method = "Bool" case types.Int, types.Int8, types.Int16, types.Int32: convert("int64") method = "Int64" case types.Int64: method = "Int64" case types.Uint, types.Uint8, types.Uint16, types.Uint32: convert("uint64") method = "Uint64" case types.Uint64: method = "Uint64" case types.Float32: method = "Float32" fmt.Fprintf(w, ` if math.IsNaN(float64(%s)) { return errors.New("json: '%s': unsupported value: NaN") } if math.IsInf(float64(%s), 0) { return errors.New("json: '%s': unsupported value: Inf") } `[1:], expr, expr, expr, expr) case types.Float64: method = "Float64" fmt.Fprintf(w, ` if math.IsNaN(%s) { return errors.New("json: '%s': unsupported value: NaN") } if math.IsInf(%s, 0) { return errors.New("json: '%s': unsupported value: Inf") } `[1:], expr, expr, expr, expr) case types.String: method = "String" default: panic(fmt.Errorf("unhandled basic kind %q", types.Typ[k])) } fmt.Fprintf(w, "w.%s(%s)\n", method, expr) } func generateStructValue(w *bytes.Buffer, expr string, exprType *types.Struct) { fmt.Fprintf(w, ` if err := %s.%s(w); err != nil && firstErr == nil { firstErr = err } `[1:], expr, marshalMethod) } func generateInterfaceValue(w *bytes.Buffer, expr string, exprType *types.Interface) { fmt.Fprintf(w, ` if err := fastjson.Marshal(w, %s); err != nil && firstErr == nil { firstErr = err } `[1:], expr) } func generateSliceValue(w *bytes.Buffer, expr string, exprType *types.Slice) { fmt.Fprintf(w, ` w.RawByte('[') for i, v := range %s { if i != 0 { w.RawByte(',') } `[1:], expr) generateValue(w, "v", exprType.Elem()) fmt.Fprintln(w, ` } w.RawByte(']')`[1:]) } func generateMapValue(w *bytes.Buffer, expr string, exprType *types.Map) { fmt.Fprintf(w, ` w.RawByte('{') { first := true for k, v := range %s { if first { first = false } else { w.RawByte(',') } `[1:], expr) generateValue(w, "k", exprType.Key()) fmt.Fprintln(w, "w.RawByte(':')") generateValue(w, "v", exprType.Elem()) fmt.Fprintln(w, ` } } w.RawByte('}')`[1:]) } func isNonZero(expr string, t types.Type) string { if named, ok := t.(*types.Named); ok { if hasMethod(named, isZeroMethod) { return fmt.Sprintf("!%s.%s()", expr, isZeroMethod) } t = named.Underlying() } zero := "nil" switch t := t.(type) { case *types.Pointer: case *types.Slice: case *types.Map: case *types.Interface: case *types.Basic: switch t.Kind() { case types.String: zero = `""` case types.Bool: zero = "false" default: zero = "0" } case *types.Alias: unaliasType := types.Unalias(t) isNonZero(expr, unaliasType) default: panic(fmt.Errorf("unhandled type %T", t)) } return fmt.Sprintf("%s != %s", expr, zero) } type structField struct { fieldName string jsonName string fieldType types.Type omitempty bool } func makeStructField(structType *types.Struct, i int) (structField, bool) { fieldVar := structType.Field(i) if !fieldVar.Exported() { return structField{}, false } var omitempty bool fieldName := fieldVar.Name() jsonName := fieldName fieldTag := reflect.StructTag(structType.Tag(i)) jsonTag, ok := fieldTag.Lookup("json") if ok { if jsonTag == "-" { return structField{}, false } name := jsonTag comma := strings.IndexRune(jsonTag, ',') if comma >= 0 { name = jsonTag[:comma] switch jsonTag[comma+1:] { case "": // special case for `json:"-,"` case "omitempty": omitempty = true default: panic("unhandled json tag: " + jsonTag) } } if name != "" { jsonName = name } } return structField{ fieldName: fieldName, jsonName: jsonName, fieldType: fieldVar.Type(), omitempty: omitempty, }, true } func hasMethod(named *types.Named, method string) bool { for i := named.NumMethods() - 1; i >= 0; i-- { if named.Method(i).Name() == method { return true } } return false }