thrift/thrift-gen/wrap.go (183 lines of code) (raw):
package main
import (
"fmt"
"sort"
"strings"
"github.com/samuel/go-thrift/parser"
)
type byServiceName []*Service
func (l byServiceName) Len() int { return len(l) }
func (l byServiceName) Less(i, j int) bool { return l[i].Service.Name < l[j].Service.Name }
func (l byServiceName) Swap(i, j int) { l[i], l[j] = l[j], l[i] }
func wrapServices(v *parser.Thrift, state *State) ([]*Service, error) {
var services []*Service
for _, s := range v.Services {
if err := Validate(s); err != nil {
return nil, err
}
services = append(services, &Service{
Service: s,
state: state,
})
}
// Have a stable ordering for services so code generation is consistent.
sort.Sort(byServiceName(services))
return services, nil
}
// Service is a wrapper for parser.Service.
type Service struct {
*parser.Service
state *State
// ExtendsService and ExtendsPrefix are set in `setExtends`.
ExtendsService *Service
ExtendsPrefix string
// methods is a cache of all methods.
methods []*Method
// inheritedMethods is a list of inherited method names.
inheritedMethods []string
}
// ThriftName returns the thrift identifier for this service.
func (s *Service) ThriftName() string {
return s.Service.Name
}
// Interface returns the name of the interface representing the service.
func (s *Service) Interface() string {
return "TChan" + goPublicName(s.Name)
}
// ClientStruct returns the name of the unexported struct that satisfies the interface as a client.
func (s *Service) ClientStruct() string {
return "tchan" + goPublicName(s.Name) + "Client"
}
// ClientConstructor returns the name of the constructor used to create a client.
func (s *Service) ClientConstructor() string {
return "NewTChan" + goPublicName(s.Name) + "Client"
}
// InheritedClientConstructor returns the name of the constructor used by the generated code
// for inherited services. This allows the parent service to set the service name that should
// be used.
func (s *Service) InheritedClientConstructor() string {
return "NewTChan" + goPublicName(s.Name) + "InheritedClient"
}
// ServerStruct returns the name of the unexported struct that satisfies TChanServer.
func (s *Service) ServerStruct() string {
return "tchan" + goPublicName(s.Name) + "Server"
}
// ServerConstructor returns the name of the constructor used to create the TChanServer interface.
func (s *Service) ServerConstructor() string {
return "NewTChan" + goPublicName(s.Name) + "Server"
}
// HasExtends returns whether this service extends another service.
func (s *Service) HasExtends() bool {
return s.ExtendsService != nil
}
// ExtendsServicePrefix returns a package selector (if any) for the extended service.
func (s *Service) ExtendsServicePrefix() string {
if dotIndex := strings.Index(s.Extends, "."); dotIndex > 0 {
return s.ExtendsPrefix
}
return ""
}
type byMethodName []*Method
func (l byMethodName) Len() int { return len(l) }
func (l byMethodName) Less(i, j int) bool { return l[i].Method.Name < l[j].Method.Name }
func (l byMethodName) Swap(i, j int) { l[i], l[j] = l[j], l[i] }
// Methods returns the methods on this service, not including methods from inherited services.
func (s *Service) Methods() []*Method {
if s.methods != nil {
return s.methods
}
for _, m := range s.Service.Methods {
s.methods = append(s.methods, &Method{m, s, s.state})
}
sort.Sort(byMethodName(s.methods))
return s.methods
}
// InheritedMethods returns names for inherited methods on this service.
func (s *Service) InheritedMethods() []string {
if s.inheritedMethods != nil {
return s.inheritedMethods
}
for svc := s.ExtendsService; svc != nil; svc = svc.ExtendsService {
for m := range svc.Service.Methods {
s.inheritedMethods = append(s.inheritedMethods, m)
}
}
sort.Strings(s.inheritedMethods)
return s.inheritedMethods
}
// Method is a wrapper for parser.Method.
type Method struct {
*parser.Method
service *Service
state *State
}
// ThriftName returns the thrift identifier for this function.
func (m *Method) ThriftName() string {
return m.Method.Name
}
// Name returns the go method name.
func (m *Method) Name() string {
return goPublicName(m.Method.Name)
}
// HandleFunc is the go method name for the handle function which decodes the payload.
func (m *Method) HandleFunc() string {
return "handle" + goPublicName(m.Method.Name)
}
// Arguments returns the argument declarations for this method.
func (m *Method) Arguments() []*Field {
var args []*Field
for _, f := range m.Method.Arguments {
args = append(args, &Field{f, m.state})
}
return args
}
// Exceptions returns the exceptions that this method may return.
func (m *Method) Exceptions() []*Field {
var args []*Field
for _, f := range m.Method.Exceptions {
args = append(args, &Field{f, m.state})
}
return args
}
// HasReturn returns false if this method is declared as void in the Thrift file.
func (m *Method) HasReturn() bool {
return m.Method.ReturnType != nil
}
// HasExceptions returns true if this method has
func (m *Method) HasExceptions() bool {
return len(m.Method.Exceptions) > 0
}
func (m *Method) argResPrefix() string {
return goPublicName(m.service.Name) + m.Name()
}
// ArgsType returns the Go name for the struct used to encode the method's arguments.
func (m *Method) ArgsType() string {
return m.argResPrefix() + "Args"
}
// ResultType returns the Go name for the struct used to encode the method's result.
func (m *Method) ResultType() string {
return m.argResPrefix() + "Result"
}
// ArgList returns the argument list for the function.
func (m *Method) ArgList() string {
args := []string{"ctx " + contextType()}
for _, arg := range m.Arguments() {
args = append(args, arg.Declaration())
}
return strings.Join(args, ", ")
}
// CallList creates the call to a function satisfying Interface from an Args struct.
func (m *Method) CallList(reqStruct string) string {
args := []string{"ctx"}
for _, arg := range m.Arguments() {
args = append(args, reqStruct+"."+arg.ArgStructName())
}
return strings.Join(args, ", ")
}
// RetType returns the go return type of the method.
func (m *Method) RetType() string {
if !m.HasReturn() {
return "error"
}
return fmt.Sprintf("(%v, %v)", m.state.goType(m.Method.ReturnType), "error")
}
// WrapResult wraps the result variable before being used in the result struct.
func (m *Method) WrapResult(respVar string) string {
if !m.HasReturn() {
panic("cannot wrap a return when there is no return mode")
}
if m.state.isResultPointer(m.ReturnType) {
return respVar
}
return "&" + respVar
}
// ReturnWith takes the result name and the error name, and generates the return expression.
func (m *Method) ReturnWith(respName string, errName string) string {
if !m.HasReturn() {
return errName
}
return fmt.Sprintf("%v, %v", respName, errName)
}
// Field is a wrapper for parser.Field.
type Field struct {
*parser.Field
state *State
}
// Declaration returns the declaration for this field.
func (a *Field) Declaration() string {
return fmt.Sprintf("%s %s", a.Name(), a.ArgType())
}
// Name returns the field name.
func (a *Field) Name() string {
return goName(a.Field.Name)
}
// ArgType returns the Go type for the given field.
func (a *Field) ArgType() string {
return a.state.goType(a.Type)
}
// ArgStructName returns the name of this field in the Args struct generated by thrift.
func (a *Field) ArgStructName() string {
return goPublicFieldName(a.Field.Name)
}