input/elasticapm/internal/modeldecoder/generator/validation.go (150 lines of code) (raw):

// Licensed to Elasticsearch B.V. under one or more contributor // license agreements. See the NOTICE file distributed with // this work for additional information regarding copyright // ownership. Elasticsearch B.V. licenses this file to you under // the Apache License, Version 2.0 (the "License"); you may // not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, // software distributed under the License is distributed on an // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY // KIND, either express or implied. See the License for the // specific language governing permissions and limitations // under the License. package generator import ( "fmt" "io" "reflect" "sort" "strings" ) const ( tagEnum = "enum" tagInputTypes = "inputTypes" tagInputTypesVals = "inputTypesVals" tagMax = "max" tagMaxLength = "maxLength" tagMaxLengthVals = "maxLengthVals" tagMin = "min" tagMinLength = "minLength" tagMinVals = "minVals" tagPattern = "pattern" tagPatternKeys = "patternKeys" tagRequired = "required" tagRequiredAnyOf = "requiredAnyOf" tagRequiredIfAny = "requiredIfAny" tagTargetType = "targetType" ) type validationRule struct { name string value string } func errUnhandledTagRule(rule validationRule) error { return fmt.Errorf("unhandled tag rule '%s'", rule.name) } func validationTag(structTag reflect.StructTag) (map[string]string, error) { parts := parseTag(structTag, "validate") m := make(map[string]string, len(parts)) errPrefix := "parse validation tag:" for _, rule := range parts { parts := strings.Split(rule, "=") switch len(parts) { case 1: // valueless rule e.g. required if rule != parts[0] { return nil, fmt.Errorf("%s malformed tag '%s'", errPrefix, rule) } switch rule { case tagRequired: m[rule] = "" default: return nil, fmt.Errorf("%s unhandled tag rule '%s'", errPrefix, rule) } case 2: // rule=value m[parts[0]] = parts[1] default: return nil, fmt.Errorf("%s malformed tag '%s'", errPrefix, rule) } } return m, nil } func validationRules(structTag reflect.StructTag) ([]validationRule, error) { tag, err := validationTag(structTag) if err != nil { return nil, err } var rules = make([]validationRule, 0, len(tag)) for k, v := range tag { rules = append(rules, validationRule{name: k, value: v}) } sort.Slice(rules, func(i, j int) bool { return rules[i].name < rules[j].name }) return rules, nil } func ruleMinMaxOperator(ruleName string) string { switch ruleName { case tagMin, tagMinLength, tagMinVals: return "<" case tagMax, tagMaxLength: return ">" default: panic("unexpected rule: " + ruleName) } } // // common validation rules independend of type // func ruleNullableRequired(w io.Writer, f structField) { fmt.Fprintf(w, ` if !val.%s.IsSet() { return fmt.Errorf("'%s' required") } `[1:], f.Name(), jsonName(f)) } func ruleRequiredOneOf(w io.Writer, fields []structField, tagValue string) error { oneOf, err := filteredFields(fields, strings.Split(tagValue, ";")) if err != nil { return err } if len(oneOf) <= 1 { return fmt.Errorf("invalid usage of rule 'requiredOneOf' - try 'required' instead") } fmt.Fprintf(w, `if `) for i, oneOfField := range oneOf { if i > 0 { fmt.Fprintf(w, " && ") } fmt.Fprint(w, "!") if err := generateIsSet(w, oneOfField, "val."); err != nil { return err } } fmt.Fprintf(w, ` { return fmt.Errorf("requires at least one of the fields '%v'") } `[1:], tagValue) return nil } func ruleRequiredIfAny(w io.Writer, fields []structField, field structField, tagValue string) error { ifAny, err := filteredFields(fields, strings.Split(tagValue, ";")) if err != nil { return err } // Only check ifAny fields if the field itself is not set fmt.Fprint(w, "if !") if err := generateIsSet(w, field, "val."); err != nil { return err } fmt.Fprintln(w, " {") // Check if any of the fields is set. We create a separate "if" block // for each field so we can include its name in the error. for _, ifAnyField := range ifAny { fmt.Fprint(w, "if ") if err := generateIsSet(w, ifAnyField, "val."); err != nil { return err } fmt.Fprintf(w, ` { return fmt.Errorf("'%s' required when '%s' is set") } `, jsonName(field), jsonName(ifAnyField)) } fmt.Fprintln(w, "}") return nil } func filteredFields(fields []structField, jsonNames []string) ([]structField, error) { mapped := make(map[string]structField) for _, field := range fields { mapped[jsonName(field)] = field } filtered := make([]structField, len(jsonNames)) for i, jsonName := range jsonNames { field, ok := mapped[jsonName] if !ok { return nil, fmt.Errorf("unknown field name %q", jsonName) } filtered[i] = field } return filtered, nil }