cmd/generate-fastjson/main.go (432 lines of code) (raw):
// Copyright 2018 Elasticsearch BV
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package main
import (
"bytes"
"flag"
"fmt"
"go/ast"
"go/format"
"go/token"
"go/types"
"io"
"log"
"os"
"reflect"
"sort"
"strings"
"golang.org/x/tools/go/packages"
)
const (
fastjsonPath = "go.elastic.co/fastjson"
isZeroMethod = "isZero"
marshalMethod = "MarshalFastJSON"
)
var (
force bool
outfile string
)
func init() {
flag.BoolVar(&force, "f", false, "remove the output file if it exists")
flag.StringVar(&outfile, "o", "-", "file to which output will be written")
flag.Usage = func() {
fmt.Fprintf(os.Stderr, "Usage: %s <package>\n", os.Args[0])
flag.PrintDefaults()
}
}
func main() {
flag.Parse()
if flag.NArg() != 1 {
flag.Usage()
os.Exit(1)
}
if outfile != "-" {
if _, err := os.Stat(outfile); err == nil {
if force {
if err := os.Remove(outfile); err != nil {
log.Fatal(err)
}
} else {
fmt.Fprintf(os.Stderr, "%s already exists, and -f not specified; aborting\n", outfile)
os.Exit(2)
}
}
}
cfg := &packages.Config{
Mode: packages.NeedTypes | packages.NeedSyntax | packages.NeedTypesInfo,
}
pkgs, err := packages.Load(cfg, flag.Arg(0))
if err != nil {
fmt.Fprintf(os.Stderr, "load: %v\n", err)
os.Exit(1)
}
if packages.PrintErrors(pkgs) > 0 {
os.Exit(1)
}
pkg := pkgs[0]
var buf bytes.Buffer
fmt.Fprintf(&buf, `
// Code generated by "generate-fastjson". DO NOT EDIT.
package %s
import (
"errors"
"math"
%q
)
var (
_ = errors.New
_ = math.IsNaN
)
`[1:], pkg.Types.Name(), fastjsonPath)
var generated int
for _, f := range pkg.Syntax {
for _, decl := range f.Decls {
genDecl, ok := decl.(*ast.GenDecl)
if !ok || genDecl.Tok != token.TYPE {
continue
}
for _, spec := range genDecl.Specs {
typeSpec, ok := spec.(*ast.TypeSpec)
if !ok {
continue
}
obj := pkg.TypesInfo.Defs[typeSpec.Name]
if obj == nil || !obj.Exported() {
continue
}
typeName := obj.(*types.TypeName)
named := typeName.Type().(*types.Named)
if !hasMethod(named, marshalMethod) {
generate(&buf, named)
generated++
}
}
}
}
formatted, err := format.Source(buf.Bytes())
if err != nil {
fmt.Println(buf.String())
log.Fatal(err)
}
var out io.Writer = os.Stdout
if outfile != "-" {
f, err := os.Create(outfile)
if err != nil {
log.Fatal(err)
}
defer f.Close()
out = f
}
if _, err := out.Write(formatted); err != nil {
log.Fatal(err)
}
if outfile != "" {
fmt.Fprintf(os.Stderr, "generated %d methods in %q\n", generated, outfile)
}
}
func generate(w *bytes.Buffer, named *types.Named) {
structType, ok := named.Underlying().(*types.Struct)
if !ok {
panic(fmt.Errorf("unhandled type %T", named.Underlying()))
}
origw := w
w = new(bytes.Buffer)
defer func() {
fmt.Fprintf(origw, "\nfunc (v *%s) %s(w *fastjson.Writer) error {\n", named.Obj().Name(), marshalMethod)
// Hypothetically you could create a type whose names contains
// "firstErr" which would force this. No big deal if the var is
// never written to, this is just for aesthetics.
mayError := strings.Contains(w.String(), "firstErr")
if mayError {
fmt.Fprintln(origw, "var firstErr error")
}
fmt.Fprintln(origw, `w.RawByte('{')`)
w.WriteTo(origw)
fmt.Fprintln(origw, `w.RawByte('}')`)
if mayError {
fmt.Fprintln(origw, "return firstErr")
} else {
fmt.Fprintln(origw, "return nil")
}
fmt.Fprintln(origw, "}")
}()
numFields := structType.NumFields()
structFields := make([]structField, 0, numFields)
for i := 0; i < numFields; i++ {
structField, ok := makeStructField(structType, i)
if !ok {
continue
}
structFields = append(structFields, structField)
}
sort.Slice(structFields, func(i, j int) bool {
// Put non-omitempty fields first, so we can elide
// the runtime "first" tracking.
switch {
case !structFields[i].omitempty && structFields[j].omitempty:
return true
case structFields[i].omitempty && !structFields[j].omitempty:
return false
}
return structFields[i].jsonName < structFields[j].jsonName
})
checkFirst := len(structFields) > 1 && structFields[0].omitempty
if checkFirst {
fmt.Fprintln(w, "first := true")
}
for i, f := range structFields {
if f.omitempty {
fmt.Fprintf(w, "if %s {", isNonZero("v."+f.fieldName, f.fieldType))
}
prefix := fmt.Sprintf(",%q:", f.jsonName)
if checkFirst {
fmt.Fprintf(w, `
const prefix = %q
if first {
first = false
w.RawString(prefix[1:])
} else {
w.RawString(prefix)
}
`[1:], prefix)
} else {
if i == 0 {
prefix = prefix[1:]
}
fmt.Fprintf(w, "w.RawString(%q)\n", prefix)
}
var nillable bool
if !f.omitempty {
// For nillable types (pointer, slice, map, interface),
// emit a null check to write "null".
switch f.fieldType.Underlying().(type) {
case *types.Pointer:
nillable = true
case *types.Slice:
nillable = true
case *types.Map:
nillable = true
case *types.Interface:
nillable = true
}
if nillable {
fmt.Fprintf(w, `
if v.%s == nil {
w.RawString("null")
} else {
`[1:], f.fieldName)
}
}
generateValue(w, "v."+f.fieldName, f.fieldType)
if f.omitempty || nillable {
fmt.Fprintln(w, "}")
}
}
}
func generateValue(w *bytes.Buffer, expr string, exprType types.Type) {
if named, ok := exprType.(*types.Named); ok {
if hasMethod(named, marshalMethod) {
fmt.Fprintf(w, `
if err := %s.%s(w); err != nil && firstErr == nil {
firstErr = err
}
`[1:], expr, marshalMethod)
return
}
exprType = named.Underlying()
}
switch t := exprType.(type) {
case *types.Pointer:
generatePointerValue(w, expr, t)
case *types.Slice:
generateSliceValue(w, expr, t)
case *types.Basic:
generateBasicValue(w, expr, t)
case *types.Map:
generateMapValue(w, expr, t)
case *types.Interface:
generateInterfaceValue(w, expr, t)
case *types.Struct:
generateStructValue(w, expr, t)
case *types.Alias:
unaliasType := types.Unalias(t)
generateValue(w, expr, unaliasType)
default:
panic(fmt.Errorf("unhandled type %T", t))
}
}
func generatePointerValue(w *bytes.Buffer, expr string, exprType *types.Pointer) {
elem := exprType.Elem()
switch t := elem.Underlying().(type) {
case *types.Basic:
generateBasicValue(w, "*"+expr, t)
case *types.Struct:
generateStructValue(w, expr, t)
default:
panic(fmt.Errorf("unhandled type %T", exprType))
}
}
func generateBasicValue(w *bytes.Buffer, expr string, exprType *types.Basic) {
convert := func(t string) {
expr = fmt.Sprintf("%s(%s)", t, expr)
}
var method string
switch k := exprType.Kind(); k {
case types.Bool:
method = "Bool"
case types.Int, types.Int8, types.Int16, types.Int32:
convert("int64")
method = "Int64"
case types.Int64:
method = "Int64"
case types.Uint, types.Uint8, types.Uint16, types.Uint32:
convert("uint64")
method = "Uint64"
case types.Uint64:
method = "Uint64"
case types.Float32:
method = "Float32"
fmt.Fprintf(w, `
if math.IsNaN(float64(%s)) {
return errors.New("json: '%s': unsupported value: NaN")
}
if math.IsInf(float64(%s), 0) {
return errors.New("json: '%s': unsupported value: Inf")
}
`[1:], expr, expr, expr, expr)
case types.Float64:
method = "Float64"
fmt.Fprintf(w, `
if math.IsNaN(%s) {
return errors.New("json: '%s': unsupported value: NaN")
}
if math.IsInf(%s, 0) {
return errors.New("json: '%s': unsupported value: Inf")
}
`[1:], expr, expr, expr, expr)
case types.String:
method = "String"
default:
panic(fmt.Errorf("unhandled basic kind %q", types.Typ[k]))
}
fmt.Fprintf(w, "w.%s(%s)\n", method, expr)
}
func generateStructValue(w *bytes.Buffer, expr string, exprType *types.Struct) {
fmt.Fprintf(w, `
if err := %s.%s(w); err != nil && firstErr == nil {
firstErr = err
}
`[1:], expr, marshalMethod)
}
func generateInterfaceValue(w *bytes.Buffer, expr string, exprType *types.Interface) {
fmt.Fprintf(w, `
if err := fastjson.Marshal(w, %s); err != nil && firstErr == nil {
firstErr = err
}
`[1:], expr)
}
func generateSliceValue(w *bytes.Buffer, expr string, exprType *types.Slice) {
fmt.Fprintf(w, `
w.RawByte('[')
for i, v := range %s {
if i != 0 {
w.RawByte(',')
}
`[1:], expr)
generateValue(w, "v", exprType.Elem())
fmt.Fprintln(w, `
}
w.RawByte(']')`[1:])
}
func generateMapValue(w *bytes.Buffer, expr string, exprType *types.Map) {
fmt.Fprintf(w, `
w.RawByte('{')
{
first := true
for k, v := range %s {
if first {
first = false
} else {
w.RawByte(',')
}
`[1:], expr)
generateValue(w, "k", exprType.Key())
fmt.Fprintln(w, "w.RawByte(':')")
generateValue(w, "v", exprType.Elem())
fmt.Fprintln(w, `
}
}
w.RawByte('}')`[1:])
}
func isNonZero(expr string, t types.Type) string {
if named, ok := t.(*types.Named); ok {
if hasMethod(named, isZeroMethod) {
return fmt.Sprintf("!%s.%s()", expr, isZeroMethod)
}
t = named.Underlying()
}
zero := "nil"
switch t := t.(type) {
case *types.Pointer:
case *types.Slice:
case *types.Map:
case *types.Interface:
case *types.Basic:
switch t.Kind() {
case types.String:
zero = `""`
case types.Bool:
zero = "false"
default:
zero = "0"
}
case *types.Alias:
unaliasType := types.Unalias(t)
isNonZero(expr, unaliasType)
default:
panic(fmt.Errorf("unhandled type %T", t))
}
return fmt.Sprintf("%s != %s", expr, zero)
}
type structField struct {
fieldName string
jsonName string
fieldType types.Type
omitempty bool
}
func makeStructField(structType *types.Struct, i int) (structField, bool) {
fieldVar := structType.Field(i)
if !fieldVar.Exported() {
return structField{}, false
}
var omitempty bool
fieldName := fieldVar.Name()
jsonName := fieldName
fieldTag := reflect.StructTag(structType.Tag(i))
jsonTag, ok := fieldTag.Lookup("json")
if ok {
if jsonTag == "-" {
return structField{}, false
}
name := jsonTag
comma := strings.IndexRune(jsonTag, ',')
if comma >= 0 {
name = jsonTag[:comma]
switch jsonTag[comma+1:] {
case "": // special case for `json:"-,"`
case "omitempty":
omitempty = true
default:
panic("unhandled json tag: " + jsonTag)
}
}
if name != "" {
jsonName = name
}
}
return structField{
fieldName: fieldName,
jsonName: jsonName,
fieldType: fieldVar.Type(),
omitempty: omitempty,
}, true
}
func hasMethod(named *types.Named, method string) bool {
for i := named.NumMethods() - 1; i >= 0; i-- {
if named.Method(i).Name() == method {
return true
}
}
return false
}