internal/api/visibility.go (236 lines of code) (raw):

// Copyright 2025 Microsoft Corporation // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package api import ( "fmt" "reflect" "strings" "github.com/Azure/ARO-HCP/internal/api/arm" ) // Property visibility meanings: // https://azure.github.io/typespec-azure/docs/howtos/ARM/resource-type#property-visibility-and-other-constraints // // Field mutability guidelines: // https://github.com/microsoft/api-guidelines/blob/vNext/azure/Guidelines.md#resource-schema--field-mutability const VisibilityStructTagKey = "visibility" // VisibilityFlags holds a visibility struct tag value as bit flags. type VisibilityFlags uint8 const ( VisibilityRead VisibilityFlags = 1 << iota VisibilityCreate VisibilityUpdate // option flags VisibilityCaseInsensitive VisibilityDefault = VisibilityRead | VisibilityCreate | VisibilityUpdate ) func (f VisibilityFlags) ReadOnly() bool { return f&(VisibilityRead|VisibilityCreate|VisibilityUpdate) == VisibilityRead } func (f VisibilityFlags) CanUpdate() bool { return f&VisibilityUpdate != 0 } func (f VisibilityFlags) CaseInsensitive() bool { return f&VisibilityCaseInsensitive != 0 } func (f VisibilityFlags) String() string { s := []string{} if f&VisibilityRead != 0 { s = append(s, "read") } if f&VisibilityCreate != 0 { s = append(s, "create") } if f&VisibilityUpdate != 0 { s = append(s, "update") } if f&VisibilityCaseInsensitive != 0 { s = append(s, "nocase") } return strings.Join(s, " ") } func GetVisibilityFlags(tag reflect.StructTag) (VisibilityFlags, bool) { var flags VisibilityFlags tagValue, ok := tag.Lookup(VisibilityStructTagKey) if ok { for _, v := range strings.Fields(tagValue) { switch strings.ToLower(v) { case "read": flags |= VisibilityRead case "create": flags |= VisibilityCreate case "update": flags |= VisibilityUpdate case "nocase": flags |= VisibilityCaseInsensitive default: panic(fmt.Sprintf("Unknown visibility tag value '%s'", v)) } } } return flags, ok } func join(ns, name string) string { res := ns if res != "" { res += "." } res += name return res } type StructTagMap map[string]reflect.StructTag func buildStructTagMap(structTagMap StructTagMap, t reflect.Type, path string) { switch t.Kind() { case reflect.Map, reflect.Pointer, reflect.Slice: buildStructTagMap(structTagMap, t.Elem(), path) case reflect.Struct: for i := 0; i < t.NumField(); i++ { field := t.Field(i) subpath := join(path, field.Name) if len(field.Tag) > 0 { structTagMap[subpath] = field.Tag } buildStructTagMap(structTagMap, field.Type, subpath) } } } // NewStructTagMap returns a mapping of dot-separated struct field names // to struct tags for the given type. Each versioned API should create // its own visibiilty map for tracked resource types. // // Note: This assumes field names for internal and versioned structs are // identical where visibility is explicitly specified. If some divergence // emerges, one workaround could be to pass a field name override map. func NewStructTagMap[T any]() StructTagMap { structTagMap := StructTagMap{} buildStructTagMap(structTagMap, reflect.TypeFor[T](), "") return structTagMap } type validateVisibility struct { structTagMap StructTagMap updating bool errs []arm.CloudErrorBody } // ValidateVisibility compares the new value (newVal) to the current value // (curVal) and returns any violations of visibility restrictions as defined // by structTagMap. func ValidateVisibility(newVal, curVal interface{}, structTagMap StructTagMap, updating bool) []arm.CloudErrorBody { vv := validateVisibility{ structTagMap: structTagMap, updating: updating, } vv.recurse(reflect.ValueOf(newVal), reflect.ValueOf(curVal), "", "", "", VisibilityDefault) return vv.errs } // mapKey is a lookup key for the StructTagMap. It DOES NOT include subscripts // for arrays, maps or slices since all elements are the same type. // // namespace is the struct field path up to but not including the field being // evaluated, analogous to path.Dir. It DOES include subscripts for arrays, // maps and slices since its purpose is for error reporting. // // fieldname is the current field being evaluated, analgous to path.Base. It // also includes subscripts for arrays, maps and slices when evaluating their // immediate elements. func (vv *validateVisibility) recurse(newVal, curVal reflect.Value, mapKey, namespace, fieldname string, implicitVisibility VisibilityFlags) { flags, ok := GetVisibilityFlags(vv.structTagMap[mapKey]) if !ok { flags = implicitVisibility } if newVal.Type() != curVal.Type() { panic(fmt.Sprintf("%s: value types differ (%s vs %s)", join(namespace, fieldname), newVal.Type().Name(), curVal.Type().Name())) } // Generated API structs are all pointer fields. A nil pointer in // the incoming request (newVal) means the value is absent, which // is always acceptable for visibility validation. switch newVal.Kind() { case reflect.Chan, reflect.Func, reflect.Interface, reflect.Map, reflect.Pointer, reflect.Slice: if newVal.IsNil() { return } } switch newVal.Kind() { case reflect.Bool: if newVal.Bool() != curVal.Bool() { vv.checkFlags(flags, namespace, fieldname) } case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: if newVal.Int() != curVal.Int() { vv.checkFlags(flags, namespace, fieldname) } case reflect.Uint, reflect.Uintptr, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: if newVal.Uint() != curVal.Uint() { vv.checkFlags(flags, namespace, fieldname) } case reflect.Float32, reflect.Float64: if newVal.Float() != curVal.Float() { vv.checkFlags(flags, namespace, fieldname) } case reflect.Complex64, reflect.Complex128: if newVal.Complex() != curVal.Complex() { vv.checkFlags(flags, namespace, fieldname) } case reflect.String: if flags.CaseInsensitive() { if !strings.EqualFold(newVal.String(), curVal.String()) { vv.checkFlags(flags, namespace, fieldname) } } else { if newVal.String() != curVal.String() { vv.checkFlags(flags, namespace, fieldname) } } case reflect.Slice: // We already know that newVal is not nil. if curVal.IsNil() { vv.checkFlags(flags, namespace, fieldname) return } fallthrough case reflect.Array: if newVal.Len() != curVal.Len() { vv.checkFlags(flags, namespace, fieldname) } else { for i := 0; i < min(newVal.Len(), curVal.Len()); i++ { subscript := fmt.Sprintf("[%d]", i) vv.recurse(newVal.Index(i), curVal.Index(i), mapKey, namespace, fieldname+subscript, flags) } } case reflect.Interface, reflect.Pointer: // We already know that newVal is not nil. if curVal.IsNil() { vv.checkFlags(flags, namespace, fieldname) } else { vv.recurse(newVal.Elem(), curVal.Elem(), mapKey, namespace, fieldname, flags) } case reflect.Map: // Determine if newVal and curVal share identical keys. var keysEqual = true // We already know that newVal is not nil. if curVal.IsNil() || newVal.Len() != curVal.Len() { keysEqual = false } else { iter := newVal.MapRange() for iter.Next() { if !curVal.MapIndex(iter.Key()).IsValid() { keysEqual = false break } } } // Skip recursion if visibility check on the map itself fails. if !keysEqual && !vv.checkFlags(flags, namespace, fieldname) { return } // Initialize a zero value for when curVal is missing a key in newVal. // If the map value type is a pointer, create a zero value for the type // being pointed to. var zeroVal reflect.Value mapValueType := newVal.Type().Elem() if mapValueType.Kind() == reflect.Ptr { // This returns a pointer to the new value. zeroVal = reflect.New(mapValueType.Elem()) } else { // Follow the pointer to the new value. zeroVal = reflect.New(mapValueType).Elem() } iter := newVal.MapRange() for iter.Next() { k := iter.Key() subscript := fmt.Sprintf("[%q]", k.Interface()) if curVal.IsNil() || !curVal.MapIndex(k).IsValid() { vv.recurse(newVal.MapIndex(k), zeroVal, mapKey, namespace, fieldname+subscript, flags) } else { vv.recurse(newVal.MapIndex(k), curVal.MapIndex(k), mapKey, namespace, fieldname+subscript, flags) } } case reflect.Struct: for i := 0; i < newVal.NumField(); i++ { structField := newVal.Type().Field(i) mapKeyNext := join(mapKey, structField.Name) namespaceNext := join(namespace, fieldname) fieldnameNext := GetJSONTagName(vv.structTagMap[mapKeyNext]) if fieldnameNext == "" { fieldnameNext = structField.Name } vv.recurse(newVal.Field(i), curVal.Field(i), mapKeyNext, namespaceNext, fieldnameNext, flags) } } } func (vv *validateVisibility) checkFlags(flags VisibilityFlags, namespace, fieldname string) bool { if flags.ReadOnly() { vv.errs = append(vv.errs, arm.CloudErrorBody{ Code: arm.CloudErrorCodeInvalidRequestContent, Message: fmt.Sprintf("Field '%s' is read-only", fieldname), Target: join(namespace, fieldname), }) return false } else if vv.updating && !flags.CanUpdate() { vv.errs = append(vv.errs, arm.CloudErrorBody{ Code: arm.CloudErrorCodeInvalidRequestContent, Message: fmt.Sprintf("Field '%s' cannot be updated", fieldname), Target: join(namespace, fieldname), }) return false } return true }