codegen/method.go (1,279 lines of code) (raw):

// Copyright (c) 2023 Uber Technologies, Inc. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: // // The above copyright notice and this permission notice shall be included in // all copies or substantial portions of the Software. // // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. package codegen import ( "fmt" "sort" "strconv" "strings" "github.com/pkg/errors" "go.uber.org/thriftrw/compile" ) const ( antHTTPMethod = "%s.http.method" antHTTPPath = "%s.http.path" antHTTPStatus = "%s.http.status" antHTTPReqHeaders = "%s.http.reqHeaders" antHTTPResHeaders = "%s.http.resHeaders" antHTTPRef = "%s.http.ref" antMeta = "%s.meta" antHandler = "%s.handler" // AntHTTPReqDefBoxed annotates a method so that the genereted method takes // generated argument directly instead of a struct that warps the argument. // The annotated method should have one and only one argument. AntHTTPReqDefBoxed = "%s.http.req.def" antHTTPResNoBody = "%s.http.res.body.disallow" ) const queryAnnotationPrefix = "query." const headerAnnotationPrefix = "headers." // PathSegment represents a part of the http path. type PathSegment struct { Type string Text string BodyIdentifier string ParamName string Required bool } // ExceptionSpec contains information about thrift exceptions type ExceptionSpec struct { StructSpec StatusCode StatusCode IsBodyDisallowed bool } // HeaderFieldInfo contains information about where to store // the string from headers into the request/response body. type HeaderFieldInfo struct { FieldIdentifier string IsPointer bool } // MethodSpec specifies all needed parts to generate code for a method in service. type MethodSpec struct { Name string HTTPMethod string // Used by edge gateway to generate endpoint. EndpointName string HTTPPath string PathSegments []PathSegment annotations annotations IsEndpoint bool // Statements for reading query parameters. ParseQueryParamGoStatements []string // Statements for writing query parameters WriteQueryParamGoStatements []string // Statements for reading request headers ReqHeaderGoStatements []string // Statements for reading request headers for clients ReqClientHeaderGoStatements []string // ResHeaderFields is a map of header name to a golang // field accessor expression used to read fields out // of the response body and place them into response headers ResHeaderFields map[string]HeaderFieldInfo // ReqHeaders needed, generated from "zanzibar.http.reqHeaders" ReqHeaders []string // ResHeaders needed, generated from "zanzibar.http.resHeaders" ResHeaders []string RequestType string ShortRequestType string ResponseType string ShortResponseType string OKStatusCode StatusCode Exceptions []ExceptionSpec ExceptionsByStatusCode map[int][]ExceptionSpec ExceptionsIndex map[string]ExceptionSpec ValidStatusCodes []int // Fully qualified field type of the unboxed field BoxedRequestType string // Unboxed field name BoxedRequestName string // Additional struct generated from the bundle of request args. RequestBoxed bool // Thrift service name the method belongs to. ThriftService string // The thriftrw-generated go package name GenCodePkgName string // Whether the method needs annotation or not. WantAnnot bool // The thriftrw compiled spec, used to extract type information CompiledThriftSpec *compile.FunctionSpec // The downstream service method set by endpoint config Downstream *ModuleSpec // the downstream service name DownstreamService string // The downstream method spec for the endpoint DownstreamMethod *MethodSpec // Statements for converting request types ConvertRequestGoStatements []string // Statements for converting response types ConvertResponseGoStatements []string // Statements for converting Clientless request types ConvertClientlessRequestGoStatements []string // Statements for propagating headers to client requests PropagateHeadersGoStatements []string // Statements for reading data out of url params (server) RequestParamGoStatements []string } type annotations struct { HTTPMethod string HTTPPath string HTTPStatus string HTTPReqHeaders string HTTPResHeaders string HTTPRef string Meta string Handler string HTTPReqDefBoxed string HTTPResNoBody string } // StructSpec specifies a Go struct to be generated. type StructSpec struct { Type string Name string Annotations map[string]string } // StatusCode is for http status code with exception message. type StatusCode struct { Code int Message string } // NewMethod creates new method specification. func NewMethod( thriftFile string, funcSpec *compile.FunctionSpec, packageHelper *PackageHelper, wantAnnot bool, isEndpoint bool, thriftService string, ) (*MethodSpec, error) { var ( err error ok bool ant = packageHelper.annotationPrefix method = &MethodSpec{} ) method.CompiledThriftSpec = funcSpec method.Name = funcSpec.MethodName() method.IsEndpoint = isEndpoint method.WantAnnot = wantAnnot method.ThriftService = thriftService method.annotations = annotations{ HTTPMethod: fmt.Sprintf(antHTTPMethod, ant), HTTPPath: fmt.Sprintf(antHTTPPath, ant), HTTPStatus: fmt.Sprintf(antHTTPStatus, ant), HTTPReqHeaders: fmt.Sprintf(antHTTPReqHeaders, ant), HTTPResHeaders: fmt.Sprintf(antHTTPResHeaders, ant), HTTPRef: fmt.Sprintf(antHTTPRef, ant), Meta: fmt.Sprintf(antMeta, ant), Handler: fmt.Sprintf(antHandler, ant), HTTPReqDefBoxed: fmt.Sprintf(AntHTTPReqDefBoxed, ant), HTTPResNoBody: fmt.Sprintf(antHTTPResNoBody, ant), } method.GenCodePkgName, err = packageHelper.TypePackageName(thriftFile) if err != nil { return nil, err } err = method.setResponseType(thriftFile, funcSpec.ResultSpec, packageHelper) if err != nil { return nil, err } err = method.setRequestType(thriftFile, funcSpec, packageHelper) if err != nil { return nil, err } err = method.setExceptions(thriftFile, isEndpoint, funcSpec.ResultSpec, packageHelper) if err != nil { return nil, err } method.ReqHeaders = headers(funcSpec.Annotations[method.annotations.HTTPReqHeaders]) method.ResHeaders = headers(funcSpec.Annotations[method.annotations.HTTPResHeaders]) if !wantAnnot { return method, nil } if method.HTTPMethod, ok = funcSpec.Annotations[method.annotations.HTTPMethod]; !ok { return nil, errors.Errorf("missing annotation '%s' for HTTP method", method.annotations.HTTPMethod) } method.EndpointName = funcSpec.Annotations[method.annotations.Handler] err = method.setOKStatusCode(funcSpec.Annotations[method.annotations.HTTPStatus]) if err != nil { return nil, err } method.setValidStatusCodes() if method.RequestType != "" { hasNoBody := method.HTTPMethod == "GET" if method.IsEndpoint { err := method.setParseQueryParamStatements(funcSpec, packageHelper, hasNoBody) if err != nil { return nil, err } } else { err := method.setWriteQueryParamStatements(funcSpec, packageHelper, hasNoBody) if err != nil { return nil, err } } } var httpPath string if httpPath, ok = funcSpec.Annotations[method.annotations.HTTPPath]; !ok { return nil, errors.Errorf( "missing annotation '%s' for HTTP path", method.annotations.HTTPPath, ) } method.setHTTPPath(httpPath, funcSpec) err = method.setRequestParamFields(funcSpec, packageHelper) if err != nil { return nil, err } err = method.setEndpointRequestHeaderFields(funcSpec, packageHelper) if err != nil { return nil, err } err = method.setClientRequestHeaderFields(funcSpec, packageHelper) if err != nil { return nil, err } method.setResponseHeaderFields(funcSpec) return method, nil } // setRequestType sets the request type of the method specification. If the // "zanzibar.http.req.def.boxed" is true, then the first parameter will be used as // the request body; otherwise a new struct is generated to bundle the request // parameters as http body and the name of the struct will be returned. func (ms *MethodSpec) setRequestType(curThriftFile string, funcSpec *compile.FunctionSpec, packageHelper *PackageHelper) error { if len(funcSpec.ArgsSpec) == 0 { ms.RequestType = "" return nil } var err error if ms.isRequestBoxed(funcSpec) { ms.RequestBoxed = true ms.BoxedRequestType, err = packageHelper.TypeFullName(funcSpec.ArgsSpec[0].Type) ms.BoxedRequestName = PascalCase(funcSpec.ArgsSpec[0].Name) if err == nil && IsStructType(funcSpec.ArgsSpec[0].Type) { ms.BoxedRequestType = "*" + ms.BoxedRequestType } } goPackageName, err := packageHelper.TypePackageName(curThriftFile) if err == nil { ms.ShortRequestType = goPackageName + "." + ms.ThriftService + "_" + strings.Title(ms.Name) + "_Args" ms.RequestType = "*" + ms.ShortRequestType } if err != nil { return errors.Wrap(err, "failed to set request type") } return nil } func (ms *MethodSpec) setResponseType(curThriftFile string, respSpec *compile.ResultSpec, packageHelper *PackageHelper) error { if respSpec == nil { ms.ResponseType = "" return nil } typeName, err := packageHelper.TypeFullName(respSpec.ReturnType) ms.ShortResponseType = typeName if IsStructType(respSpec.ReturnType) { typeName = "*" + typeName } if err != nil { return errors.Wrap(err, "failed to get response type") } ms.ResponseType = typeName return nil } // RefResponse prepends the response variable with '&' if it is not of reference type // It is used to construct the `Success` field of the `$service_$method_Result` struct // generated by thriftrw, which is always of reference type. func (ms *MethodSpec) RefResponse(respVar string) string { respSpec := ms.CompiledThriftSpec.ResultSpec if respSpec == nil || respSpec.ReturnType == nil { return respVar } switch compile.RootTypeSpec(respSpec.ReturnType).(type) { case *compile.BoolSpec, *compile.I8Spec, *compile.I16Spec, *compile.I32Spec, *compile.I64Spec, *compile.DoubleSpec, *compile.StringSpec, *compile.EnumSpec: return "&" + respVar default: return respVar } } func (ms *MethodSpec) setOKStatusCode(statusCode string) error { if statusCode == "" { return errors.Errorf("no http OK status code set by annotation '%s' ", ms.annotations.HTTPStatus) } code, err := strconv.Atoi(statusCode) if err != nil { return errors.Wrapf(err, "Could not parse status code annotation (%s) for ok response", statusCode, ) } ms.OKStatusCode = StatusCode{ Code: code, } return nil } func (ms *MethodSpec) setValidStatusCodes() { ms.ValidStatusCodes = []int{ ms.OKStatusCode.Code, } for code := range ms.ExceptionsByStatusCode { ms.ValidStatusCodes = append(ms.ValidStatusCodes, code) } // Prevents non-deterministic builds sort.Ints(ms.ValidStatusCodes) } func (ms *MethodSpec) setExceptions( curThriftFile string, isEndpoint bool, resultSpec *compile.ResultSpec, h *PackageHelper, ) error { ms.Exceptions = make([]ExceptionSpec, len(resultSpec.Exceptions)) ms.ExceptionsIndex = make( map[string]ExceptionSpec, len(resultSpec.Exceptions), ) ms.ExceptionsByStatusCode = map[int][]ExceptionSpec{} for i, e := range resultSpec.Exceptions { typeName, err := h.TypeFullName(e.Type) if err != nil { return errors.Wrapf( err, "cannot resolve type full name for %s for exception %s", e.Type, e.Name, ) } bodyDisallowed := ms.isBodyDisallowed(e) if !ms.WantAnnot { exception := ExceptionSpec{ StructSpec: StructSpec{ Type: typeName, Name: e.Name, }, IsBodyDisallowed: bodyDisallowed, } ms.Exceptions[i] = exception ms.ExceptionsIndex[e.Name] = exception if _, exists := ms.ExceptionsByStatusCode[exception.StatusCode.Code]; !exists { ms.ExceptionsByStatusCode[exception.StatusCode.Code] = []ExceptionSpec{} } ms.ExceptionsByStatusCode[exception.StatusCode.Code] = append( ms.ExceptionsByStatusCode[exception.StatusCode.Code], exception, ) continue } code, err := strconv.Atoi(e.Annotations[ms.annotations.HTTPStatus]) if err != nil { return errors.Wrapf( err, "cannot parse the annotation %s for exception %s", ms.annotations.HTTPStatus, e.Name, ) } exception := ExceptionSpec{ StructSpec: StructSpec{ Type: typeName, Name: e.Name, Annotations: e.Annotations, }, StatusCode: StatusCode{ Code: code, Message: e.Name, }, IsBodyDisallowed: bodyDisallowed, } ms.Exceptions[i] = exception ms.ExceptionsIndex[e.Name] = exception if _, exists := ms.ExceptionsByStatusCode[exception.StatusCode.Code]; !exists { ms.ExceptionsByStatusCode[exception.StatusCode.Code] = []ExceptionSpec{} } ms.ExceptionsByStatusCode[exception.StatusCode.Code] = append( ms.ExceptionsByStatusCode[exception.StatusCode.Code], exception, ) } return nil } func (ms *MethodSpec) findParamsAnnotation( fields compile.FieldGroup, paramName string, ) (string, bool, bool) { var identifier string var required bool visitor := func( goPrefix string, thriftPrefix string, field *compile.FieldSpec, ) bool { if param, ok := field.Annotations[ms.annotations.HTTPRef]; ok { if param == "params."+paramName[1:] { identifier = goPrefix + "." + PascalCase(field.Name) required = field.Required return true } } return false } walkFieldGroups(fields, visitor) if identifier == "" { return "", required, false } return identifier, required, true } func (ms *MethodSpec) setRequestParamFields( funcSpec *compile.FunctionSpec, packageHelper *PackageHelper, ) error { statements := LineBuilder{} seenStructs, itrOrder, err := findStructs(funcSpec, packageHelper) if err != nil { return err } for _, segment := range ms.PathSegments { if segment.Type != "param" { continue } for _, seenStruct := range itrOrder { if strings.HasPrefix(segment.BodyIdentifier, seenStruct) { statements.appendf("if requestBody%s == nil {", seenStruct, ) statements.appendf("\trequestBody%s = &%s{}", seenStruct, seenStructs[seenStruct], ) statements.append("}") } } if segment.Required { statements.appendf("requestBody%s = req.Params.Get(%q)", segment.BodyIdentifier, segment.ParamName, ) } else { statements.appendf( "requestBody%s = ptr.String(req.Params.Get(%q))", segment.BodyIdentifier, segment.ParamName, ) } } ms.RequestParamGoStatements = statements.GetLines() return nil } func findStructs( funcSpec *compile.FunctionSpec, packageHelper *PackageHelper, ) (map[string]string, []string, error) { fields := compile.FieldGroup(funcSpec.ArgsSpec) seenStructs := make(map[string]string) itrOrder := make([]string, 0) var finalError error visitor := func( goPrefix string, thriftPrefix string, field *compile.FieldSpec, ) bool { realType := compile.RootTypeSpec(field.Type) longFieldName := goPrefix + "." + PascalCase(field.Name) if _, ok := realType.(*compile.StructSpec); ok { typeName, err := GoType(packageHelper, realType) if err != nil { finalError = err return true } seenStructs[longFieldName] = typeName itrOrder = append(itrOrder, longFieldName) } return false } walkFieldGroups(fields, visitor) if finalError != nil { return nil, nil, finalError } return seenStructs, itrOrder, nil } func (ms *MethodSpec) setEndpointRequestHeaderFields( funcSpec *compile.FunctionSpec, packageHelper *PackageHelper, ) error { fields := compile.FieldGroup(funcSpec.ArgsSpec) // ms.ReqHeaderFields = map[string]HeaderFieldInfo{} statements := LineBuilder{} var finalError error var seenHeaders bool var headersMap = map[string]int{} var seenOptStructs = map[string]string{} // Scan for all annotations visitor := func( goPrefix string, thriftPrefix string, field *compile.FieldSpec, ) bool { realType := compile.RootTypeSpec(field.Type) longFieldName := goPrefix + "." + PascalCase(field.Name) // If the type is a struct then we cannot really do anything if _, ok := realType.(*compile.StructSpec); ok { // if a field is a struct then we must do a nil check typeName, err := GoType(packageHelper, realType) if err != nil { finalError = err return true } if field.Required { statements.appendf("if requestBody%s == nil {", longFieldName) statements.appendf("\trequestBody%s = &%s{}", longFieldName, typeName, ) statements.append("}") } else { seenOptStructs[longFieldName] = typeName } return false } if param, ok := field.Annotations[ms.annotations.HTTPRef]; ok { if strings.HasPrefix(param, headerAnnotationPrefix) { headerName := strings.TrimPrefix(param, headerAnnotationPrefix) camelHeaderName := CamelCase(headerName) fieldThriftType, err := GoType(packageHelper, field.Type) if err != nil { finalError = err return true } bodyIdentifier := goPrefix + "." + PascalCase(field.Name) seenCount := headersMap[camelHeaderName] var variableName string if seenCount > 0 { variableName = camelHeaderName + "No" + strconv.Itoa(seenCount) + "Value" } else { variableName = camelHeaderName + "Value" } headersMap[camelHeaderName] = seenCount + 1 if field.Required { statements.appendf("%s, _ := req.Header.Get(%q)", variableName, headerName, ) for seenStruct, typeName := range seenOptStructs { if strings.HasPrefix(longFieldName, seenStruct) { statements.appendf("if requestBody%s == nil {", seenStruct, ) statements.appendf("\trequestBody%s = &%s{}", seenStruct, typeName, ) statements.append("}") } } statements.appendf("requestBody%s = %s(%s)", bodyIdentifier, fieldThriftType, variableName, ) } else { statements.appendf("%s, %sExists := req.Header.Get(%q)", variableName, variableName, headerName, ) statements.appendf("if %sExists {", variableName) for seenStruct, typeName := range seenOptStructs { if strings.HasPrefix(longFieldName, seenStruct) { statements.appendf("\tif requestBody%s == nil {", seenStruct, ) statements.appendf("\t\trequestBody%s = &%s{}", seenStruct, typeName, ) statements.append("\t}") } } switch fieldThriftType { case "string": statements.appendf("\trequestBody%s = ptr.String(%s)", bodyIdentifier, variableName, ) case "int64": statements.appendf("body, _ := strconv.ParseInt(%s, 10, 64)", variableName, ) statements.appendf("requestBody%s = &body", bodyIdentifier) case "bool": statements.appendf("body, _ := strconv.ParseBool(%s)", variableName, ) statements.appendf("requestBody%s = &body", bodyIdentifier) case "float64": case "float32": statements.appendf("body, _ := strconv.ParseFloat(%s, 64)", variableName, ) statements.appendf("requestBody%s = &body", bodyIdentifier) default: statements.appendf("body := %s(%s)", fieldThriftType, variableName, ) statements.appendf("requestBody%s = &body", bodyIdentifier) } statements.append("}") } seenHeaders = true } } return false } walkFieldGroups(fields, visitor) if finalError != nil { return finalError } if seenHeaders { ms.ReqHeaderGoStatements = statements.GetLines() } return nil } func (ms *MethodSpec) setResponseHeaderFields( funcSpec *compile.FunctionSpec, ) { structType, ok := funcSpec.ResultSpec.ReturnType.(*compile.StructSpec) // If the result is not a struct then there are zero response header // annotations. if !ok { return } fields := structType.Fields ms.ResHeaderFields = map[string]HeaderFieldInfo{} // Scan for all annotations visitor := func( goPrefix string, thriftPrefix string, field *compile.FieldSpec, ) bool { if param, ok := field.Annotations[ms.annotations.HTTPRef]; ok { if strings.HasPrefix(param, headerAnnotationPrefix) { headerName := strings.TrimPrefix(param, headerAnnotationPrefix) ms.ResHeaderFields[headerName] = HeaderFieldInfo{ FieldIdentifier: goPrefix + "." + PascalCase(field.Name), IsPointer: !field.Required, } } } return false } walkFieldGroups(fields, visitor) } func (ms *MethodSpec) setClientRequestHeaderFields( funcSpec *compile.FunctionSpec, packageHelper *PackageHelper, ) error { fields := compile.FieldGroup(funcSpec.ArgsSpec) statements := LineBuilder{} var finalError error seenOptStructs := make(map[string]string) itrOrder := make([]string, 0) // Scan for all annotations visitor := func( goPrefix string, thriftPrefix string, field *compile.FieldSpec, ) bool { realType := compile.RootTypeSpec(field.Type) longFieldName := goPrefix + "." + PascalCase(field.Name) // If the type is a struct then we cannot really do anything if _, ok := realType.(*compile.StructSpec); ok { // if a field is a struct then we must do a nil check typeName, err := GoType(packageHelper, realType) if err != nil { finalError = err return true } seenOptStructs[longFieldName] = typeName itrOrder = append(itrOrder, longFieldName) return false } if param, ok := field.Annotations[ms.annotations.HTTPRef]; ok { if strings.HasPrefix(param, headerAnnotationPrefix) { headerName := strings.TrimPrefix(param, headerAnnotationPrefix) bodyIdentifier := goPrefix + "." + PascalCase(field.Name) var headerNameValuePair string if field.Required { // Note header values are always string headerNameValuePair = "headers[%q]= string(r%s)" } else { headerNameValuePair = "headers[%q]= string(*r%s)" } if !field.Required { closeFunction := "" for _, seenStruct := range itrOrder { if strings.HasPrefix(longFieldName, seenStruct) { statements.appendf("if r%s != nil {", seenStruct) closeFunction = closeFunction + "}" } } statements.appendf("if r%s != nil {", bodyIdentifier) statements.appendf(headerNameValuePair, headerName, bodyIdentifier) statements.append("}") statements.append(closeFunction) } else { statements.appendf(headerNameValuePair, headerName, bodyIdentifier, ) } } } return false } walkFieldGroups(fields, visitor) if finalError != nil { return finalError } ms.ReqClientHeaderGoStatements = statements.GetLines() return nil } func (ms *MethodSpec) setHTTPPath(httpPath string, funcSpec *compile.FunctionSpec) { ms.HTTPPath = httpPath segments := strings.Split(httpPath[1:], "/") ms.PathSegments = make([]PathSegment, len(segments)) for i := 0; i < len(segments); i++ { segment := segments[i] if segment == "" || segment[0] != ':' { ms.PathSegments[i].Type = "static" ms.PathSegments[i].Text = segment } else { ms.PathSegments[i].Type = "param" var fieldSelect string var required bool var ok bool fieldSelect, required, ok = ms.findParamsAnnotation( compile.FieldGroup(funcSpec.ArgsSpec), segment, ) if !ok { panic(fmt.Sprintf("cannot find params: %s for http path %s", segment, httpPath)) } ms.PathSegments[i].BodyIdentifier = fieldSelect ms.PathSegments[i].ParamName = segment[1:] ms.PathSegments[i].Required = required } } } func (ms *MethodSpec) setDownstream( clientModule *ModuleSpec, clientThriftService, clientThriftMethod string, ) error { var downstreamService *ServiceSpec for _, service := range clientModule.Services { if service.Name == clientThriftService { downstreamService = service break } } if downstreamService == nil { return errors.Errorf( "Downstream service '%s' is not found in '%s'", clientThriftService, clientModule.ThriftFile, ) } var downstreamMethod *MethodSpec for _, method := range downstreamService.Methods { if method.Name == clientThriftMethod { downstreamMethod = method break } } if downstreamMethod == nil { return errors.Errorf( "\n Downstream method '%s' is not found in '%s'", clientThriftMethod, clientModule.ThriftFile, ) } // Remove irrelevant services and methods. ms.Downstream = clientModule ms.DownstreamService = clientThriftService ms.DownstreamMethod = downstreamMethod return nil } func (ms *MethodSpec) setHeaderPropagator( reqHeaders []string, downstreamSpec *compile.FunctionSpec, headersPropagate map[string]FieldMapperEntry, h *PackageHelper, downstreamMethod *MethodSpec, ) error { downstreamStructType := compile.FieldGroup(downstreamSpec.ArgsSpec) hp := NewHeaderPropagator(h) hp.append( "func propagateHeaders", PascalCase(ms.Name), "ClientRequests(in ", downstreamMethod.RequestType, ", headers zanzibar.Header) ", downstreamMethod.RequestType, "{", ) hp.append("if in == nil {") hp.append(fmt.Sprintf(`in = %s{}`, strings.Replace(downstreamMethod.RequestType, "*", "&", 1))) hp.append("}") err := hp.Propagate(reqHeaders, downstreamStructType, headersPropagate) if err != nil { return err } hp.append("return in") hp.append("}") ms.PropagateHeadersGoStatements = hp.GetLines() return nil } func (ms *MethodSpec) setTypeConverters( funcSpec *compile.FunctionSpec, downstreamSpec *compile.FunctionSpec, reqTransforms map[string]FieldMapperEntry, headersPropagate map[string]FieldMapperEntry, respTransforms map[string]FieldMapperEntry, h *PackageHelper, downstreamMethod *MethodSpec, ) error { // TODO(sindelar): Iterate over fields that are structs (for foo/bar examples). // Add type checking and conversion, custom mapping structType := compile.FieldGroup(funcSpec.ArgsSpec) downstreamStructType := compile.FieldGroup(downstreamSpec.ArgsSpec) typeConverter := NewTypeConverter(h, headersPropagate) typeConverter.append( "func convertTo", PascalCase(ms.Name), "ClientRequest(in ", ms.RequestType, ") ", downstreamMethod.RequestType, "{") typeConverter.append("out := &", downstreamMethod.ShortRequestType, "{}\n") err := typeConverter.GenStructConverter(structType, downstreamStructType, reqTransforms) if err != nil { return err } typeConverter.append("\nreturn out") typeConverter.append("}") ms.ConvertRequestGoStatements = typeConverter.GetLines() // TODO: support non-struct return types respType := funcSpec.ResultSpec.ReturnType downstreamRespType := downstreamMethod.CompiledThriftSpec.ResultSpec.ReturnType if respType == nil || downstreamRespType == nil { return nil } respConverter := NewTypeConverter(h, nil) respConverter.append( "func convert", PascalCase(ms.DownstreamService), PascalCase(ms.Name), "ClientResponse(in ", downstreamMethod.ResponseType, ") ", ms.ResponseType, "{") var respFields, downstreamRespFields []*compile.FieldSpec switch respType.(type) { case *compile.BoolSpec, *compile.I8Spec, *compile.I16Spec, *compile.I32Spec, *compile.EnumSpec, *compile.I64Spec, *compile.DoubleSpec, *compile.StringSpec: respConverter.append("out", " := in\t\n") default: // default as struct respFields = respType.(*compile.StructSpec).Fields downstreamRespFields = downstreamRespType.(*compile.StructSpec).Fields respConverter.append("out", " := ", "&", ms.ShortResponseType, "{}\t\n") err = respConverter.GenStructConverter(downstreamRespFields, respFields, respTransforms) if err != nil { return err } } respConverter.append("\nreturn out \t}") ms.ConvertResponseGoStatements = respConverter.GetLines() return nil } func (ms *MethodSpec) setClientlessTypeConverters( funcSpec *compile.FunctionSpec, reqTransforms map[string]FieldMapperEntry, headersPropagate map[string]FieldMapperEntry, respTransforms map[string]FieldMapperEntry, dummyReqTransforms map[string]FieldMapperEntry, h *PackageHelper, ) error { clientlessConverter := NewTypeConverter(h, nil) respType := funcSpec.ResultSpec.ReturnType clientlessConverter.append( "func convert", PascalCase(ms.Name), "DummyResponse(in ", ms.RequestType, ") ", ms.ResponseType, "{") structType := compile.FieldGroup(funcSpec.ArgsSpec) if respType == nil { return nil } switch respType.(type) { case *compile.BoolSpec, *compile.I8Spec, *compile.I16Spec, *compile.I32Spec, *compile.EnumSpec, *compile.I64Spec, *compile.DoubleSpec, *compile.StringSpec: // TODO: Add support for primitive type by mapping the first field from request to response return errors.Errorf( "clientless endpoints need a complex return type") default: // default as struct respFields := respType.(*compile.StructSpec).Fields clientlessConverter.append("out", " := ", "&", ms.ShortResponseType, "{}\t\n") err := clientlessConverter.GenStructConverter(structType, respFields, dummyReqTransforms) if err != nil { return err } } clientlessConverter.append("\nreturn out \t}") ms.ConvertClientlessRequestGoStatements = clientlessConverter.GetLines() return nil } func getQueryMethodForPrimitiveType(typeSpec compile.TypeSpec) string { var queryMethod string switch typeSpec.(type) { case *compile.BoolSpec: queryMethod = "GetQueryBool" case *compile.I8Spec: queryMethod = "GetQueryInt8" case *compile.I16Spec: queryMethod = "GetQueryInt16" case *compile.I32Spec: queryMethod = "GetQueryInt32" case *compile.I64Spec: queryMethod = "GetQueryInt64" case *compile.DoubleSpec: queryMethod = "GetQueryFloat64" case *compile.StringSpec: queryMethod = "GetQueryValue" case *compile.EnumSpec: queryMethod = "GetQueryValue" default: panic(fmt.Sprintf( "Unsupported type (%T) for %s as query string parameter", typeSpec, typeSpec.ThriftName(), )) } return queryMethod } func getQueryMethodForType(typeSpec compile.TypeSpec) string { var queryMethod string switch t := typeSpec.(type) { case *compile.ListSpec: queryMethod = getQueryMethodForPrimitiveType(compile.RootTypeSpec(t.ValueSpec)) + "List" case *compile.SetSpec: queryMethod = getQueryMethodForPrimitiveType(compile.RootTypeSpec(t.ValueSpec)) + "Set" default: queryMethod = getQueryMethodForPrimitiveType(typeSpec) } return queryMethod } func getQueryEncodeExprPrimitive(typeSpec compile.TypeSpec) string { var encodeExpression string _, isTypedef := typeSpec.(*compile.TypedefSpec) switch compile.RootTypeSpec(typeSpec).(type) { case *compile.BoolSpec: if isTypedef { encodeExpression = "strconv.FormatBool(bool(%s))" } else { encodeExpression = "strconv.FormatBool(%s)" } case *compile.I8Spec, *compile.I16Spec, *compile.I32Spec, *compile.I64Spec: encodeExpression = "strconv.FormatInt(int64(%s), 10)" case *compile.DoubleSpec: if isTypedef { encodeExpression = "strconv.FormatFloat(float64(%s), 'G', -1, 64)" } else { encodeExpression = "strconv.FormatFloat(%s, 'G', -1, 64)" } case *compile.StringSpec: if isTypedef { encodeExpression = "string(%s)" } else { encodeExpression = "%s" } case *compile.EnumSpec: encodeExpression = "(%s).String()" default: // This is intentional -- lets evaluate why we would want other types here before opening the flood gates panic(fmt.Sprintf( "Unsupported type (%T) for %s as query string parameter", typeSpec, typeSpec.ThriftName(), )) } return encodeExpression } func getQueryEncodeExpression(typeSpec compile.TypeSpec, valueName string) string { var encodeExpression string switch t := compile.RootTypeSpec(typeSpec).(type) { case *compile.ListSpec: encodeExpression = getQueryEncodeExprPrimitive(t.ValueSpec) case *compile.SetSpec: encodeExpression = getQueryEncodeExprPrimitive(t.ValueSpec) default: encodeExpression = getQueryEncodeExprPrimitive(typeSpec) } return fmt.Sprintf(encodeExpression, valueName) } // hasQueryParams - checks to see if either this field has a query-param annotation // or if this is a struct, some field in it has. // Caveat is that unannotated fields are considered Query Params IF the REST method // should not have a body (GET). This is an existing convenience afforded to callers func (ms *MethodSpec) hasQueryParams(field *compile.FieldSpec, defaultIsQuery bool) bool { httpRefAnnotation := field.Annotations[ms.annotations.HTTPRef] if strings.HasPrefix(httpRefAnnotation, queryAnnotationPrefix) { return true } // If it is a struct, recursively look to see if any of the fields are query params if container, ok := compile.RootTypeSpec(field.Type).(*compile.StructSpec); ok { visitor := func(goPrefix string, thriftPrefix string, field *compile.FieldSpec) bool { annotation := field.Annotations[ms.annotations.HTTPRef] if strings.HasPrefix(annotation, queryAnnotationPrefix) { return true } return annotation == "" && defaultIsQuery } return walkFieldGroups(container.Fields, visitor) } return httpRefAnnotation == "" && defaultIsQuery } // getContainedQueryParams - finds all query params of interest in this field // In the case of structs, it recursively drills down func (ms *MethodSpec) getContainedQueryParams( field *compile.FieldSpec, defaultIsQuery bool, defaultPrefix string) []string { rval := []string{} myDefaultParam := defaultPrefix + strings.ToLower(field.Name) annotation := field.Annotations[ms.annotations.HTTPRef] if strings.HasPrefix(annotation, queryAnnotationPrefix) { rval = append(rval, strings.TrimPrefix(annotation, queryAnnotationPrefix)) } else if defaultIsQuery && annotation == "" { rval = append(rval, myDefaultParam) } // If it is a struct, look to see if any of the fields are query params if container, ok := compile.RootTypeSpec(field.Type).(*compile.StructSpec); ok { for _, subField := range container.Fields { rval = append(rval, ms.getContainedQueryParams(subField, defaultIsQuery, myDefaultParam+".")...) } } return rval } func (ms *MethodSpec) setWriteQueryParamStatements( funcSpec *compile.FunctionSpec, packageHelper *PackageHelper, hasNoBody bool, ) error { var statements LineBuilder var hasQueryFields bool var stack []string isVoidReturn := funcSpec.ResultSpec.ReturnType == nil visitor := func( goPrefix string, thriftPrefix string, field *compile.FieldSpec, ) bool { // Skip if there are no query params in the field or its components if !ms.hasQueryParams(field, hasNoBody) { return false } if !hasQueryFields { statements.append("queryValues := &url.Values{}") hasQueryFields = true } realType := compile.RootTypeSpec(field.Type) longFieldName := goPrefix + "." + PascalCase(field.Name) if len(stack) > 0 { if !strings.HasPrefix(longFieldName, stack[len(stack)-1]) { stack = stack[:len(stack)-1] statements.append("}") } } if _, ok := realType.(*compile.StructSpec); ok { // If a field is a struct we need to look inside if field.Required { statements.appendf("if r%s == nil {", longFieldName) // Generate correct number of nils... if isVoidReturn { statements.append("\treturn ctx, nil, errors.New(") } else { statements.append("\treturn ctx, nil, nil, errors.New(") } statements.appendf("\t\t\"The field %s is required\",", longFieldName, ) statements.append("\t)") statements.append("}") } else { stack = append(stack, longFieldName) statements.appendf("if r%s != nil {", longFieldName) } return false } longQueryName, shortQueryParam := ms.getQueryParamInfo(field, thriftPrefix) identifierName := CamelCase(longQueryName) + "Query" _, isList := realType.(*compile.ListSpec) _, isSet := realType.(*compile.SetSpec) if field.Required { if isList { encodeExpr := getQueryEncodeExpression(field.Type, "value") statements.appendf("for _, value := range %s {", "r"+longFieldName) statements.appendf("\tqueryValues.Add(\"%s\", %s)", shortQueryParam, encodeExpr) statements.append("}") } else if isSet { encodeExpr := getQueryEncodeExpression(field.Type, "value") statements.appendf("for value := range %s {", "r"+longFieldName) statements.appendf("\tqueryValues.Add(\"%s\", %s)", shortQueryParam, encodeExpr) statements.append("}") } else { encodeExpr := getQueryEncodeExpression(field.Type, "r"+longFieldName) statements.appendf("%s := %s", identifierName, encodeExpr) statements.appendf("queryValues.Set(\"%s\", %s)", shortQueryParam, identifierName) } } else { statements.appendf("if r%s != nil {", longFieldName) if isList { encodeExpr := getQueryEncodeExpression(field.Type, "value") statements.appendf("for _, value := range %s {", "r"+longFieldName) statements.appendf("\tqueryValues.Add(\"%s\", %s)", shortQueryParam, encodeExpr) statements.append("}") } else if isSet { encodeExpr := getQueryEncodeExpression(field.Type, "value") statements.appendf("for value := range %s {", "r"+longFieldName) statements.appendf("\tqueryValues.Add(\"%s\", %s)", shortQueryParam, encodeExpr) statements.append("}") } else { encodeExpr := getQueryEncodeExpression(field.Type, "*r"+longFieldName) statements.appendf("\t%s := %s", identifierName, encodeExpr) statements.appendf("\tqueryValues.Set(\"%s\", %s)", shortQueryParam, identifierName) } statements.append("}") } return false } walkFieldGroups(compile.FieldGroup(funcSpec.ArgsSpec), visitor) for i := 0; i < len(stack); i++ { statements.append("}") } if hasQueryFields { statements.append("fullURL += \"?\" + queryValues.Encode()") } ms.WriteQueryParamGoStatements = statements.GetLines() return nil } // makeUniqIdent appends an integer to the identifier name if there is duplication already // The reason for this is to disambiguate a query param "deviceID" from "device_ID" - yes people did do that func makeUniqIdent(identifier string, seen map[string]int) string { count := seen[identifier] seen[identifier] = count + 1 if count > 0 { return identifier + strconv.Itoa(count) } return identifier } func getCustomType(pkgHelper *PackageHelper, itemType compile.TypeSpec) (string, error) { switch itemType.(type) { case *compile.TypedefSpec, *compile.EnumSpec: return GoType(pkgHelper, itemType) } return "", nil } func (ms *MethodSpec) setParseQueryParamStatements( funcSpec *compile.FunctionSpec, packageHelper *PackageHelper, hasNoBody bool, ) error { // If a thrift field has a http.ref annotation then we // should not read this field from query parameters. var statements LineBuilder var finalError error var stack = []string{} seenIdents := map[string]int{} visitor := func( goPrefix string, thriftPrefix string, field *compile.FieldSpec, ) bool { realType := compile.RootTypeSpec(field.Type) longFieldName := goPrefix + "." + PascalCase(field.Name) longQueryName, shortQueryParam := ms.getQueryParamInfo(field, thriftPrefix) // Skip if there are no query params in the field or its components if !ms.hasQueryParams(field, hasNoBody) { return false } if len(stack) > 0 { if !strings.HasPrefix(longFieldName, stack[len(stack)-1]) { stack = stack[:len(stack)-1] statements.append("}") } } customType, err := getCustomType(packageHelper, field.Type) if err != nil { finalError = err return true } var isList, isSet bool var customElemType string var isEnumElem bool switch t := realType.(type) { // Before you ask -- yes duplicated code because ValueSpec is not defined in the generic interface case *compile.ListSpec: isList = true customElemType, err = getCustomType(packageHelper, t.ValueSpec) if err != nil { finalError = err return true } _, isEnumElem = t.ValueSpec.(*compile.EnumSpec) case *compile.SetSpec: isSet = true customElemType, err = getCustomType(packageHelper, t.ValueSpec) if err != nil { finalError = err return true } _, isEnumElem = t.ValueSpec.(*compile.EnumSpec) case *compile.StructSpec: typeName, err := GoType(packageHelper, realType) if err != nil { finalError = err return true } if !field.Required { stack = append(stack, longFieldName) applicableQueryParams := ms.getContainedQueryParams(field, hasNoBody, "") statements.append("var _queryNeeded bool") statements.appendf("for _, _pfx := range %#v {", applicableQueryParams) statements.append("if _queryNeeded = req.HasQueryPrefix(_pfx); _queryNeeded {") statements.append("break") statements.append("}") statements.append("}") statements.append("if _queryNeeded {") } statements.appendf("if requestBody%s == nil {", longFieldName) statements.appendf("requestBody%s = &%s{}", longFieldName, typeName) statements.append("}") return false } isAggregate := isList || isSet // we do not support maps // For disambiguation of similar names baseIdent := makeUniqIdent(CamelCase(longQueryName), seenIdents) identifierName := baseIdent + "Query" okIdentifierName := baseIdent + "Ok" // make sure value is present if field.Required { statements.appendf("%s := req.CheckQueryValue(%q)", okIdentifierName, shortQueryParam) statements.appendf("if !%s {", okIdentifierName) statements.append("return ctx") statements.append("}") } else { statements.appendf("%s := req.HasQueryValue(%q)", okIdentifierName, shortQueryParam) statements.appendf("if %s {", okIdentifierName) } queryRValue := fmt.Sprintf("req.%s(%q)", getQueryMethodForType(realType), shortQueryParam) // Transform if enum if _, isEnumType := field.Type.(*compile.EnumSpec); isEnumType { statements.appendf("var %s %s", identifierName, customType) tmpVar := "_tmp" + identifierName statements.appendf("%s, ok := %s", tmpVar, queryRValue) statements.append("if ok {") statements.appendf("if err := %s.UnmarshalText([]byte(%s)); err != nil {", identifierName, tmpVar) statements.appendf("req.LogAndSendQueryError(err, %q, %q, %s)", "enum", shortQueryParam, tmpVar) statements.append("ok = false") statements.append("}") statements.append("}") } else { statements.appendf("%s, ok := %s", identifierName, queryRValue) } statements.append("if !ok {") statements.append("return ctx") statements.append("}") target := identifierName // If field is an "aggregate" with custom element types, we need to convert them first // Note that enums and typedefs are what get in here if customElemType != "" { target += "Final" valVar := "v" if isList { statements.appendf( "%s := make([]%s, len(%s))", target, customElemType, identifierName, ) statements.appendf("for i, %s := range %s {", valVar, identifierName) if isEnumElem { tmpVar := "_tmp" + valVar statements.appendf("var %s %s", tmpVar, customElemType) statements.appendf("if err := %s.UnmarshalText([]byte(%s)); err != nil {", tmpVar, valVar) statements.appendf("req.LogAndSendQueryError(err, %q, %q, %s)", "enum", shortQueryParam, valVar) statements.append("return ctx") statements.append("}") valVar = tmpVar } statements.appendf("%s[i] = %s(%s)", target, customElemType, valVar) statements.append("}") } else if isSet { statements.appendf( "%s := make(map[%s]struct{}, len(%s))", target, customElemType, identifierName, ) statements.appendf("for %s := range %s {", valVar, identifierName) if isEnumElem { tmpVar := "_tmp" + valVar statements.appendf("var %s %s", tmpVar, customElemType) statements.appendf("if err := %s.UnmarshalText([]byte(%s)); err != nil {", tmpVar, valVar) statements.appendf("req.LogAndSendQueryError(err, %q, %q, %s)", "enum", shortQueryParam, valVar) statements.append("return ctx") statements.append("}") valVar = tmpVar } statements.appendf("%s[%s(%s)] = struct{}{}", target, customElemType, valVar) statements.append("}") } } var deref string if !field.Required && !isAggregate { deref = "*" targetName := identifierName if customType != "" { targetName = fmt.Sprintf("%s(%s)", strings.ToLower(pointerMethodType(realType)), targetName) } target = fmt.Sprintf("ptr.%s(%s)", pointerMethodType(realType), targetName) } if customType != "" { target = fmt.Sprintf("(%s%s)(%s)", deref, customType, target) } statements.appendf("requestBody%s = %s", longFieldName, target) if !field.Required { statements.append("}") } // new line after block. statements.append("") return false } walkFieldGroups(compile.FieldGroup(funcSpec.ArgsSpec), visitor) for i := 0; i < len(stack); i++ { statements.append("}") } if finalError != nil { return finalError } ms.ParseQueryParamGoStatements = statements.GetLines() return nil } // getQueryParamInfo -- returns the fully-qualified query name and the query param // The query param is what is specified in the annotation if present, otherwise it is the same as the long query name func (ms *MethodSpec) getQueryParamInfo(field *compile.FieldSpec, thriftPrefix string) (string, string) { var longQueryName, queryParam string queryName := field.Name queryAnnotation := field.Annotations[ms.annotations.HTTPRef] if strings.HasPrefix(queryAnnotation, queryAnnotationPrefix) { queryName = strings.TrimPrefix(queryAnnotation, queryAnnotationPrefix) queryParam = queryName } longQueryName = strings.TrimPrefix(thriftPrefix+"."+queryName, ".") // default the short query param to the fully qualified long path if queryParam == "" { queryParam = longQueryName } return longQueryName, queryParam } func (ms *MethodSpec) isRequestBoxed(f *compile.FunctionSpec) bool { boxed, ok := f.Annotations[ms.annotations.HTTPReqDefBoxed] return ok && boxed == "true" } func (ms *MethodSpec) isBodyDisallowed(f *compile.FieldSpec) bool { val, ok := f.Annotations[ms.annotations.HTTPResNoBody] return ok && val == "true" } func headers(annotation string) []string { if annotation == "" { return nil } return strings.Split(annotation, ",") }