internal/api/validate.go (276 lines of code) (raw):

// Copyright 2025 Microsoft Corporation // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package api import ( "crypto/x509" "fmt" "net/http" "reflect" "strings" "unicode" "unicode/utf8" azcorearm "github.com/Azure/azure-sdk-for-go/sdk/azcore/arm" validator "github.com/go-playground/validator/v10" k8svalidation "k8s.io/apimachinery/pkg/util/validation" "github.com/Azure/ARO-HCP/internal/api/arm" ) // GetJSONTagName extracts the JSON field name from the "json" key in // a struct tag. Returns an empty string if no "json" key is present, // or if the value is "-". func GetJSONTagName(tag reflect.StructTag) string { tagValue := tag.Get("json") if tagValue == "-" { return "" } fieldName, _, _ := strings.Cut(tagValue, ",") return fieldName } // EnumValidateTag generates a string suitable for use with the "validate" // struct tag. The intent is to convert a set of valid values for a string // subtype into a "oneof=" expression for the purpose of static validation. func EnumValidateTag[S ~string](values ...S) string { s := make([]string, len(values)) for i, e := range values { s[i] = string(e) // Replace special characters with the UTF-8 hex representation. // https://pkg.go.dev/github.com/go-playground/validator/v10#hdr-Using_Validator_Tags s[i] = strings.ReplaceAll(s[i], ",", "0x2C") s[i] = strings.ReplaceAll(s[i], "|", "0x7C") } return fmt.Sprintf("oneof=%s", strings.Join(s, " ")) } func NewValidator() *validator.Validate { var err error validate := validator.New(validator.WithRequiredStructEnabled()) // Use "json" struct tags for alternate field names. // Alternate field names will be used in validation errors. validate.RegisterTagNameFunc(func(field reflect.StructField) string { return GetJSONTagName(field.Tag) }) // Register ARM-mandated enumeration types. validate.RegisterAlias("enum_managedserviceidentitytype", EnumValidateTag( arm.ManagedServiceIdentityTypeNone, arm.ManagedServiceIdentityTypeSystemAssigned, arm.ManagedServiceIdentityTypeSystemAssignedUserAssigned, arm.ManagedServiceIdentityTypeUserAssigned)) validate.RegisterAlias("enum_subscriptionstate", EnumValidateTag( arm.SubscriptionStateRegistered, arm.SubscriptionStateUnregistered, arm.SubscriptionStateWarned, arm.SubscriptionStateDeleted, arm.SubscriptionStateSuspended)) // Use this for string fields specifying an ARO-HCP API version. err = validate.RegisterValidation("api_version", func(fl validator.FieldLevel) bool { field := fl.Field() if field.Kind() != reflect.String { panic("String type required for api_version") } _, ok := Lookup(field.String()) return ok }) if err != nil { panic(err) } // Use this for string fields that must be a valid Kubernetes qualified name. err = validate.RegisterValidation("k8s_qualified_name", func(fl validator.FieldLevel) bool { field := fl.Field() if field.Kind() != reflect.String { panic("String type required for k8s_qualified_name") } return len(k8svalidation.IsQualifiedName(field.String())) == 0 }) if err != nil { panic(err) } // Use this for string fields that must be a valid Kubernetes label value. err = validate.RegisterValidation("k8s_label_value", func(fl validator.FieldLevel) bool { field := fl.Field() if field.Kind() != reflect.String { panic("String type required for k8s_label_value") } return len(k8svalidation.IsValidLabelValue(field.String())) == 0 }) if err != nil { panic(err) } // Use this for version ID fields that might begin with "openshift-v". err = validate.RegisterValidation("openshift_version", func(fl validator.FieldLevel) bool { field := fl.Field() if field.Kind() != reflect.String { panic("String type required for openshift_version") } _, err := NewOpenShiftVersion(field.String()) return err == nil }) if err != nil { panic(err) } // Use this for string fields providing PEM encoded certificates. err = validate.RegisterValidation("pem_certificates", func(fl validator.FieldLevel) bool { field := fl.Field() if field.Kind() != reflect.String { panic("String type required for pem_certificates") } return x509.NewCertPool().AppendCertsFromPEM([]byte(field.String())) }) if err != nil { panic(err) } // Use this for fields required in PUT requests. Do not apply to read-only fields. err = validate.RegisterValidation("required_for_put", func(fl validator.FieldLevel) bool { val := fl.Top().FieldByName("Method") if val.IsZero() { panic("Method field not found for required_for_put") } if val.String() != http.MethodPut { return true } // This is replicating the implementation of "required". // See https://github.com/go-playground/validator/issues/492 // Sounds like "hasValue" is unlikely to be exported and // "validate.Var" does not seem like a safe alternative. field := fl.Field() _, kind, nullable := fl.ExtractType(field) switch kind { case reflect.Slice, reflect.Map, reflect.Ptr, reflect.Interface, reflect.Chan, reflect.Func: return !field.IsNil() default: if nullable && field.Interface() != nil { return true } return field.IsValid() && !field.IsZero() } }) if err != nil { panic(err) } // Use this for string fields specifying an Azure resource ID. // The optional argument further enforces a specific resource type. err = validate.RegisterValidation("resource_id", func(fl validator.FieldLevel) bool { field := fl.Field() param := fl.Param() if field.Kind() != reflect.String { panic("String type required for resource_id") } resourceID, err := azcorearm.ParseResourceID(field.String()) if err != nil { return false } resourceType := resourceID.ResourceType.String() return param == "" || strings.EqualFold(resourceType, param) }) if err != nil { panic(err) } return validate } type validateContext struct { // Fields must be exported so valdator can access. Method string Resource any } // approximateJSONName approximates the JSON name for a struct field name by // lowercasing the first letter. This is not always accurate in general but // works for the small set of cases where we need it. func approximateJSONName(s string) string { r, size := utf8.DecodeRuneInString(s) if r == utf8.RuneError && size <= 1 { return s } lc := unicode.ToLower(r) if r == lc { return s } return string(lc) + s[size:] } func ValidateRequest(validate *validator.Validate, method string, resource any) []arm.CloudErrorBody { var errorDetails []arm.CloudErrorBody err := validate.Struct(validateContext{Method: method, Resource: resource}) if err == nil { return nil } // Convert validation errors to cloud error details. switch err := err.(type) { case validator.ValidationErrors: for _, fieldErr := range err { message := fmt.Sprintf("Invalid value '%v' for field '%s'", fieldErr.Value(), fieldErr.Field()) // Try to add a corrective suggestion to the message. tag := fieldErr.Tag() if strings.HasPrefix(tag, "enum_") { if len(strings.Split(fieldErr.Param(), " ")) == 1 { message += fmt.Sprintf(" (must be %s)", fieldErr.Param()) } else { message += fmt.Sprintf(" (must be one of: %s)", fieldErr.Param()) } } else { switch tag { case "api_version": // custom tag message = fmt.Sprintf("Unrecognized API version '%s'", fieldErr.Value()) case "openshift_version": // custom tag message = fmt.Sprintf("Invalid OpenShift version '%s'", fieldErr.Value()) case "pem_certificates": // custom tag message += " (must provide PEM encoded certificates)" case "k8s_label_value": // custom tag // Rerun the label value validation to obtain the error message. if value, ok := fieldErr.Value().(string); ok { errList := k8svalidation.IsValidLabelValue(value) message += fmt.Sprintf(" (%s)", strings.Join(errList, "; ")) } case "k8s_qualified_name": // custom tag // Rerun the qualified name validation to obtain the error message. if value, ok := fieldErr.Value().(string); ok { errList := k8svalidation.IsQualifiedName(value) message += fmt.Sprintf(" (%s)", strings.Join(errList, "; ")) } case "required", "required_for_put": // custom tag message = fmt.Sprintf("Missing required field '%s'", fieldErr.Field()) case "required_unless": // The parameter format is pairs of "fieldName fieldValue". // Multiple pairs are possible but we currently only use one. fields := strings.Fields(fieldErr.Param()) if len(fields) > 1 { // We want to print the JSON name for the field // referenced in the parameter, but FieldError does // not provide access to the parent reflect.Type from // which we could look it up. So approximate the JSON // name by lowercasing the first letter. message = fmt.Sprintf("Field '%s' is required when '%s' is not '%s'", fieldErr.Field(), approximateJSONName(fields[0]), fields[1]) } case "resource_id": // custom tag if fieldErr.Param() != "" { message += fmt.Sprintf(" (must be a valid '%s' resource ID)", fieldErr.Param()) } else { message += " (must be a valid Azure resource ID)" } case "cidrv4": message += " (must be a v4 CIDR range)" case "dns_rfc1035_label": message += " (must be a valid DNS RFC 1035 label)" case "excluded_with": // We want to print the JSON name for the field // referenced in the parameter, but FieldError does // not provide access to the parent reflect.Type from // which we could look it up. So approximate the JSON // name by lowercasing the first letter. zero := reflect.Zero(fieldErr.Type()).Interface() message = fmt.Sprintf("Field '%s' must be %v when '%s' is specified", fieldErr.Field(), zero, approximateJSONName(fieldErr.Param())) case "gtefield": // We want to print the JSON name for the field // referenced in the parameter, but FieldError does // not provide access to the parent reflect.Type from // which we could look it up. So approximate the JSON // name by lowercasing the first letter. message += fmt.Sprintf(" (must be at least the value of '%s')", approximateJSONName(fieldErr.Param())) case "ipv4": message += " (must be an IPv4 address)" case "max": switch fieldErr.Kind() { case reflect.String: message += fmt.Sprintf(" (maximum length is %s)", fieldErr.Param()) default: if fieldErr.Param() == "0" { message += " (must be non-positive)" } else { message += fmt.Sprintf(" (must be at most %s)", fieldErr.Param()) } } case "min": switch fieldErr.Kind() { case reflect.String: message += fmt.Sprintf(" (minimum length is %s)", fieldErr.Param()) default: if fieldErr.Param() == "0" { message += " (must be non-negative)" } else { message += fmt.Sprintf(" (must be at least %s)", fieldErr.Param()) } } case "startswith": message += fmt.Sprintf(" (must start with '%s')", fieldErr.Param()) case "url": message += " (must be a URL)" } } errorDetails = append(errorDetails, arm.CloudErrorBody{ Code: arm.CloudErrorCodeInvalidRequestContent, Message: message, // Split "validateContext.Resource.{REMAINING_FIELDS}" Target: strings.SplitN(fieldErr.Namespace(), ".", 3)[2], }) } default: errorDetails = append(errorDetails, arm.CloudErrorBody{ Code: arm.CloudErrorCodeInvalidRequestContent, Message: err.Error(), }) } return errorDetails } // ValidateSubscription validates a subscription request payload. func ValidateSubscription(subscription *arm.Subscription) *arm.CloudError { cloudError := arm.NewCloudError( http.StatusBadRequest, arm.CloudErrorCodeMultipleErrorsOccurred, "", "Content validation failed on multiple fields") cloudError.Details = make([]arm.CloudErrorBody, 0) validate := NewValidator() // There is no PATCH method for subscriptions, so assume PUT. errorDetails := ValidateRequest(validate, http.MethodPut, subscription) if errorDetails != nil { cloudError.Details = append(cloudError.Details, errorDetails...) } switch len(cloudError.Details) { case 0: cloudError = nil case 1: // Promote a single validation error out of details. cloudError.CloudErrorBody = &cloudError.Details[0] } return cloudError }