internal/apijson/enum.go (63 lines of code) (raw):
package apijson
import (
"fmt"
"reflect"
"sync"
)
/********************/
/* Validating Enums */
/********************/
type validationEntry struct {
field reflect.StructField
nullable bool
legalValues []reflect.Value
}
type validatorFunc func(reflect.Value) exactness
var validators sync.Map
var validationRegistry = map[reflect.Type][]validationEntry{}
func RegisterFieldValidator[T any, V string | bool | int](fieldName string, nullable bool, values ...V) {
var t T
parentType := reflect.TypeOf(t)
if _, ok := validationRegistry[parentType]; !ok {
validationRegistry[parentType] = []validationEntry{}
}
// The following checks run at initialization time,
// it is impossible for them to panic if any tests pass.
if parentType.Kind() != reflect.Struct {
panic(fmt.Sprintf("apijson: cannot initialize validator for non-struct %s", parentType.String()))
}
field, found := parentType.FieldByName(fieldName)
if !found {
panic(fmt.Sprintf("apijson: cannot initialize validator for unknown field %q in %s", fieldName, parentType.String()))
}
newEntry := validationEntry{field, nullable, make([]reflect.Value, len(values))}
for i, value := range values {
newEntry.legalValues[i] = reflect.ValueOf(value)
}
// Store the information necessary to create a validator, so that we can use it
// lazily create the validator function when did.
validationRegistry[parentType] = append(validationRegistry[parentType], newEntry)
}
// Enums are the only types which are validated
func typeValidator(t reflect.Type) validatorFunc {
entry, ok := validationRegistry[t]
if !ok {
return nil
}
if fi, ok := validators.Load(t); ok {
return fi.(validatorFunc)
}
fi, _ := validators.LoadOrStore(t, validatorFunc(func(v reflect.Value) exactness {
return validateEnum(v, entry)
}))
return fi.(validatorFunc)
}
func validateEnum(v reflect.Value, entry []validationEntry) exactness {
if v.Kind() != reflect.Struct {
return loose
}
for _, check := range entry {
field := v.FieldByIndex(check.field.Index)
if !field.IsValid() {
return loose
}
for _, opt := range check.legalValues {
if field.Equal(opt) {
return exact
}
}
}
return loose
}