internal/apiquery/encoder.go (369 lines of code) (raw):

package apiquery import ( "encoding/json" "fmt" "reflect" "strconv" "strings" "sync" "time" internalparam "github.com/openai/openai-go/internal/param" "github.com/openai/openai-go/packages/param" ) var encoders sync.Map // map[reflect.Type]encoderFunc type encoder struct { dateFormat string root bool settings QuerySettings } type encoderFunc func(key string, value reflect.Value) ([]Pair, error) type encoderField struct { tag parsedStructTag fn encoderFunc idx []int } type encoderEntry struct { reflect.Type dateFormat string root bool settings QuerySettings } type Pair struct { key string value string } func (e *encoder) typeEncoder(t reflect.Type) encoderFunc { entry := encoderEntry{ Type: t, dateFormat: e.dateFormat, root: e.root, settings: e.settings, } if fi, ok := encoders.Load(entry); ok { return fi.(encoderFunc) } // To deal with recursive types, populate the map with an // indirect func before we build it. This type waits on the // real func (f) to be ready and then calls it. This indirect // func is only used for recursive types. var ( wg sync.WaitGroup f encoderFunc ) wg.Add(1) fi, loaded := encoders.LoadOrStore(entry, encoderFunc(func(key string, v reflect.Value) ([]Pair, error) { wg.Wait() return f(key, v) })) if loaded { return fi.(encoderFunc) } // Compute the real encoder and replace the indirect func with it. f = e.newTypeEncoder(t) wg.Done() encoders.Store(entry, f) return f } func marshalerEncoder(key string, value reflect.Value) ([]Pair, error) { s, err := value.Interface().(json.Marshaler).MarshalJSON() if err != nil { return nil, fmt.Errorf("apiquery: json fallback marshal error %s", err) } return []Pair{{key, string(s)}}, nil } func (e *encoder) newTypeEncoder(t reflect.Type) encoderFunc { if t.ConvertibleTo(reflect.TypeOf(time.Time{})) { return e.newTimeTypeEncoder(t) } if t.Implements(reflect.TypeOf((*param.Optional)(nil)).Elem()) { return e.newRichFieldTypeEncoder(t) } if !e.root && t.Implements(reflect.TypeOf((*json.Marshaler)(nil)).Elem()) { return marshalerEncoder } e.root = false switch t.Kind() { case reflect.Pointer: encoder := e.typeEncoder(t.Elem()) return func(key string, value reflect.Value) (pairs []Pair, err error) { if !value.IsValid() || value.IsNil() { return } return encoder(key, value.Elem()) } case reflect.Struct: return e.newStructTypeEncoder(t) case reflect.Array: fallthrough case reflect.Slice: return e.newArrayTypeEncoder(t) case reflect.Map: return e.newMapEncoder(t) case reflect.Interface: return e.newInterfaceEncoder() default: return e.newPrimitiveTypeEncoder(t) } } func (e *encoder) newStructTypeEncoder(t reflect.Type) encoderFunc { if t.Implements(reflect.TypeOf((*internalparam.FieldLike)(nil)).Elem()) { return e.newFieldTypeEncoder(t) } if t.Implements(reflect.TypeOf((*param.Optional)(nil)).Elem()) { return e.newRichFieldTypeEncoder(t) } for i := 0; i < t.NumField(); i++ { if t.Field(i).Type == paramUnionType && t.Field(i).Anonymous { return e.newStructUnionTypeEncoder(t) } } encoderFields := []encoderField{} // This helper allows us to recursively collect field encoders into a flat // array. The parameter `index` keeps track of the access patterns necessary // to get to some field. var collectEncoderFields func(r reflect.Type, index []int) collectEncoderFields = func(r reflect.Type, index []int) { for i := 0; i < r.NumField(); i++ { idx := append(index, i) field := t.FieldByIndex(idx) if !field.IsExported() { continue } // If this is an embedded struct, traverse one level deeper to extract // the field and get their encoders as well. if field.Anonymous { collectEncoderFields(field.Type, idx) continue } // If query tag is not present, then we skip, which is intentionally // different behavior from the stdlib. ptag, ok := parseQueryStructTag(field) if !ok { continue } if (ptag.name == "-" || ptag.name == "") && !ptag.inline { continue } dateFormat, ok := parseFormatStructTag(field) oldFormat := e.dateFormat if ok { switch dateFormat { case "date-time": e.dateFormat = time.RFC3339 case "date": e.dateFormat = "2006-01-02" } } var encoderFn encoderFunc if ptag.omitzero { typeEncoderFn := e.typeEncoder(field.Type) encoderFn = func(key string, value reflect.Value) ([]Pair, error) { if value.IsZero() { return nil, nil } return typeEncoderFn(key, value) } } else { encoderFn = e.typeEncoder(field.Type) } encoderFields = append(encoderFields, encoderField{ptag, encoderFn, idx}) e.dateFormat = oldFormat } } collectEncoderFields(t, []int{}) return func(key string, value reflect.Value) (pairs []Pair, err error) { for _, ef := range encoderFields { var subkey string = e.renderKeyPath(key, ef.tag.name) if ef.tag.inline { subkey = key } field := value.FieldByIndex(ef.idx) subpairs, suberr := ef.fn(subkey, field) if suberr != nil { err = suberr } pairs = append(pairs, subpairs...) } return } } var paramUnionType = reflect.TypeOf((*param.APIUnion)(nil)).Elem() func (e *encoder) newStructUnionTypeEncoder(t reflect.Type) encoderFunc { var fieldEncoders []encoderFunc for i := 0; i < t.NumField(); i++ { field := t.Field(i) if field.Type == paramUnionType && field.Anonymous { fieldEncoders = append(fieldEncoders, nil) continue } fieldEncoders = append(fieldEncoders, e.typeEncoder(field.Type)) } return func(key string, value reflect.Value) (pairs []Pair, err error) { for i := 0; i < t.NumField(); i++ { if value.Field(i).Type() == paramUnionType { continue } if !value.Field(i).IsZero() { return fieldEncoders[i](key, value.Field(i)) } } return nil, fmt.Errorf("apiquery: union %s has no field set", t.String()) } } func (e *encoder) newMapEncoder(t reflect.Type) encoderFunc { keyEncoder := e.typeEncoder(t.Key()) elementEncoder := e.typeEncoder(t.Elem()) return func(key string, value reflect.Value) (pairs []Pair, err error) { iter := value.MapRange() for iter.Next() { encodedKey, err := keyEncoder("", iter.Key()) if err != nil { return nil, err } if len(encodedKey) != 1 { return nil, fmt.Errorf("apiquery: unexpected number of parts for encoded map key, map may contain non-primitive") } subkey := encodedKey[0].value keyPath := e.renderKeyPath(key, subkey) subpairs, suberr := elementEncoder(keyPath, iter.Value()) if suberr != nil { err = suberr } pairs = append(pairs, subpairs...) } return } } func (e *encoder) renderKeyPath(key string, subkey string) string { if len(key) == 0 { return subkey } if e.settings.NestedFormat == NestedQueryFormatDots { return fmt.Sprintf("%s.%s", key, subkey) } return fmt.Sprintf("%s[%s]", key, subkey) } func (e *encoder) newArrayTypeEncoder(t reflect.Type) encoderFunc { switch e.settings.ArrayFormat { case ArrayQueryFormatComma: innerEncoder := e.typeEncoder(t.Elem()) return func(key string, v reflect.Value) ([]Pair, error) { elements := []string{} for i := 0; i < v.Len(); i++ { innerPairs, err := innerEncoder("", v.Index(i)) if err != nil { return nil, err } for _, pair := range innerPairs { elements = append(elements, pair.value) } } if len(elements) == 0 { return []Pair{}, nil } return []Pair{{key, strings.Join(elements, ",")}}, nil } case ArrayQueryFormatRepeat: innerEncoder := e.typeEncoder(t.Elem()) return func(key string, value reflect.Value) (pairs []Pair, err error) { for i := 0; i < value.Len(); i++ { subpairs, suberr := innerEncoder(key, value.Index(i)) if suberr != nil { err = suberr } pairs = append(pairs, subpairs...) } return } case ArrayQueryFormatIndices: panic("The array indices format is not supported yet") case ArrayQueryFormatBrackets: innerEncoder := e.typeEncoder(t.Elem()) return func(key string, value reflect.Value) (pairs []Pair, err error) { pairs = []Pair{} for i := 0; i < value.Len(); i++ { subpairs, suberr := innerEncoder(key+"[]", value.Index(i)) if suberr != nil { err = suberr } pairs = append(pairs, subpairs...) } return } default: panic(fmt.Sprintf("Unknown ArrayFormat value: %d", e.settings.ArrayFormat)) } } func (e *encoder) newPrimitiveTypeEncoder(t reflect.Type) encoderFunc { switch t.Kind() { case reflect.Pointer: inner := t.Elem() innerEncoder := e.newPrimitiveTypeEncoder(inner) return func(key string, v reflect.Value) ([]Pair, error) { if !v.IsValid() || v.IsNil() { return nil, nil } return innerEncoder(key, v.Elem()) } case reflect.String: return func(key string, v reflect.Value) ([]Pair, error) { return []Pair{{key, v.String()}}, nil } case reflect.Bool: return func(key string, v reflect.Value) ([]Pair, error) { if v.Bool() { return []Pair{{key, "true"}}, nil } return []Pair{{key, "false"}}, nil } case reflect.Int, reflect.Int16, reflect.Int32, reflect.Int64: return func(key string, v reflect.Value) ([]Pair, error) { return []Pair{{key, strconv.FormatInt(v.Int(), 10)}}, nil } case reflect.Uint, reflect.Uint16, reflect.Uint32, reflect.Uint64: return func(key string, v reflect.Value) ([]Pair, error) { return []Pair{{key, strconv.FormatUint(v.Uint(), 10)}}, nil } case reflect.Float32, reflect.Float64: return func(key string, v reflect.Value) ([]Pair, error) { return []Pair{{key, strconv.FormatFloat(v.Float(), 'f', -1, 64)}}, nil } case reflect.Complex64, reflect.Complex128: bitSize := 64 if t.Kind() == reflect.Complex128 { bitSize = 128 } return func(key string, v reflect.Value) ([]Pair, error) { return []Pair{{key, strconv.FormatComplex(v.Complex(), 'f', -1, bitSize)}}, nil } default: return func(key string, v reflect.Value) ([]Pair, error) { return nil, nil } } } func (e *encoder) newFieldTypeEncoder(t reflect.Type) encoderFunc { f, _ := t.FieldByName("Value") enc := e.typeEncoder(f.Type) return func(key string, value reflect.Value) ([]Pair, error) { present := value.FieldByName("Present") if !present.Bool() { return nil, nil } null := value.FieldByName("Null") if null.Bool() { return nil, fmt.Errorf("apiquery: field cannot be null") } raw := value.FieldByName("Raw") if !raw.IsNil() { return e.typeEncoder(raw.Type())(key, raw) } return enc(key, value.FieldByName("Value")) } } func (e *encoder) newTimeTypeEncoder(_ reflect.Type) encoderFunc { format := e.dateFormat return func(key string, value reflect.Value) ([]Pair, error) { return []Pair{{ key, value.Convert(reflect.TypeOf(time.Time{})).Interface().(time.Time).Format(format), }}, nil } } func (e encoder) newInterfaceEncoder() encoderFunc { return func(key string, value reflect.Value) ([]Pair, error) { value = value.Elem() if !value.IsValid() { return nil, nil } return e.typeEncoder(value.Type())(key, value) } }