thrift/thrift-gen/typestate.go (90 lines of code) (raw):
// Copyright (c) 2015 Uber Technologies, Inc.
// Permission is hereby granted, free of charge, to any person obtaining a copy
// of this software and associated documentation files (the "Software"), to deal
// in the Software without restriction, including without limitation the rights
// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
// copies of the Software, and to permit persons to whom the Software is
// furnished to do so, subject to the following conditions:
//
// The above copyright notice and this permission notice shall be included in
// all copies or substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
// THE SOFTWARE.
package main
import (
"strings"
"github.com/samuel/go-thrift/parser"
)
// State is global Thrift state for a file with type information.
type State struct {
// typedefs is a map from a typedef name to the underlying type.
typedefs map[string]*parser.Type
// includes is a map from Thrift base name to the include.
includes map[string]*Include
// all is used for includes.
all map[string]parseState
}
// newState parses the type information for a parsed Thrift file and returns the state.
func newState(v *parser.Thrift, all map[string]parseState) *State {
typedefs := make(map[string]*parser.Type)
for k, v := range v.Typedefs {
typedefs[k] = v.Type
}
// Enums are typedefs to an int64.
i64Type := &parser.Type{Name: "i64"}
for k := range v.Enums {
typedefs[k] = i64Type
}
return &State{typedefs, nil, all}
}
func setIncludes(all map[string]parseState) {
for _, v := range all {
v.global.includes = createIncludes(v.ast, all)
}
}
func (s *State) isBasicType(thriftType string) bool {
_, ok := thriftToGo[thriftType]
return ok
}
// rootType recurses through typedefs and returns the underlying type.
func (s *State) rootType(thriftType *parser.Type) *parser.Type {
if state, newType, include := s.checkInclude(thriftType); include != nil {
return state.rootType(newType)
}
if v, ok := s.typedefs[thriftType.Name]; ok {
return s.rootType(v)
}
return thriftType
}
// checkInclude will check if the type is an included type, and if so, return the
// state and type from the state for that file.
func (s *State) checkInclude(thriftType *parser.Type) (*State, *parser.Type, *Include) {
parts := strings.SplitN(thriftType.Name, ".", 2)
if len(parts) < 2 {
return nil, nil, nil
}
newType := *thriftType
newType.Name = parts[1]
include := s.includes[parts[0]]
state := s.all[include.file]
return state.global, &newType, include
}
// isResultPointer returns whether the result for this method is a pointer.
func (s *State) isResultPointer(thriftType *parser.Type) bool {
_, basicGoType := thriftToGo[s.rootType(thriftType).Name]
return !basicGoType
}
// goType returns the Go type name for the given thrift type.
func (s *State) goType(thriftType *parser.Type) string {
return s.goTypePrefix("", thriftType)
}
// goTypePrefix returns the Go type name for the given thrift type with the prefix.
func (s *State) goTypePrefix(prefix string, thriftType *parser.Type) string {
switch thriftType.Name {
case "binary":
return "[]byte"
case "list":
return "[]" + s.goType(thriftType.ValueType)
case "set":
return "map[" + s.goType(thriftType.ValueType) + "]bool"
case "map":
return "map[" + s.goType(thriftType.KeyType) + "]" + s.goType(thriftType.ValueType)
}
// If the type is imported, then ignore the package.
if state, newType, include := s.checkInclude(thriftType); include != nil {
return state.goTypePrefix(include.Package()+".", newType)
}
// If the type is a direct Go type, use that.
if goType, ok := thriftToGo[thriftType.Name]; ok {
return goType
}
goThriftName := goPublicFieldName(thriftType.Name)
goThriftName = prefix + goThriftName
// Check if the type has a typedef to the direct Go type.
rootType := s.rootType(thriftType)
if _, ok := thriftToGo[rootType.Name]; ok {
return goThriftName
}
if rootType.Name == "list" ||
rootType.Name == "set" ||
rootType.Name == "map" {
return goThriftName
}
// If it's a typedef to another struct, then the typedef is defined as a pointer
// so we do not want the pointer type here.
if rootType != thriftType {
return goThriftName
}
// If it's not a typedef for a basic type, we use a pointer.
return "*" + goThriftName
}