codegen/service.go (330 lines of code) (raw):
// Copyright (c) 2023 Uber Technologies, Inc.
//
// Permission is hereby granted, free of charge, to any person obtaining a copy
// of this software and associated documentation files (the "Software"), to deal
// in the Software without restriction, including without limitation the rights
// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
// copies of the Software, and to permit persons to whom the Software is
// furnished to do so, subject to the following conditions:
//
// The above copyright notice and this permission notice shall be included in
// all copies or substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
// THE SOFTWARE.
package codegen
import (
"fmt"
"os"
"sort"
"strings"
"github.com/emicklei/proto"
"github.com/pkg/errors"
"go.uber.org/thriftrw/compile"
)
// ModuleSpec collects the service specifications from thrift file.
type ModuleSpec struct {
// CompiledModule is the resolved module from thrift file
// that will contain modules and typedefs not directly mounted on AST
CompiledModule *compile.Module `json:"omitempty"`
// Source thrift file to generate the code.
ThriftFile string
// Whether the ThriftFile should have annotations or not
WantAnnot bool
// Whether the module is for an endpoint vs downstream client
IsEndpoint bool
// Go package name, generated base on module name.
PackageName string
// Go client types file path, generated from thrift file.
GoThriftTypesFilePath string
// Generated imports
IncludedPackages []GoPackageImport
Services ServiceSpecs
ProtoServices []*ProtoService
}
// GoPackageImport ...
type GoPackageImport struct {
PackageName string
AliasName string
}
// ServiceSpecs is a list of ServiceSpecs
type ServiceSpecs []*ServiceSpec
func (a ServiceSpecs) Len() int {
return len(a)
}
func (a ServiceSpecs) Swap(i, j int) {
a[i], a[j] = a[j], a[i]
}
func (a ServiceSpecs) Less(i, j int) bool {
return a[i].Name < a[j].Name
}
// ServiceSpec specifies a service.
type ServiceSpec struct {
// Service name
Name string
// Source thrift file to generate the code.
ThriftFile string
// Whether the service should have annotations or not
WantAnnot bool
// Whether the service is for an endpoint vs downstream client
IsEndpoint bool
// List of methods/endpoints of the service
Methods []*MethodSpec
// thriftrw compile spec.
CompileSpec *compile.ServiceSpec
}
// NewProtoModuleSpec returns a specification for a proto module.
func NewProtoModuleSpec(protoFile string, isEndpoint bool, h *PackageHelper) (*ModuleSpec, error) {
reader, err := os.Open(protoFile)
if err != nil {
return nil, errors.Wrap(err, "failed reading proto file")
}
defer func() { _ = reader.Close() }()
parser := proto.NewParser(reader)
protoModules, err := parser.Parse()
if err != nil {
return nil, errors.Wrap(err, "failed parsing proto file")
}
pModule := newVisitor().Visit(protoModules)
sort.Sort(&pModule.Services)
moduleSpec := &ModuleSpec{
ProtoServices: pModule.Services,
ThriftFile: protoFile,
WantAnnot: false,
IsEndpoint: isEndpoint,
PackageName: pModule.PackageName,
}
newPkg, _ := h.TypeImportPath(protoFile)
moduleSpec.IncludedPackages = []GoPackageImport{{
PackageName: newPkg,
AliasName: "gen",
}}
return moduleSpec, nil
}
// NewModuleSpec returns a specification for a thrift module
func NewModuleSpec(
thrift string,
wantAnnot bool,
isEndpoint bool,
packageHelper *PackageHelper,
) (*ModuleSpec, error) {
if !fileExists(thrift) {
return nil, &ErrorSkipCodeGen{IDLFile: thrift}
}
module, err := compile.Compile(thrift)
if err != nil {
return nil, errors.Wrap(err, "failed parse thrift file")
}
moduleSpec := &ModuleSpec{
CompiledModule: module,
WantAnnot: wantAnnot,
IsEndpoint: isEndpoint,
ThriftFile: module.ThriftPath,
PackageName: packageName(module.GetName()),
}
if err := moduleSpec.AddServices(module, packageHelper); err != nil {
return nil, err
}
if err := moduleSpec.AddImports(module, packageHelper); err != nil {
return nil, err
}
return moduleSpec, nil
}
// ErrorSkipCodeGen when thrown modules can be skipped building without failing code gen
type ErrorSkipCodeGen struct {
IDLFile string
}
// Error when thrown modules can be skipped building without failing code gen
func (e *ErrorSkipCodeGen) Error() string {
return fmt.Sprintf("code gen skip for idlFile: %v", e.IDLFile)
}
// IgnorePopulateSpecStageErr when thrown modules can be skipped building while populating spec
type IgnorePopulateSpecStageErr struct {
Err error
}
// Error when thrown modules can be skipped building without failing code gen
func (e *IgnorePopulateSpecStageErr) Error() string {
return e.Err.Error()
}
// fileExists checks if a file exists and is not a directory before we
// try using it to prevent further errors.
func fileExists(filename string) bool {
info, err := os.Stat(filename)
if os.IsNotExist(err) {
return false
}
return !info.IsDir()
}
// AddImports adds imported Go packages in ModuleSpec in alphabetical order.
func (ms *ModuleSpec) AddImports(module *compile.Module, packageHelper *PackageHelper) error {
err := module.Walk(func(dep *compile.Module) error {
if err := ms.addTypeImport(dep.ThriftPath, packageHelper); err != nil {
return errors.Wrapf(err, "can't add import %s", dep.ThriftPath)
}
return nil
})
if err != nil {
return err
}
if err := ms.addTypeImport(ms.ThriftFile, packageHelper); err != nil {
return errors.Wrapf(err, "can't add import %s", ms.ThriftFile)
}
return nil
}
// AddServices adds services in ModuleSpec in alphabetical order of service names.
func (ms *ModuleSpec) AddServices(module *compile.Module, packageHelper *PackageHelper) error {
names := make([]string, 0, len(module.Services))
for name := range module.Services {
names = append(names, name)
}
sort.Strings(names)
for _, name := range names {
serviceSpec, err := NewServiceSpec(
module.Services[name],
ms.WantAnnot,
ms.IsEndpoint,
packageHelper,
)
if err != nil {
return err
}
ms.Services = append(ms.Services, serviceSpec)
}
return nil
}
// NewServiceSpec creates a service specification from given thrift file path.
func NewServiceSpec(
spec *compile.ServiceSpec,
wantAnnot bool,
isEndpoint bool,
packageHelper *PackageHelper,
) (*ServiceSpec, error) {
serviceSpec := &ServiceSpec{
WantAnnot: wantAnnot,
IsEndpoint: isEndpoint,
Name: spec.Name,
ThriftFile: spec.File,
CompileSpec: spec,
}
funcNames := make([]string, 0, len(spec.Functions))
for name := range spec.Functions {
funcNames = append(funcNames, name)
}
sort.Strings(funcNames)
for _, funcName := range funcNames {
method, err := serviceSpec.NewMethod(spec.Functions[funcName], packageHelper)
if err != nil {
return nil, errors.Wrapf(err, "service %s method %s", spec.Name, funcName)
}
serviceSpec.Methods = append(serviceSpec.Methods, method)
}
return serviceSpec, nil
}
// SetDownstream ...
func (ms *ModuleSpec) SetDownstream(
e *EndpointSpec,
h *PackageHelper,
) error {
var (
service *ServiceSpec
method *MethodSpec
serviceName = e.ThriftServiceName
methodName = e.ThriftMethodName
clientSpec = e.ClientSpec
clientMethod = e.ClientMethod
// TODO: move generated middlewares out of zanzibar
headersPropagate = e.HeadersPropagate
reqTransforms = e.ReqTransforms
respTransforms = e.RespTransforms
dummyReqTransforms = e.DummyReqTransforms
)
for _, v := range ms.Services {
if v.Name == serviceName {
service = v
break
}
}
if service == nil {
return errors.Errorf(
"Module does not have service %q\n", serviceName,
)
}
for _, v := range service.Methods {
if v.Name == methodName {
method = v
break
}
}
if method == nil {
return errors.Errorf(
"Service %q does not have method %q\n", serviceName, methodName,
)
}
if e.IsClientlessEndpoint {
funcSpec := method.CompiledThriftSpec
err := method.setClientlessTypeConverters(funcSpec, reqTransforms, headersPropagate, respTransforms, dummyReqTransforms, h)
if err != nil {
return errors.Errorf(
"unable to set dummy type convertors for dummy endpoint")
}
return nil
}
serviceMethod, ok := clientSpec.ExposedMethods[clientMethod]
if !ok {
return errors.Errorf("Client %q does not expose method %q", clientSpec.ClientName, clientMethod)
}
sm := strings.Split(serviceMethod, "::")
err := method.setDownstream(clientSpec.ModuleSpec, sm[0], sm[1])
if err != nil {
return err
}
// Exception validation
for en := range method.DownstreamMethod.ExceptionsIndex {
if _, ok := method.ExceptionsIndex[en]; !ok {
return fmt.Errorf("Missing exception %s in Endpoint schema", en)
}
}
// If this is an endpoint then a downstream will be defined.
// If if it a client it will not be.
if method.Downstream != nil {
downstreamMethod := method.DownstreamMethod
downstreamSpec := downstreamMethod.CompiledThriftSpec
funcSpec := method.CompiledThriftSpec
err = method.setTypeConverters(funcSpec, downstreamSpec, reqTransforms, headersPropagate, respTransforms, h, downstreamMethod)
if err != nil {
return err
}
}
if method.Downstream != nil && len(headersPropagate) > 0 {
downstreamMethod, err := findMethodByName(method.Name, method.Downstream.Services)
if err != nil {
return err
}
downstreamSpec := downstreamMethod.CompiledThriftSpec
err = method.setHeaderPropagator(sortedHeaders(e.ReqHeaders, false), downstreamSpec, headersPropagate, h, downstreamMethod)
if err != nil {
return err
}
}
// Adds imports for downstream services.
if !ms.isPackageIncluded(clientSpec.ImportPackagePath) {
ms.IncludedPackages = append(
ms.IncludedPackages, GoPackageImport{
PackageName: clientSpec.ImportPackagePath,
AliasName: clientSpec.ImportPackageAlias,
},
)
}
// Adds imports for thrift types used by downstream services.
for _, service := range ms.Services {
for _, method := range service.Methods {
d := method.Downstream
if d != nil && !ms.isPackageIncluded(d.GoThriftTypesFilePath) {
// thrift types file is optional...
if d.GoThriftTypesFilePath == "" {
continue
}
ms.IncludedPackages = append(
ms.IncludedPackages, GoPackageImport{
PackageName: d.GoThriftTypesFilePath,
AliasName: "",
},
)
}
}
}
return nil
}
func findMethodByName(name string, serviceSpecs []*ServiceSpec) (*MethodSpec, error) {
var allMethods []string
for _, s := range serviceSpecs {
for _, dsMethod := range s.Methods {
allMethods = append(allMethods, s.Name+"::"+dsMethod.Name)
if name == dsMethod.Name {
return dsMethod, nil
}
}
}
return nil, errors.Errorf("failed to map downstream method %q to methods %q defined in thrift file", name, allMethods)
}
// NewMethod creates new method specification.
func (s *ServiceSpec) NewMethod(
funcSpec *compile.FunctionSpec, packageHelper *PackageHelper,
) (*MethodSpec, error) {
return NewMethod(s.ThriftFile, funcSpec, packageHelper, s.WantAnnot, s.IsEndpoint, s.Name)
}
func (ms *ModuleSpec) addTypeImport(thriftPath string, packageHelper *PackageHelper) error {
newPkg, err := packageHelper.TypeImportPath(thriftPath)
if err != nil {
return err
}
aliasName, err := packageHelper.TypePackageName(thriftPath)
if err != nil {
return err
}
if !ms.isPackageIncluded(newPkg) {
ms.IncludedPackages = append(
ms.IncludedPackages, GoPackageImport{
PackageName: newPkg,
AliasName: aliasName,
},
)
}
return nil
}
func (ms *ModuleSpec) isPackageIncluded(pkg string) bool {
for _, includedPkg := range ms.IncludedPackages {
if pkg == includedPkg.PackageName {
return true
}
}
return false
}