v2/internal/reflecthelpers/reflect_helpers.go (282 lines of code) (raw):
/*
Copyright (c) Microsoft Corporation.
Licensed under the MIT license.
*/
package reflecthelpers
import (
"fmt"
"reflect"
"strings"
"github.com/rotisserie/eris"
"sigs.k8s.io/controller-runtime/pkg/client"
"github.com/Azure/azure-service-operator/v2/internal/set"
"github.com/Azure/azure-service-operator/v2/pkg/genruntime"
"github.com/Azure/azure-service-operator/v2/pkg/genruntime/configmaps"
)
// ValueOfPtr dereferences a pointer and returns the value the pointer points to.
// Use this as carefully as you would the * operator
// TODO: Can we delete this helper later when we have some better code generated functions?
func ValueOfPtr(ptr interface{}) interface{} {
v := reflect.ValueOf(ptr)
if v.Kind() != reflect.Ptr {
panic(fmt.Sprintf("Can't get value of pointer for non-pointer type %T", ptr))
}
val := reflect.Indirect(v)
return val.Interface()
}
// DeepCopyInto calls in.DeepCopyInto(out)
func DeepCopyInto(in client.Object, out client.Object) {
inVal := reflect.ValueOf(in)
method := inVal.MethodByName("DeepCopyInto")
method.Call([]reflect.Value{reflect.ValueOf(out)})
}
// FindReferences finds references of the given type on the provided object
func FindReferences(obj interface{}, t reflect.Type) (map[interface{}]struct{}, error) {
result := make(map[interface{}]struct{})
visitor := NewReflectVisitor()
visitor.VisitStruct = func(this *ReflectVisitor, it reflect.Value, ctx interface{}) error {
if it.Type() == t {
if it.CanInterface() {
result[it.Interface()] = struct{}{}
}
return nil
}
return IdentityVisitStruct(this, it, ctx)
}
err := visitor.Visit(obj, nil)
if err != nil {
return nil, eris.Wrapf(err, "scanning for references of type %s", t.String())
}
return result, nil
}
// FindPropertiesWithTag finds all the properties with the given tag on the specified object and
// returns a map of the property name to the property value
func FindPropertiesWithTag(obj interface{}, tag string) (map[string][]interface{}, error) {
result := make(map[string][]interface{})
visitor := NewReflectVisitor()
visitor.VisitStruct = func(this *ReflectVisitor, it reflect.Value, ctx interface{}) error {
// This was adapted from IdentityVisitStruct
for i := 0; i < it.NumField(); i++ {
fieldVal := it.Field(i)
if !fieldVal.CanInterface() {
// Bypass unexported fields
continue
}
structField := it.Type().Field(i)
path := ctx.(string)
if path == "" {
path = structField.Name
} else {
path += "." + structField.Name
}
_, ok := structField.Tag.Lookup(tag)
field := it.Field(i)
if ok && field.CanInterface() {
if len(result[path]) == 0 {
result[path] = []interface{}{}
}
result[path] = append(result[path], field.Interface())
}
err := this.visit(field, path)
if err != nil {
return err
}
}
return nil
}
err := visitor.Visit(obj, "")
if err != nil {
return nil, eris.Wrapf(err, "scanning for references to tag %s", tag)
}
return result, nil
}
// FindResourceReferences finds all the genruntime.ResourceReference's on the provided object
func FindResourceReferences(obj interface{}) (set.Set[genruntime.ResourceReference], error) {
return Find[genruntime.ResourceReference](obj)
}
// FindSecretReferences finds all the genruntime.SecretReference's on the provided object
func FindSecretReferences(obj interface{}) (set.Set[genruntime.SecretReference], error) {
return Find[genruntime.SecretReference](obj)
}
// FindSecretMaps finds all the genruntime.SecretMapReference's on the provided object
func FindSecretMaps(obj interface{}) (set.Set[genruntime.SecretMapReference], error) {
return Find[genruntime.SecretMapReference](obj)
}
// FindConfigMapReferences finds all the genruntime.ConfigMapReference's on the provided object
func FindConfigMapReferences(obj interface{}) (set.Set[genruntime.ConfigMapReference], error) {
return Find[genruntime.ConfigMapReference](obj)
}
// Find finds all the references of the given type on the provided object
func Find[T comparable](obj interface{}) (set.Set[T], error) {
var t T
untypedResult, err := FindReferences(obj, reflect.TypeOf(t))
if err != nil {
return nil, err
}
result := set.Make[T]()
for k := range untypedResult {
result.Add(k.(T))
}
return result, nil
}
// FindOptionalConfigMapReferences finds all the genruntime.ConfigMapReference's on the provided object
func FindOptionalConfigMapReferences(obj interface{}) ([]*configmaps.OptionalReferencePair, error) {
untypedResult, err := FindPropertiesWithTag(obj, "optionalConfigMapPair") // TODO: This is astmodel.OptionalConfigMapPairTag
if err != nil {
return nil, err
}
collector := make(map[string][]*configmaps.OptionalReferencePair)
suffix := "FromConfig" // TODO This is astmodel.OptionalConfigMapReferenceSuffix
// This could probably be more efficient, but this avoids code duplication, and we're not dealing
// with huge collections here.
for key, values := range untypedResult {
if strings.HasSuffix(key, suffix) {
continue
}
collector[key] = make([]*configmaps.OptionalReferencePair, 0, len(values))
for _, val := range values {
typedValue, ok := val.(*string)
if !ok {
return nil, eris.Errorf("value of property %s was not a *string like expected", key)
}
collector[key] = append(collector[key], &configmaps.OptionalReferencePair{
Name: key,
Value: typedValue,
})
}
}
for key, values := range untypedResult {
if !strings.HasSuffix(key, suffix) {
continue
}
idx := strings.TrimSuffix(key, suffix)
if len(values) != len(collector[idx]) {
return nil, eris.Errorf("number of Ref's didn't match number of Values for %s", idx)
}
for i, val := range values {
typedValue, ok := val.(*genruntime.ConfigMapReference)
if !ok {
return nil, eris.Errorf("value of property %s was not a genruntime.ConfigMapReference like expected", key)
}
collector[idx][i].RefName = key
collector[idx][i].Ref = typedValue
}
}
// Translate our collector into a simple list
var result []*configmaps.OptionalReferencePair
for _, values := range collector {
result = append(result, values...)
}
return result, nil
}
// GetObjectListItems gets the list of items from an ObjectList
func GetObjectListItems(listPtr client.ObjectList) ([]client.Object, error) {
itemsField, err := getItemsField(listPtr)
if err != nil {
return nil, err
}
var result []client.Object
for i := 0; i < itemsField.Len(); i++ {
item := itemsField.Index(i)
if item.Kind() == reflect.Struct {
if !item.CanAddr() {
return nil, eris.Errorf("provided list elements were not pointers, but cannot be addressed")
}
item = item.Addr()
}
typedItem, ok := item.Interface().(client.Object)
if !ok {
return nil, eris.Errorf("provided list elements did not implement client.Object interface")
}
result = append(result, typedItem)
}
return result, nil
}
// SetObjectListItems gets the list of items from an ObjectList
func SetObjectListItems(listPtr client.ObjectList, items []client.Object) (returnErr error) {
itemsField, err := getItemsField(listPtr)
if err != nil {
return err
}
if !itemsField.CanSet() {
return eris.Errorf("cannot set items field of %T", listPtr)
}
defer func() {
if recovered := recover(); recovered != nil {
returnErr = eris.Errorf("failed to set items field of %T: %s", listPtr, recovered)
}
}()
slice := reflect.MakeSlice(itemsField.Type(), 0, 0)
for _, item := range items {
val := reflect.ValueOf(item)
if val.Kind() == reflect.Ptr {
val = val.Elem()
}
slice = reflect.Append(slice, val)
}
itemsField.Set(slice)
return nil
}
func getItemsField(listPtr client.ObjectList) (reflect.Value, error) {
val := reflect.ValueOf(listPtr)
if val.Kind() != reflect.Ptr {
return reflect.Value{}, eris.Errorf("provided list was not a pointer, was %s", val.Kind())
}
list := val.Elem()
if list.Kind() != reflect.Struct {
return reflect.Value{}, eris.Errorf("provided list was not a struct, was %s", val.Kind())
}
itemsField := list.FieldByName("Items")
if (itemsField == reflect.Value{}) {
return reflect.Value{}, eris.Errorf("provided list has no field \"Items\"")
}
if itemsField.Kind() != reflect.Slice {
return reflect.Value{}, eris.Errorf("provided list \"Items\" field was not of type slice")
}
return itemsField, nil
}
// SetProperty sets the property on the provided object to the provided value.
// obj is the object to modify.
// propertyPath is a dot-separated path to the property to set.
// value is the value to set the property to.
// Returns an error if any of the properties in the path do not exist, if the property is not settable,
// or if the value provided is incompatible.
func SetProperty(obj any, propertyPath string, value any) error {
if obj == nil {
return eris.Errorf("provided object was nil")
}
if propertyPath == "" {
return eris.Errorf("property path was empty")
}
steps := strings.Split(propertyPath, ".")
return setPropertyCore(obj, steps, value)
}
func setPropertyCore(obj any, propertyPath []string, value any) (err error) {
// Catch any panic that occurs when setting the field and turn it into an error return
defer func() {
if recovered := recover(); recovered != nil {
err = eris.Errorf("failed to set property %s: %s", propertyPath[0], recovered)
}
}()
// Get the underlying object we need to modify
subject := reflect.ValueOf(obj)
// Dereference pointers
if subject.Kind() == reflect.Ptr {
subject = subject.Elem()
}
// Check we have a struct
if subject.Kind() != reflect.Struct {
return eris.Errorf("provided object was not a struct, was %s", subject.Kind())
}
// Get the field we need to modify
field := subject.FieldByName(propertyPath[0])
// Check the field exists
if field == (reflect.Value{}) {
return eris.Errorf("provided object did not have a field named %s", propertyPath[0])
}
// If this is not the last property in the path, we need to recurse
if len(propertyPath) > 1 {
if field.Kind() == reflect.Ptr {
// Field is a pointer; initialize it if needed, then pass the pointer recursively
if field.IsNil() {
newValue := reflect.New(field.Type().Elem())
field.Set(newValue)
}
err = setPropertyCore(field.Interface(), propertyPath[1:], value)
if err != nil {
return eris.Wrapf(err, "failed to set property %s",
propertyPath[0])
}
return nil
}
// Field is not a pointer, so we need to pass the address of the field recursively
err = setPropertyCore(field.Addr().Interface(), propertyPath[1:], value)
if err != nil {
return eris.Wrapf(err, "failed to set property %s",
propertyPath[0])
}
return nil
}
// If this is the last property in the path, we need to set the value, if we can
if !field.CanSet() {
return eris.Errorf("field %s was not settable", propertyPath[0])
}
// Cast value to the type required by the field
valueKind := reflect.ValueOf(value)
if !valueKind.CanConvert(field.Type()) {
return eris.Errorf("value of kind %s was not compatible with field %s", valueKind, propertyPath[0])
}
value = valueKind.Convert(field.Type()).Interface()
field.Set(reflect.ValueOf(value))
return nil
}
// GetJSONTags returns a set of JSON keys used in the `json` annotation of a struct
func GetJSONTags(t reflect.Type) set.Set[string] {
tags := set.Make[string]()
for i := 0; i < t.NumField(); i++ {
fieldType := t.Field(i)
tag := fieldType.Tag.Get("json")
if tag != "" {
// Split the tag to handle omitempty and other options
tags.Add(strings.Split(tag, ",")[0])
}
}
return tags
}