internal/apijson/union.go (86 lines of code) (raw):
package apijson
import (
"errors"
"github.com/openai/openai-go/packages/param"
"reflect"
"github.com/tidwall/gjson"
)
func isEmbeddedUnion(t reflect.Type) bool {
var apiunion param.APIUnion
for i := 0; i < t.NumField(); i++ {
if t.Field(i).Type == reflect.TypeOf(apiunion) && t.Field(i).Anonymous {
return true
}
}
return false
}
func RegisterDiscriminatedUnion[T any](key string, mappings map[string]reflect.Type) {
var t T
entry := unionEntry{
discriminatorKey: key,
variants: []UnionVariant{},
}
for k, typ := range mappings {
entry.variants = append(entry.variants, UnionVariant{
DiscriminatorValue: k,
Type: typ,
})
}
unionRegistry[reflect.TypeOf(t)] = entry
}
func (d *decoderBuilder) newEmbeddedUnionDecoder(t reflect.Type) decoderFunc {
decoders := []decoderFunc{}
for i := 0; i < t.NumField(); i++ {
variant := t.Field(i)
decoder := d.typeDecoder(variant.Type)
decoders = append(decoders, decoder)
}
unionEntry := unionEntry{
variants: []UnionVariant{},
}
return func(n gjson.Result, v reflect.Value, state *decoderState) error {
// If there is a discriminator match, circumvent the exactness logic entirely
for idx, variant := range unionEntry.variants {
decoder := decoders[idx]
if variant.TypeFilter != n.Type {
continue
}
if len(unionEntry.discriminatorKey) != 0 {
discriminatorValue := n.Get(unionEntry.discriminatorKey).Value()
if discriminatorValue == variant.DiscriminatorValue {
inner := reflect.New(variant.Type).Elem()
err := decoder(n, inner, state)
v.Set(inner)
return err
}
}
}
// Set bestExactness to worse than loose
bestExactness := loose - 1
for idx, variant := range unionEntry.variants {
decoder := decoders[idx]
if variant.TypeFilter != n.Type {
continue
}
sub := decoderState{strict: state.strict, exactness: exact}
inner := reflect.New(variant.Type).Elem()
err := decoder(n, inner, &sub)
if err != nil {
continue
}
if sub.exactness == exact {
v.Set(inner)
return nil
}
if sub.exactness > bestExactness {
v.Set(inner)
bestExactness = sub.exactness
}
}
if bestExactness < loose {
return errors.New("apijson: was not able to coerce type as union")
}
if guardStrict(state, bestExactness != exact) {
return errors.New("apijson: was not able to coerce type as union strictly")
}
return nil
}
}