codegen/header_propagate.go (149 lines of code) (raw):
// Copyright (c) 2023 Uber Technologies, Inc.
//
// Permission is hereby granted, free of charge, to any person obtaining a copy
// of this software and associated documentation files (the "Software"), to deal
// in the Software without restriction, including without limitation the rights
// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
// copies of the Software, and to permit persons to whom the Software is
// furnished to do so, subject to the following conditions:
//
// The above copyright notice and this permission notice shall be included in
// all copies or substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
// THE SOFTWARE.
package codegen
import (
"fmt"
"sort"
"strings"
"github.com/pkg/errors"
"go.uber.org/thriftrw/compile"
)
// HeaderPropagator generates function propagates endpoint request
// headers to client request body
type HeaderPropagator struct {
LineBuilder
Helper PackageNameResolver
}
// NewHeaderPropagator returns an instance of HeaderPropagator
func NewHeaderPropagator(h PackageNameResolver) *HeaderPropagator {
return &HeaderPropagator{
LineBuilder: LineBuilder{},
Helper: h,
}
}
// Propagate assigns header value to downstream client request fields
// based on fieldMap
func (hp *HeaderPropagator) Propagate(
headers []string,
toFields []*compile.FieldSpec,
fieldMap map[string]FieldMapperEntry,
) error {
sortedKeys := make([]string, len(fieldMap))
i := 0
for key := range fieldMap {
sortedKeys[i] = key
i++
}
sort.Strings(sortedKeys)
for _, key := range sortedKeys {
val := fieldMap[key]
field, err := findField(key, toFields)
if err != nil {
return err
}
gotype, err := GoType(hp.Helper, field.Type)
if err != nil {
return errors.Errorf("invalid: trying to assign header %s to non-string field in %s",
val.QualifiedName, field.Name)
}
hp.appendf(`if key, ok := headers.Get("%s"); ok {`, val.QualifiedName)
// patch optional params along the path
if err := hp.initNilOpt(key, toFields); err != nil {
return err
}
arrs := typeSwitch(key, gotype, field)
hp.append(arrs...)
hp.append("}")
}
return nil
}
// typeSwitch supports primary type parsing for headers
func typeSwitch(key, gotype string, field *compile.FieldSpec) []string {
var (
ret = []string{}
typeParse string
typeCast string
assignVal = "key"
)
switch gotype {
case "int8":
panic(fmt.Sprintf("type byte is note supported for field %q", field.Name))
case "bool":
typeParse = "strconv.ParseBool(key)"
assignVal = "v"
case "int16":
typeParse = "strconv.ParseInt(key, 10, 16)"
assignVal = "val"
typeCast = "val := int16(v)\n"
case "int32":
typeParse = "strconv.ParseInt(key, 10, 32)"
assignVal = "val"
typeCast = "val := int32(v)\n"
case "int64":
typeParse = "strconv.ParseInt(key, 10, 64)"
assignVal = "v"
case "float64":
typeParse = "strconv.ParseFloat(key, 64)"
assignVal = "v"
case "string":
default:
typeCast = "val := " + gotype + "(key)\n"
assignVal = "val"
}
if len(typeParse) > 0 {
ret = append(ret, fmt.Sprintf("if v, err := %s; err == nil {\n", typeParse))
}
if len(typeCast) > 0 {
ret = append(ret, typeCast)
}
if !field.Required {
ret = append(ret, fmt.Sprintf("in.%s = &%s\n", key, assignVal))
} else {
ret = append(ret, fmt.Sprintf("in.%s = %s\n", key, assignVal))
}
if len(typeParse) > 0 {
ret = append(ret, "}\n")
}
return ret
}
// init optional field that could be nil on field assign path
func (hp *HeaderPropagator) initNilOpt(path string, toFields []*compile.FieldSpec) error {
initChecks := getMiddleIdentifiers(path)
if len(initChecks) < 2 {
return nil
}
initChecks = initChecks[:len(initChecks)-1]
for _, p := range initChecks {
f, err := findField(p, toFields)
if err != nil {
return err
}
ftype := f.Type
t, err := GoCustomType(hp.Helper, ftype)
if err != nil {
return errors.Wrapf(
err,
"could not lookup fieldType when building converter for %s",
ftype.ThriftName(),
)
}
hp.appendf("if in.%s == nil {", p)
hp.appendf("in.%s = &%s{}", p, t)
hp.append("}")
}
return nil
}
func findField(fieldPath string, toFields []*compile.FieldSpec) (*compile.FieldSpec, error) {
currPath := strings.Split(fieldPath, ".")
currFields := toFields
missErr := errors.Errorf("could not find field path in client request %s", fieldPath)
for len(currFields) > 0 && len(currPath) > 0 {
prevPath := []string(currPath)
currPos := currPath[0]
for _, v := range currFields {
if strings.ToLower(v.Name) == strings.ToLower(currPos) {
if len(currPath) == 1 {
return v, nil
}
currPath = currPath[1:]
t := v.Type.(*compile.StructSpec)
currFields = t.Fields
break
}
}
if len(prevPath) == len(currPath) {
return nil, missErr
}
}
return nil, missErr
}