thrift/thrift-gen/main.go (170 lines of code) (raw):
// Copyright (c) 2015 Uber Technologies, Inc.
// Permission is hereby granted, free of charge, to any person obtaining a copy
// of this software and associated documentation files (the "Software"), to deal
// in the Software without restriction, including without limitation the rights
// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
// copies of the Software, and to permit persons to whom the Software is
// furnished to do so, subject to the following conditions:
//
// The above copyright notice and this permission notice shall be included in
// all copies or substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
// THE SOFTWARE.
// thrift-gen generates code for Thrift services that can be used with the
// uber/tchannel/thrift package. thrift-gen generated code relies on the
// Apache Thrift generated code for serialization/deserialization, and should
// be a part of the generated code's package.
package main
import (
"flag"
"fmt"
"log"
"os"
"path/filepath"
"regexp"
"strings"
"text/template"
"github.com/samuel/go-thrift/parser"
)
const tchannelThriftImport = "github.com/uber/tchannel-go/thrift"
var (
generateThrift = flag.Bool("generateThrift", false, "Whether to generate all Thrift go code")
inputFile = flag.String("inputFile", "", "The .thrift file to generate a client for")
outputDir = flag.String("outputDir", "gen-go", "The output directory to generate go code to.")
skipTChannel = flag.Bool("skipTChannel", false, "Whether to skip the TChannel template")
templateFiles = NewStringSliceFlag("template", "Template file to compile code from")
nlSpaceNL = regexp.MustCompile(`\n[ \t]+\n`)
)
// TemplateData is the data passed to the template that generates code.
type TemplateData struct {
Package string
AST *parser.Thrift
Services []*Service
Includes map[string]*Include
Imports imports
// global should not be directly exported to the template, but functions on
// global can be exposed to templates.
global *State
}
type imports struct {
Thrift string
TChannel string
}
func main() {
flag.Parse()
if *inputFile == "" {
log.Fatalf("Please specify an inputFile")
}
opts := processOptions{
InputFile: *inputFile,
GenerateThrift: *generateThrift,
OutputDir: *outputDir,
SkipTChannel: *skipTChannel,
TemplateFiles: *templateFiles,
}
if err := processFile(opts); err != nil {
log.Fatal(err)
}
}
type processOptions struct {
InputFile string
GenerateThrift bool
OutputDir string
SkipTChannel bool
TemplateFiles []string
}
func processFile(opts processOptions) error {
if err := os.MkdirAll(opts.OutputDir, 0770); err != nil {
return fmt.Errorf("failed to create output directory %q: %v", opts.OutputDir, err)
}
if opts.GenerateThrift {
if err := runThrift(opts.InputFile, opts.OutputDir); err != nil {
return fmt.Errorf("failed to run thrift for file %q: %v", opts.InputFile, err)
}
}
allParsed, err := parseFile(opts.InputFile)
if err != nil {
return fmt.Errorf("failed to parse file %q: %v", opts.InputFile, err)
}
allTemplates, err := parseTemplates(opts.SkipTChannel, opts.TemplateFiles)
if err != nil {
return fmt.Errorf("failed to parse templates: %v", err)
}
for filename, v := range allParsed {
pkg := getNamespace(filename, v.ast)
for _, template := range allTemplates {
outputFile := filepath.Join(opts.OutputDir, pkg, template.outputFile(pkg))
if err := generateCode(outputFile, template, pkg, v); err != nil {
return err
}
}
}
return nil
}
type parseState struct {
ast *parser.Thrift
namespace string
global *State
services []*Service
}
// parseTemplates returns a list of Templates that must be rendered given the template files.
func parseTemplates(skipTChannel bool, templateFiles []string) ([]*Template, error) {
var templates []*Template
if !skipTChannel {
templates = append(templates, &Template{
name: "tchan",
template: template.Must(parseTemplate(tchannelTmpl)),
})
}
for _, f := range templateFiles {
t, err := parseTemplateFile(f)
if err != nil {
return nil, err
}
templates = append(templates, t)
}
return templates, nil
}
func parseFile(inputFile string) (map[string]parseState, error) {
parser := &parser.Parser{}
parsed, _, err := parser.ParseFile(inputFile)
if err != nil {
return nil, err
}
allParsed := make(map[string]parseState)
for filename, v := range parsed {
state := newState(v, allParsed)
services, err := wrapServices(v, state)
if err != nil {
return nil, fmt.Errorf("wrap services failed: %v", err)
}
namespace := getNamespace(filename, v)
allParsed[filename] = parseState{v, namespace, state, services}
}
setIncludes(allParsed)
return allParsed, setExtends(allParsed)
}
func defaultPackageName(fullPath string) string {
filename := filepath.Base(fullPath)
file := strings.TrimSuffix(filename, filepath.Ext(filename))
return strings.ToLower(file)
}
func getNamespace(filename string, v *parser.Thrift) string {
if ns, ok := v.Namespaces["go"]; ok {
return ns
}
// TODO(prashant): Remove any characters that are not valid in Go package names.
return defaultPackageName(filename)
}
func generateCode(outputFile string, template *Template, pkg string, state parseState) error {
if outputFile == "" {
return fmt.Errorf("must speciy an output file")
}
if len(state.services) == 0 {
return nil
}
td := TemplateData{
Package: pkg,
AST: state.ast,
Includes: state.global.includes,
Services: state.services,
global: state.global,
Imports: imports{
Thrift: *apacheThriftImport,
TChannel: tchannelThriftImport,
},
}
return template.execute(outputFile, td)
}
type stringSliceFlag []string
func (s *stringSliceFlag) String() string {
return strings.Join(*s, ", ")
}
func (s *stringSliceFlag) Set(in string) error {
*s = append(*s, in)
return nil
}
// NewStringSliceFlag creates a new string slice flag. The default value is always nil.
func NewStringSliceFlag(name string, usage string) *[]string {
var ss stringSliceFlag
flag.Var(&ss, name, usage)
return (*[]string)(&ss)
}