validator.go (371 lines of code) (raw):

// Licensed to Elasticsearch B.V. under one or more contributor // license agreements. See the NOTICE file distributed with // this work for additional information regarding copyright // ownership. Elasticsearch B.V. licenses this file to you under // the Apache License, Version 2.0 (the "License"); you may // not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, // software distributed under the License is distributed on an // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY // KIND, either express or implied. See the License for the // specific language governing permissions and limitations // under the License. package ucfg import ( "fmt" "reflect" "regexp" "strconv" "strings" "time" ) // Validator interface provides additional validation support to Unpack. The // Validate method will be executed for any type passed directly or indirectly to // Unpack. // // If Validate fails with an error message, Unpack will add some // context - like setting being accessed and file setting was read from - to the // error message before returning the actual error. type Validator interface { Validate() error } // ValidatorCallback is the type of optional validator tags to be registered via // RegisterValidator. type ValidatorCallback func(interface{}, string) error type validatorTag struct { name string cb ValidatorCallback param string } var ( validators = map[string]ValidatorCallback{} ) func init() { initRegisterValidator("nonzero", validateNonZero) initRegisterValidator("positive", validatePositive) initRegisterValidator("min", validateMin) initRegisterValidator("max", validateMax) initRegisterValidator("required", validateRequired) } func initRegisterValidator(name string, cb ValidatorCallback) { if err := RegisterValidator(name, cb); err != nil { panic("Duplicate validator: " + name) } } // RegisterValidator adds a new validator option to the "validate" struct tag. // The callback will be executed when unpacking into a struct field. func RegisterValidator(name string, cb ValidatorCallback) error { if _, exists := validators[name]; exists { return ErrDuplicateValidator } validators[name] = cb return nil } func parseValidatorTags(tag string) ([]validatorTag, error) { if tag == "" { return nil, nil } lst := strings.Split(tag, ",") if len(lst) == 0 { return nil, nil } tags := make([]validatorTag, 0, len(lst)) for _, cfg := range lst { v := strings.SplitN(cfg, "=", 2) name := strings.Trim(v[0], " \t\r\n") cb := validators[name] if cb == nil { return nil, fmt.Errorf("unknown validator '%v'", name) } param := "" if len(v) == 2 { param = strings.Trim(v[1], " \t\r\n") } tags = append(tags, validatorTag{name: name, cb: cb, param: param}) } return tags, nil } func tryValidate(val reflect.Value) error { t := val.Type() var validator Validator if (t.Kind() == reflect.Ptr || t.Kind() == reflect.Interface) && val.IsNil() { return nil } if t.Implements(tValidator) { validator = val.Interface().(Validator) } else if reflect.PtrTo(t).Implements(tValidator) { val = pointerize(reflect.PtrTo(t), t, val) validator = val.Interface().(Validator) } if validator == nil { return nil } return validator.Validate() } func runValidators(val interface{}, validators []validatorTag) error { if validators == nil { return nil } for _, tag := range validators { if err := tag.cb(val, tag.param); err != nil { return err } } return nil } func tryRecursiveValidate(val reflect.Value, opts *options, validators []validatorTag) error { var curr interface{} if val.IsValid() { curr = val.Interface() } if err := runValidators(curr, validators); err != nil { return err } if !val.IsValid() { return nil } t := val.Type() if (t.Kind() == reflect.Ptr || t.Kind() == reflect.Interface) && val.IsNil() { return nil } var err error switch chaseValue(val).Kind() { case reflect.Struct: err = validateStruct(val, opts) case reflect.Map: err = validateMap(val, opts) case reflect.Array, reflect.Slice: err = validateArray(val, opts) } if err != nil { return err } return tryValidate(val) } func validateStruct(val reflect.Value, opts *options) error { val = chaseValue(val) numField := val.NumField() for i := 0; i < numField; i++ { fInfo, skip, err := accessField(val, i, opts) if err != nil { return err } if skip { continue } if err := tryRecursiveValidate(fInfo.value, fInfo.options, fInfo.validatorTags); err != nil { return err } } return nil } func validateMap(val reflect.Value, opts *options) error { for _, key := range val.MapKeys() { if err := tryRecursiveValidate(val.MapIndex(key), opts, nil); err != nil { return err } } return nil } func validateArray(val reflect.Value, opts *options) error { for i := 0; i < val.Len(); i++ { if err := tryRecursiveValidate(val.Index(i), opts, nil); err != nil { return err } } return nil } // validateNonZero implements the `nonzero` validation tag. // If nonzero is set, the validator is only run if field is present in config. // It checks for numbers and durations to be != 0, and for strings/arrays/slices // not being empty. func validateNonZero(v interface{}, name string) error { if v == nil { return nil } if d, ok := v.(time.Duration); ok { if d == 0 { return ErrZeroValue } return nil } val := chaseValue(reflect.ValueOf(v)) switch val.Kind() { case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: if val.Int() != 0 { return nil } case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: if val.Uint() != 0 { return nil } case reflect.Float32, reflect.Float64: if val.Float() != 0 { return nil } default: return validateNonEmpty(v, name) } return ErrZeroValue } func validatePositive(v interface{}, _ string) error { if v == nil { return nil } if d, ok := v.(time.Duration); ok { if d < 0 { return ErrNegative } return nil } val := reflect.ValueOf(v) switch val.Kind() { case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: if val.Int() >= 0 { return nil } case reflect.Float32, reflect.Float64: if val.Float() >= 0 { return nil } default: return nil } return ErrNegative } func validateMin(v interface{}, param string) error { if v == nil { return nil } if d, ok := v.(time.Duration); ok { min, err := param2Duration(param) if err != nil { return err } if min > d { return fmt.Errorf("requires duration >= %v", param) } return nil } val := reflect.ValueOf(v) switch val.Kind() { case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: min, err := strconv.ParseInt(param, 0, 64) if err != nil { return err } if val.Int() >= min { return nil } case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: min, err := strconv.ParseUint(param, 0, 64) if err != nil { return err } if val.Uint() >= min { return nil } case reflect.Float32, reflect.Float64: min, err := strconv.ParseFloat(param, 64) if err != nil { return err } if val.Float() >= min { return nil } default: return nil } return fmt.Errorf("requires value >= %v", param) } func validateMax(v interface{}, param string) error { if v == nil { return nil } if d, ok := v.(time.Duration); ok { max, err := param2Duration(param) if err != nil { return err } if max < d { return fmt.Errorf("requires duration <= %v", param) } return nil } val := reflect.ValueOf(v) switch val.Kind() { case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: max, err := strconv.ParseInt(param, 0, 64) if err != nil { return err } if val.Int() <= max { return nil } case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: max, err := strconv.ParseUint(param, 0, 64) if err != nil { return err } if val.Uint() <= max { return nil } case reflect.Float32, reflect.Float64: max, err := strconv.ParseFloat(param, 64) if err != nil { return err } if val.Float() <= max { return nil } default: return nil } return fmt.Errorf("requires value <= %v", param) } // validateRequired implements the `required` validation tag. // If a field is required, it must be present in the config. // If field is a string, regex or slice its length must be > 0. func validateRequired(v interface{}, name string) error { if v == nil { return ErrRequired } val := reflect.ValueOf(v) if val.Kind() == reflect.Ptr && val.IsNil() { return ErrRequired } if isInt(val.Kind()) || isUint(val.Kind()) || isFloat(val.Kind()) { if err := validateNonZero(v, name); err != nil { return ErrRequired } return nil } if err := validateNonEmptyWithAllowNil(v, name, false); err != nil { return err } return nil } func validateNonEmpty(v interface{}, name string) error { return validateNonEmptyWithAllowNil(v, name, true) } func validateNonEmptyWithAllowNil(v interface{}, _ string, allowNil bool) error { if s, ok := v.(string); ok { if s == "" { return ErrStringEmpty } return nil } if r, ok := v.(regexp.Regexp); ok { if r.String() == "" { return ErrRegexEmpty } return nil } val := reflect.ValueOf(v) if val.Kind() == reflect.Array || val.Kind() == reflect.Slice { if val.IsNil() { if allowNil { return nil } return ErrRequired } if val.Len() == 0 { return ErrArrayEmpty } return nil } if val.Kind() == reflect.Map { if val.IsNil() { if allowNil { return nil } return ErrRequired } if val.Len() == 0 { return ErrMapEmpty } return nil } return nil } func param2Duration(param string) (time.Duration, error) { d, err := time.ParseDuration(param) if err == nil { return d, err } tmp, floatErr := strconv.ParseFloat(param, 64) if floatErr != nil { return 0, err } return time.Duration(tmp * float64(time.Second)), nil }