codegen/thrift.go (186 lines of code) (raw):
// Copyright (c) 2023 Uber Technologies, Inc.
//
// Permission is hereby granted, free of charge, to any person obtaining a copy
// of this software and associated documentation files (the "Software"), to deal
// in the Software without restriction, including without limitation the rights
// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
// copies of the Software, and to permit persons to whom the Software is
// furnished to do so, subject to the following conditions:
//
// The above copyright notice and this permission notice shall be included in
// all copies or substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
// THE SOFTWARE.
package codegen
import (
"fmt"
"github.com/pkg/errors"
"go.uber.org/thriftrw/compile"
)
// GoType returns the Go type string representation for the given thrift type.
func GoType(p PackageNameResolver, spec compile.TypeSpec) (string, error) {
switch s := spec.(type) {
case *compile.BoolSpec:
return "bool", nil
case *compile.I8Spec:
return "int8", nil
case *compile.I16Spec:
return "int16", nil
case *compile.I32Spec:
return "int32", nil
case *compile.I64Spec:
return "int64", nil
case *compile.DoubleSpec:
return "float64", nil
case *compile.StringSpec:
return "string", nil
case *compile.BinarySpec:
return "[]byte", nil
case *compile.MapSpec:
k, err := GoReferenceType(p, s.KeySpec)
if err != nil {
return "", err
}
v, err := GoReferenceType(p, s.ValueSpec)
if err != nil {
return "", err
}
if !isHashable(s.KeySpec) {
return fmt.Sprintf("[]struct{Key %s; Value %s}", k, v), nil
}
return fmt.Sprintf("map[%s]%s", k, v), nil
case *compile.ListSpec:
v, err := GoReferenceType(p, s.ValueSpec)
if err != nil {
return "", err
}
return "[]" + v, nil
case *compile.SetSpec:
v, err := GoReferenceType(p, s.ValueSpec)
if err != nil {
return "", err
}
if !isHashable(s.ValueSpec) || isSliceSetType(s) {
return fmt.Sprintf("[]%s", v), nil
}
return fmt.Sprintf("map[%s]struct{}", v), nil
case *compile.EnumSpec, *compile.StructSpec, *compile.TypedefSpec:
return GoCustomType(p, spec)
default:
panic(fmt.Sprintf("Unknown type (%T) for %s", spec, spec.ThriftName()))
}
}
// GoReferenceType returns the Go reference type string representation for the given thrift type.
// for types like slice and map that are already of reference type, it returns the result of GoType;
// for struct type, it returns the pointer of the result of GoType.
func GoReferenceType(p PackageNameResolver, spec compile.TypeSpec) (string, error) {
t, err := GoType(p, spec)
if err != nil {
return "", err
}
if IsStructType(spec) {
t = "*" + t
}
return t, nil
}
// GoCustomType returns the user-defined Go type with its importing package.
func GoCustomType(p PackageNameResolver, spec compile.TypeSpec) (string, error) {
f := spec.ThriftFile()
if f == "" {
return "", fmt.Errorf("GoCustomType called with native type (%T) %v", spec, spec)
}
pkg, err := p.TypePackageName(f)
if err != nil {
return "", errors.Wrapf(err, "failed to get package for custom type (%T) %v", spec, spec)
}
return pkg + "." + PascalCase(spec.ThriftName()), nil
}
// IsStructType returns true if the given thrift type is struct, false otherwise.
func IsStructType(spec compile.TypeSpec) bool {
spec = compile.RootTypeSpec(spec)
_, isStruct := spec.(*compile.StructSpec)
return isStruct
}
// isHashable returns true if the given type is considered hashable by thriftrw.
//
// Only primitive types, enums, and typedefs of other hashable types are considered hashable.
// binary is not considered a primitive type because it is represented as []byte in Go.
func isHashable(spec compile.TypeSpec) bool {
spec = compile.RootTypeSpec(spec)
switch spec.(type) {
case *compile.BoolSpec, *compile.I8Spec, *compile.I16Spec, *compile.I32Spec,
*compile.I64Spec, *compile.DoubleSpec, *compile.StringSpec, *compile.EnumSpec:
return true
default:
return false
}
}
// IsSliceSetType returns true if the given thrift type is a Set implemented as a slice (as opposed to a map)
func isSliceSetType(spec compile.TypeSpec) bool {
spec = compile.RootTypeSpec(spec)
_, isSet := spec.(*compile.SetSpec)
return isSet && spec.ThriftAnnotations()["go.type"] == "slice"
}
func pointerMethodType(typeSpec compile.TypeSpec) string {
var pointerMethod string
switch typeSpec.(type) {
case *compile.BoolSpec:
pointerMethod = "Bool"
case *compile.I8Spec:
pointerMethod = "Int8"
case *compile.I16Spec:
pointerMethod = "Int16"
case *compile.I32Spec:
pointerMethod = "Int32"
case *compile.I64Spec:
pointerMethod = "Int64"
case *compile.DoubleSpec:
pointerMethod = "Float64"
case *compile.StringSpec:
pointerMethod = "String"
case *compile.EnumSpec:
pointerMethod = "Int32"
default:
panic(fmt.Sprintf(
"Unknown type (%T) for %s for allocating a pointer",
typeSpec, typeSpec.ThriftName(),
))
}
return pointerMethod
}
type walkFieldVisitor func(
goPrefix string,
thriftPrefix string,
field *compile.FieldSpec,
) bool
func walkFieldGroups(
fields compile.FieldGroup,
visitField walkFieldVisitor,
) bool {
seen := map[*compile.FieldSpec]bool{}
return walkFieldGroupsInternal("", "", fields, visitField, seen)
}
func walkFieldGroupsInternal(
goPrefix string,
thriftPrefix string,
fields compile.FieldGroup,
visitField walkFieldVisitor,
seen map[*compile.FieldSpec]bool,
) bool {
for i := 0; i < len(fields); i++ {
field := fields[i]
if seen[field] {
// Skip this field if we have already considered it
continue
}
seen[field] = true
bail := visitField(goPrefix, thriftPrefix, field)
if bail {
return true
}
realType := compile.RootTypeSpec(field.Type)
switch t := realType.(type) {
case *compile.BinarySpec:
case *compile.StringSpec:
case *compile.BoolSpec:
case *compile.DoubleSpec:
case *compile.I8Spec:
case *compile.I16Spec:
case *compile.I32Spec:
case *compile.I64Spec:
case *compile.EnumSpec:
case *compile.ListSpec:
case *compile.SetSpec:
case *compile.StructSpec:
bail := walkFieldGroupsInternal(
goPrefix+"."+PascalCase(field.Name),
thriftPrefix+"."+field.Name,
t.Fields,
visitField,
seen,
)
if bail {
return true
}
case *compile.MapSpec:
// TODO: implement
default:
panic("unknown Spec")
}
}
return false
}