schema.go (1,156 lines of code) (raw):

// Licensed to the Apache Software Foundation (ASF) under one // or more contributor license agreements. See the NOTICE file // distributed with this work for additional information // regarding copyright ownership. The ASF licenses this file // to you under the Apache License, Version 2.0 (the // "License"); you may not use this file except in compliance // with the License. You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, // software distributed under the License is distributed on an // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY // KIND, either express or implied. See the License for the // specific language governing permissions and limitations // under the License. package iceberg import ( "encoding/json" "fmt" "maps" "slices" "strings" "sync" "sync/atomic" "unicode" ) // Schema is an Iceberg table schema, represented as a struct with // multiple fields. The fields are only exported via accessor methods // rather than exposing the slice directly in order to ensure a schema // as immutable. type Schema struct { ID int `json:"schema-id"` IdentifierFieldIDs []int `json:"identifier-field-ids"` fields []NestedField // the following maps are lazily populated as needed. // rather than have lock contention with a mutex, we can use // atomic pointers to Store/Load the values. idToName atomic.Pointer[map[int]string] idToField atomic.Pointer[map[int]NestedField] nameToID atomic.Pointer[map[string]int] nameToIDLower atomic.Pointer[map[string]int] idToAccessor atomic.Pointer[map[int]accessor] lazyIDToParent func() (map[int]int, error) lazyNameMapping func() NameMapping } // NewSchema constructs a new schema with the provided ID // and list of fields. func NewSchema(id int, fields ...NestedField) *Schema { return NewSchemaWithIdentifiers(id, []int{}, fields...) } // NewSchemaWithIdentifiers constructs a new schema with the provided ID // and fields, along with a slice of field IDs to be listed as identifier // fields. func NewSchemaWithIdentifiers(id int, identifierIDs []int, fields ...NestedField) *Schema { s := &Schema{ID: id, fields: fields, IdentifierFieldIDs: identifierIDs} s.init() return s } func (s *Schema) init() { s.lazyIDToParent = sync.OnceValues(func() (map[int]int, error) { return IndexParents(s) }) s.lazyNameMapping = sync.OnceValue(func() NameMapping { return createMappingFromSchema(s) }) } func (s *Schema) String() string { var b strings.Builder b.WriteString("table {") for _, f := range s.fields { b.WriteString("\n\t") b.WriteString(f.String()) } b.WriteString("\n}") return b.String() } func (s *Schema) lazyNameToID() (map[string]int, error) { index := s.nameToID.Load() if index != nil { return *index, nil } idx, err := IndexByName(s) if err != nil { return nil, err } s.nameToID.Store(&idx) return idx, nil } func (s *Schema) lazyIDToField() (map[int]NestedField, error) { index := s.idToField.Load() if index != nil { return *index, nil } idx, err := IndexByID(s) if err != nil { return nil, err } s.idToField.Store(&idx) return idx, nil } func (s *Schema) lazyIDToName() (map[int]string, error) { index := s.idToName.Load() if index != nil { return *index, nil } idx, err := IndexNameByID(s) if err != nil { return nil, err } s.idToName.Store(&idx) return idx, nil } func (s *Schema) lazyNameToIDLower() (map[string]int, error) { index := s.nameToIDLower.Load() if index != nil { return *index, nil } idx, err := s.lazyNameToID() if err != nil { return nil, err } out := make(map[string]int) for k, v := range idx { out[strings.ToLower(k)] = v } s.nameToIDLower.Store(&out) return out, nil } func (s *Schema) lazyIdToAccessor() (map[int]accessor, error) { index := s.idToAccessor.Load() if index != nil { return *index, nil } idx, err := buildAccessors(s) if err != nil { return nil, err } s.idToAccessor.Store(&idx) return idx, nil } func (s *Schema) NameMapping() NameMapping { return s.lazyNameMapping() } func (s *Schema) Type() string { return "struct" } // AsStruct returns a Struct with the same fields as the schema which can // then be used as a Type. func (s *Schema) AsStruct() StructType { return StructType{FieldList: s.fields} } func (s *Schema) NumFields() int { return len(s.fields) } func (s *Schema) Field(i int) NestedField { return s.fields[i] } func (s *Schema) Fields() []NestedField { return slices.Clone(s.fields) } func (s *Schema) FieldIDs() []int { idx, _ := s.lazyNameToID() return slices.Collect(maps.Values(idx)) } func (s *Schema) UnmarshalJSON(b []byte) error { type Alias Schema aux := struct { Fields []NestedField `json:"fields"` *Alias }{Alias: (*Alias)(s)} if err := json.Unmarshal(b, &aux); err != nil { return err } s.init() s.fields = aux.Fields if s.IdentifierFieldIDs == nil { s.IdentifierFieldIDs = []int{} } return nil } func (s *Schema) MarshalJSON() ([]byte, error) { if s.IdentifierFieldIDs == nil { s.IdentifierFieldIDs = []int{} } type Alias Schema return json.Marshal(struct { Type string `json:"type"` Fields []NestedField `json:"fields"` *Alias }{Type: "struct", Fields: s.fields, Alias: (*Alias)(s)}) } // FindColumnName returns the name of the column identified by the // passed in field id. The second return value reports whether or // not the field id was found in the schema. func (s *Schema) FindColumnName(fieldID int) (string, bool) { idx, _ := s.lazyIDToName() col, ok := idx[fieldID] return col, ok } // FindFieldByName returns the field identified by the name given, // the second return value will be false if no field by this name // is found. // // Note: This search is done in a case sensitive manner. To perform // a case insensitive search, use [*Schema.FindFieldByNameCaseInsensitive]. func (s *Schema) FindFieldByName(name string) (NestedField, bool) { idx, _ := s.lazyNameToID() id, ok := idx[name] if !ok { return NestedField{}, false } return s.FindFieldByID(id) } // FindFieldByNameCaseInsensitive is like [*Schema.FindFieldByName], // but performs a case insensitive search. func (s *Schema) FindFieldByNameCaseInsensitive(name string) (NestedField, bool) { idx, _ := s.lazyNameToIDLower() id, ok := idx[strings.ToLower(name)] if !ok { return NestedField{}, false } return s.FindFieldByID(id) } // FindFieldByID is like [*Schema.FindColumnName], but returns the whole // field rather than just the field name. func (s *Schema) FindFieldByID(id int) (NestedField, bool) { idx, _ := s.lazyIDToField() f, ok := idx[id] return f, ok } // FindTypeByID is like [*Schema.FindFieldByID], but returns only the data // type of the field. func (s *Schema) FindTypeByID(id int) (Type, bool) { f, ok := s.FindFieldByID(id) if !ok { return nil, false } return f.Type, true } // FindTypeByName is a convenience function for calling [*Schema.FindFieldByName], // and then returning just the type. func (s *Schema) FindTypeByName(name string) (Type, bool) { f, ok := s.FindFieldByName(name) if !ok { return nil, false } return f.Type, true } // FindTypeByNameCaseInsensitive is like [*Schema.FindTypeByName] but // performs a case insensitive search. func (s *Schema) FindTypeByNameCaseInsensitive(name string) (Type, bool) { f, ok := s.FindFieldByNameCaseInsensitive(name) if !ok { return nil, false } return f.Type, true } func (s *Schema) accessorForField(id int) (accessor, bool) { idx, err := s.lazyIdToAccessor() if err != nil { return accessor{}, false } acc, ok := idx[id] return acc, ok } // Equals compares the fields and identifierIDs, but does not compare // the schema ID itself. func (s *Schema) Equals(other *Schema) bool { if other == nil { return false } if s == other { return true } if len(s.fields) != len(other.fields) { return false } if !slices.Equal(s.IdentifierFieldIDs, other.IdentifierFieldIDs) { return false } return slices.EqualFunc(s.fields, other.fields, func(a, b NestedField) bool { return a.Equals(b) }) } // HighestFieldID returns the value of the numerically highest field ID // in this schema. func (s *Schema) HighestFieldID() int { id, _ := Visit(s, findLastFieldID{}) return id } type Void = struct{} var void = Void{} // Select creates a new schema with just the fields identified by name // passed in the order they are provided. If caseSensitive is false, // then fields will be identified by case insensitive search. // // An error is returned if a requested name cannot be found. func (s *Schema) Select(caseSensitive bool, names ...string) (*Schema, error) { ids := make(map[int]Void) if caseSensitive { nameMap, _ := s.lazyNameToID() for _, n := range names { id, ok := nameMap[n] if !ok { return nil, fmt.Errorf("%w: could not find column %s", ErrInvalidSchema, n) } ids[id] = void } } else { nameMap, _ := s.lazyNameToIDLower() for _, n := range names { id, ok := nameMap[strings.ToLower(n)] if !ok { return nil, fmt.Errorf("%w: could not find column %s", ErrInvalidSchema, n) } ids[id] = void } } return PruneColumns(s, ids, true) } func (s *Schema) FieldHasOptionalParent(id int) bool { idToParent, _ := s.lazyIDToParent() idToField, _ := s.lazyIDToField() f, ok := idToField[id] if !ok { return false } for { parent, ok := idToParent[f.ID] if !ok { return false } if f = idToField[parent]; !f.Required { return true } } } // SchemaVisitor is an interface that can be implemented to allow for // easy traversal and processing of a schema. // // A SchemaVisitor can also optionally implement the Before/After Field, // ListElement, MapKey, or MapValue interfaces to allow them to get called // at the appropriate points within schema traversal. type SchemaVisitor[T any] interface { Schema(schema *Schema, structResult T) T Struct(st StructType, fieldResults []T) T Field(field NestedField, fieldResult T) T List(list ListType, elemResult T) T Map(mapType MapType, keyResult, valueResult T) T Primitive(p PrimitiveType) T } type BeforeFieldVisitor interface { BeforeField(field NestedField) } type AfterFieldVisitor interface { AfterField(field NestedField) } type BeforeListElementVisitor interface { BeforeListElement(elem NestedField) } type AfterListElementVisitor interface { AfterListElement(elem NestedField) } type BeforeMapKeyVisitor interface { BeforeMapKey(key NestedField) } type AfterMapKeyVisitor interface { AfterMapKey(key NestedField) } type BeforeMapValueVisitor interface { BeforeMapValue(value NestedField) } type AfterMapValueVisitor interface { AfterMapValue(value NestedField) } type SchemaVisitorPerPrimitiveType[T any] interface { SchemaVisitor[T] VisitFixed(FixedType) T VisitDecimal(DecimalType) T VisitBoolean() T VisitInt32() T VisitInt64() T VisitFloat32() T VisitFloat64() T VisitDate() T VisitTime() T VisitTimestamp() T VisitTimestampTz() T VisitString() T VisitBinary() T VisitUUID() T } // Visit accepts a visitor and performs a post-order traversal of the given schema. func Visit[T any](sc *Schema, visitor SchemaVisitor[T]) (res T, err error) { if sc == nil { err = fmt.Errorf("%w: cannot visit nil schema", ErrInvalidArgument) return } defer func() { if r := recover(); r != nil { switch e := r.(type) { case string: err = fmt.Errorf("error encountered during schema visitor: %s", e) case error: err = fmt.Errorf("error encountered during schema visitor: %w", e) } } }() return visitor.Schema(sc, visitStruct(sc.AsStruct(), visitor)), nil } func visitStruct[T any](obj StructType, visitor SchemaVisitor[T]) T { results := make([]T, len(obj.FieldList)) bf, _ := visitor.(BeforeFieldVisitor) af, _ := visitor.(AfterFieldVisitor) for i, f := range obj.FieldList { if bf != nil { bf.BeforeField(f) } res := visitField(f, visitor) if af != nil { af.AfterField(f) } results[i] = visitor.Field(f, res) } return visitor.Struct(obj, results) } func visitList[T any](obj ListType, visitor SchemaVisitor[T]) T { elemField := obj.ElementField() if bl, ok := visitor.(BeforeListElementVisitor); ok { bl.BeforeListElement(elemField) } else if bf, ok := visitor.(BeforeFieldVisitor); ok { bf.BeforeField(elemField) } res := visitField(elemField, visitor) if al, ok := visitor.(AfterListElementVisitor); ok { al.AfterListElement(elemField) } else if af, ok := visitor.(AfterFieldVisitor); ok { af.AfterField(elemField) } return visitor.List(obj, res) } func visitMap[T any](obj MapType, visitor SchemaVisitor[T]) T { keyField, valueField := obj.KeyField(), obj.ValueField() if bmk, ok := visitor.(BeforeMapKeyVisitor); ok { bmk.BeforeMapKey(keyField) } else if bf, ok := visitor.(BeforeFieldVisitor); ok { bf.BeforeField(keyField) } keyRes := visitField(keyField, visitor) if amk, ok := visitor.(AfterMapKeyVisitor); ok { amk.AfterMapKey(keyField) } else if af, ok := visitor.(AfterFieldVisitor); ok { af.AfterField(keyField) } if bmk, ok := visitor.(BeforeMapValueVisitor); ok { bmk.BeforeMapValue(valueField) } else if bf, ok := visitor.(BeforeFieldVisitor); ok { bf.BeforeField(valueField) } valueRes := visitField(valueField, visitor) if amk, ok := visitor.(AfterMapValueVisitor); ok { amk.AfterMapValue(valueField) } else if af, ok := visitor.(AfterFieldVisitor); ok { af.AfterField(valueField) } return visitor.Map(obj, keyRes, valueRes) } func visitField[T any](f NestedField, visitor SchemaVisitor[T]) T { switch typ := f.Type.(type) { case *StructType: return visitStruct(*typ, visitor) case *ListType: return visitList(*typ, visitor) case *MapType: return visitMap(*typ, visitor) default: // primitive if perPrimitive, ok := visitor.(SchemaVisitorPerPrimitiveType[T]); ok { switch t := typ.(type) { case BooleanType: return perPrimitive.VisitBoolean() case Int32Type: return perPrimitive.VisitInt32() case Int64Type: return perPrimitive.VisitInt64() case Float32Type: return perPrimitive.VisitFloat32() case Float64Type: return perPrimitive.VisitFloat64() case DateType: return perPrimitive.VisitDate() case TimeType: return perPrimitive.VisitTime() case TimestampType: return perPrimitive.VisitTimestamp() case TimestampTzType: return perPrimitive.VisitTimestampTz() case StringType: return perPrimitive.VisitString() case BinaryType: return perPrimitive.VisitBinary() case UUIDType: return perPrimitive.VisitUUID() case DecimalType: return perPrimitive.VisitDecimal(t) case FixedType: return perPrimitive.VisitFixed(t) } } return visitor.Primitive(typ.(PrimitiveType)) } } type PreOrderSchemaVisitor[T any] interface { Schema(*Schema, func() T) T Struct(StructType, []func() T) T Field(NestedField, func() T) T List(ListType, func() T) T Map(MapType, func() T, func() T) T Primitive(PrimitiveType) T } func PreOrderVisit[T any](sc *Schema, visitor PreOrderSchemaVisitor[T]) (res T, err error) { if sc == nil { err = fmt.Errorf("%w: cannot visit nil schema", ErrInvalidArgument) return } defer func() { if r := recover(); r != nil { switch e := r.(type) { case string: err = fmt.Errorf("error encountered during schema visitor: %s", e) case error: err = fmt.Errorf("error encountered during schema visitor: %w", e) } } }() return visitor.Schema(sc, func() T { return visitStructPreOrder(sc.AsStruct(), visitor) }), nil } func visitStructPreOrder[T any](obj StructType, visitor PreOrderSchemaVisitor[T]) T { results := make([]func() T, len(obj.FieldList)) for i, f := range obj.FieldList { results[i] = func() T { return visitFieldPreOrder(f, visitor) } } return visitor.Struct(obj, results) } func visitListPreOrder[T any](obj ListType, visitor PreOrderSchemaVisitor[T]) T { return visitor.List(obj, func() T { return visitFieldPreOrder(obj.ElementField(), visitor) }) } func visitMapPreOrder[T any](obj MapType, visitor PreOrderSchemaVisitor[T]) T { return visitor.Map(obj, func() T { return visitFieldPreOrder(obj.KeyField(), visitor) }, func() T { return visitFieldPreOrder(obj.ValueField(), visitor) }) } func visitFieldPreOrder[T any](f NestedField, visitor PreOrderSchemaVisitor[T]) T { var fn func() T switch typ := f.Type.(type) { case *StructType: fn = func() T { return visitStructPreOrder(*typ, visitor) } case *ListType: fn = func() T { return visitListPreOrder(*typ, visitor) } case *MapType: fn = func() T { return visitMapPreOrder(*typ, visitor) } default: fn = func() T { return visitor.Primitive(typ.(PrimitiveType)) } } return visitor.Field(f, fn) } // IndexByID performs a post-order traversal of the given schema and // returns a mapping from field ID to field. func IndexByID(schema *Schema) (map[int]NestedField, error) { return Visit(schema, &indexByID{index: make(map[int]NestedField)}) } type indexByID struct { index map[int]NestedField } func (i *indexByID) Schema(*Schema, map[int]NestedField) map[int]NestedField { return i.index } func (i *indexByID) Struct(StructType, []map[int]NestedField) map[int]NestedField { return i.index } func (i *indexByID) Field(field NestedField, _ map[int]NestedField) map[int]NestedField { i.index[field.ID] = field return i.index } func (i *indexByID) List(list ListType, _ map[int]NestedField) map[int]NestedField { i.index[list.ElementID] = list.ElementField() return i.index } func (i *indexByID) Map(mapType MapType, _, _ map[int]NestedField) map[int]NestedField { i.index[mapType.KeyID] = mapType.KeyField() i.index[mapType.ValueID] = mapType.ValueField() return i.index } func (i *indexByID) Primitive(PrimitiveType) map[int]NestedField { return i.index } // IndexByName performs a post-order traversal of the schema and returns // a mapping from field name to field ID. func IndexByName(schema *Schema) (map[string]int, error) { if schema == nil { return nil, fmt.Errorf("%w: cannot index nil schema", ErrInvalidArgument) } if len(schema.fields) > 0 { indexer := &indexByName{ index: make(map[string]int), shortNameId: make(map[string]int), fieldNames: make([]string, 0), shortFieldNames: make([]string, 0), } if _, err := Visit(schema, indexer); err != nil { return nil, err } return indexer.ByName(), nil } return map[string]int{}, nil } // IndexNameByID performs a post-order traversal of the schema and returns // a mapping from field ID to field name. func IndexNameByID(schema *Schema) (map[int]string, error) { indexer := &indexByName{ index: make(map[string]int), shortNameId: make(map[string]int), fieldNames: make([]string, 0), shortFieldNames: make([]string, 0), } if _, err := Visit(schema, indexer); err != nil { return nil, err } return indexer.ByID(), nil } type indexByName struct { index map[string]int shortNameId map[string]int combinedIndex map[string]int fieldNames []string shortFieldNames []string } func (i *indexByName) ByID() map[int]string { idToName := make(map[int]string) for k, v := range i.index { idToName[v] = k } return idToName } func (i *indexByName) ByName() map[string]int { i.combinedIndex = maps.Clone(i.shortNameId) maps.Copy(i.combinedIndex, i.index) return i.combinedIndex } func (i *indexByName) Primitive(PrimitiveType) map[string]int { return i.index } func (i *indexByName) addField(name string, fieldID int) { fullName := name if len(i.fieldNames) > 0 { fullName = strings.Join(i.fieldNames, ".") + "." + name } if _, ok := i.index[fullName]; ok { panic(fmt.Errorf("%w: multiple fields for name %s: %d and %d", ErrInvalidSchema, fullName, i.index[fullName], fieldID)) } i.index[fullName] = fieldID if len(i.shortFieldNames) > 0 { shortName := strings.Join(i.shortFieldNames, ".") + "." + name i.shortNameId[shortName] = fieldID } } func (i *indexByName) Schema(*Schema, map[string]int) map[string]int { return i.index } func (i *indexByName) Struct(StructType, []map[string]int) map[string]int { return i.index } func (i *indexByName) Field(field NestedField, _ map[string]int) map[string]int { i.addField(field.Name, field.ID) return i.index } func (i *indexByName) List(list ListType, _ map[string]int) map[string]int { i.addField(list.ElementField().Name, list.ElementID) return i.index } func (i *indexByName) Map(mapType MapType, _, _ map[string]int) map[string]int { i.addField(mapType.KeyField().Name, mapType.KeyID) i.addField(mapType.ValueField().Name, mapType.ValueID) return i.index } func (i *indexByName) BeforeListElement(elem NestedField) { if _, ok := elem.Type.(*StructType); !ok { i.shortFieldNames = append(i.shortFieldNames, elem.Name) } i.fieldNames = append(i.fieldNames, elem.Name) } func (i *indexByName) AfterListElement(elem NestedField) { if _, ok := elem.Type.(*StructType); !ok { i.shortFieldNames = i.shortFieldNames[:len(i.shortFieldNames)-1] } i.fieldNames = i.fieldNames[:len(i.fieldNames)-1] } func (i *indexByName) BeforeField(field NestedField) { i.fieldNames = append(i.fieldNames, field.Name) i.shortFieldNames = append(i.shortFieldNames, field.Name) } func (i *indexByName) AfterField(field NestedField) { i.fieldNames = i.fieldNames[:len(i.fieldNames)-1] i.shortFieldNames = i.shortFieldNames[:len(i.shortFieldNames)-1] } // PruneColumns visits a schema pruning any columns which do not exist in the // provided selected set. Parent fields of a selected child will be retained. func PruneColumns(schema *Schema, selected map[int]Void, selectFullTypes bool) (*Schema, error) { result, err := Visit(schema, &pruneColVisitor{ selected: selected, fullTypes: selectFullTypes, }) if err != nil { return nil, err } n, ok := result.(NestedType) if !ok { n = &StructType{} } newIdentifierIDs := make([]int, 0, len(schema.IdentifierFieldIDs)) for _, id := range schema.IdentifierFieldIDs { if _, ok := selected[id]; ok { newIdentifierIDs = append(newIdentifierIDs, id) } } return &Schema{ fields: n.Fields(), ID: schema.ID, IdentifierFieldIDs: newIdentifierIDs, }, nil } type pruneColVisitor struct { selected map[int]Void fullTypes bool } func (p *pruneColVisitor) Schema(_ *Schema, structResult Type) Type { return structResult } func (p *pruneColVisitor) Struct(st StructType, fieldResults []Type) Type { selected, fields := []NestedField{}, st.FieldList sameType := true for i, t := range fieldResults { field := fields[i] if field.Type == t { selected = append(selected, field) } else if t != nil { sameType = false // type has changed, create a new field with the projected type selected = append(selected, NestedField{ ID: field.ID, Name: field.Name, Type: t, Doc: field.Doc, Required: field.Required, }) } } if len(selected) > 0 { if len(selected) == len(fields) && sameType { // nothing changed, return the original return &st } else { return &StructType{FieldList: selected} } } return nil } func (p *pruneColVisitor) Field(field NestedField, fieldResult Type) Type { _, ok := p.selected[field.ID] if !ok { if fieldResult != nil { return fieldResult } return nil } if p.fullTypes { return field.Type } if _, ok := field.Type.(*StructType); ok { return p.projectSelectedStruct(fieldResult) } typ, ok := field.Type.(PrimitiveType) if !ok { panic(fmt.Errorf("%w: cannot explicitly project List or Map types, %d:%s of type %s was selected", ErrInvalidSchema, field.ID, field.Name, field.Type)) } return typ } func (p *pruneColVisitor) List(list ListType, elemResult Type) Type { _, ok := p.selected[list.ElementID] if !ok { if elemResult != nil { return p.projectList(&list, elemResult) } return nil } if p.fullTypes { return &list } _, ok = list.Element.(*StructType) if list.Element != nil && ok { projected := p.projectSelectedStruct(elemResult) return p.projectList(&list, projected) } if _, ok = list.Element.(PrimitiveType); !ok { panic(fmt.Errorf("%w: cannot explicitly project List or Map types, %d of type %s was selected", ErrInvalidSchema, list.ElementID, list.Element)) } return &list } func (p *pruneColVisitor) Map(mapType MapType, keyResult, valueResult Type) Type { _, ok := p.selected[mapType.ValueID] if !ok { if valueResult != nil { return p.projectMap(&mapType, valueResult) } if _, ok = p.selected[mapType.KeyID]; ok { return &mapType } return nil } if p.fullTypes { return &mapType } _, ok = mapType.ValueType.(*StructType) if mapType.ValueType != nil && ok { projected := p.projectSelectedStruct(valueResult) return p.projectMap(&mapType, projected) } if _, ok = mapType.ValueType.(PrimitiveType); !ok { panic(fmt.Errorf("%w: cannot explicitly project List or Map types, Map value %d of type %s was selected", ErrInvalidSchema, mapType.ValueID, mapType.ValueType)) } return &mapType } func (p *pruneColVisitor) Primitive(_ PrimitiveType) Type { return nil } func (*pruneColVisitor) projectSelectedStruct(projected Type) *StructType { if projected == nil { return &StructType{} } if ty, ok := projected.(*StructType); ok { return ty } panic("expected a struct") } func (*pruneColVisitor) projectList(listType *ListType, elementResult Type) *ListType { if listType.Element.Equals(elementResult) { return listType } return &ListType{ ElementID: listType.ElementID, Element: elementResult, ElementRequired: listType.ElementRequired, } } func (*pruneColVisitor) projectMap(mapType *MapType, valueResult Type) *MapType { if mapType.ValueType.Equals(valueResult) { return mapType } return &MapType{ KeyID: mapType.KeyID, ValueID: mapType.ValueID, KeyType: mapType.KeyType, ValueType: valueResult, ValueRequired: mapType.ValueRequired, } } type findLastFieldID struct{} func (findLastFieldID) Schema(_ *Schema, result int) int { return result } func (findLastFieldID) Struct(_ StructType, fieldResults []int) int { return max(fieldResults...) } func (findLastFieldID) Field(field NestedField, fieldResult int) int { return max(field.ID, fieldResult) } func (findLastFieldID) List(_ ListType, elemResult int) int { return elemResult } func (findLastFieldID) Map(_ MapType, keyResult, valueResult int) int { return max(keyResult, valueResult) } func (findLastFieldID) Primitive(PrimitiveType) int { return 0 } // IndexParents generates an index of field IDs to their parent field // IDs. Root fields are not indexed func IndexParents(schema *Schema) (map[int]int, error) { indexer := &indexParents{ idToParent: make(map[int]int), idStack: make([]int, 0), } return Visit(schema, indexer) } type indexParents struct { idToParent map[int]int idStack []int } func (i *indexParents) BeforeField(field NestedField) { i.idStack = append(i.idStack, field.ID) } func (i *indexParents) AfterField(field NestedField) { i.idStack = i.idStack[:len(i.idStack)-1] } func (i *indexParents) Schema(schema *Schema, _ map[int]int) map[int]int { return i.idToParent } func (i *indexParents) Struct(st StructType, _ []map[int]int) map[int]int { var parent int stackLen := len(i.idStack) if stackLen > 0 { parent = i.idStack[stackLen-1] for _, f := range st.FieldList { i.idToParent[f.ID] = parent } } return i.idToParent } func (i *indexParents) Field(NestedField, map[int]int) map[int]int { return i.idToParent } func (i *indexParents) List(list ListType, _ map[int]int) map[int]int { i.idToParent[list.ElementID] = i.idStack[len(i.idStack)-1] return i.idToParent } func (i *indexParents) Map(mapType MapType, _, _ map[int]int) map[int]int { parent := i.idStack[len(i.idStack)-1] i.idToParent[mapType.KeyID] = parent i.idToParent[mapType.ValueID] = parent return i.idToParent } func (i *indexParents) Primitive(PrimitiveType) map[int]int { return i.idToParent } type buildPosAccessors struct{} func (buildPosAccessors) Schema(_ *Schema, structResult map[int]accessor) map[int]accessor { return structResult } func (buildPosAccessors) Struct(st StructType, fieldResults []map[int]accessor) map[int]accessor { result := map[int]accessor{} for pos, f := range st.FieldList { if innerMap := fieldResults[pos]; len(innerMap) != 0 { for inner, acc := range innerMap { acc := acc result[inner] = accessor{pos: pos, inner: &acc} } } else { result[f.ID] = accessor{pos: pos} } } return result } func (buildPosAccessors) Field(_ NestedField, fieldResult map[int]accessor) map[int]accessor { return fieldResult } func (buildPosAccessors) List(ListType, map[int]accessor) map[int]accessor { return map[int]accessor{} } func (buildPosAccessors) Map(_ MapType, _, _ map[int]accessor) map[int]accessor { return map[int]accessor{} } func (buildPosAccessors) Primitive(PrimitiveType) map[int]accessor { return map[int]accessor{} } func buildAccessors(schema *Schema) (map[int]accessor, error) { return Visit(schema, buildPosAccessors{}) } type setFreshIDs struct { oldIdToNew map[int]int nextIDFunc func() int } func (s *setFreshIDs) getAndInc(currentID int) int { next := s.nextIDFunc() s.oldIdToNew[currentID] = next return next } func (s *setFreshIDs) Schema(_ *Schema, structResult func() Type) Type { return structResult() } func (s *setFreshIDs) Struct(st StructType, fieldResults []func() Type) Type { newFields := make([]NestedField, len(st.FieldList)) for idx, f := range st.FieldList { newFields[idx] = NestedField{ ID: s.getAndInc(f.ID), Name: f.Name, Type: fieldResults[idx](), Doc: f.Doc, Required: f.Required, } } return &StructType{FieldList: newFields} } func (s *setFreshIDs) Field(_ NestedField, fieldResult func() Type) Type { return fieldResult() } func (s *setFreshIDs) List(list ListType, elemResult func() Type) Type { elemID := s.getAndInc(list.ElementID) return &ListType{ ElementID: elemID, Element: elemResult(), ElementRequired: list.ElementRequired, } } func (s *setFreshIDs) Map(mapType MapType, keyResult, valueResult func() Type) Type { keyID := s.getAndInc(mapType.KeyID) valueID := s.getAndInc(mapType.ValueID) return &MapType{ KeyID: keyID, ValueID: valueID, KeyType: keyResult(), ValueType: valueResult(), ValueRequired: mapType.ValueRequired, } } func (s *setFreshIDs) Primitive(p PrimitiveType) Type { return p } // AssignFreshSchemaIDs creates a new schema with fresh field IDs for all of the // fields in it. The nextID function is used to iteratively generate the ids, if // it is nil then a simple incrementing counter is used starting at 1. func AssignFreshSchemaIDs(sc *Schema, nextID func() int) (*Schema, error) { if nextID == nil { var id int = 0 nextID = func() int { id++ return id } } visitor := &setFreshIDs{oldIdToNew: make(map[int]int), nextIDFunc: nextID} outType, err := PreOrderVisit(sc, visitor) if err != nil { return nil, err } fields := outType.(*StructType).FieldList var newIdentifierIDs []int if len(sc.IdentifierFieldIDs) != 0 { newIdentifierIDs = make([]int, len(sc.IdentifierFieldIDs)) for i, id := range sc.IdentifierFieldIDs { newIdentifierIDs[i] = visitor.oldIdToNew[id] } } return NewSchemaWithIdentifiers(0, newIdentifierIDs, fields...), nil } type SchemaWithPartnerVisitor[T, P any] interface { Schema(sc *Schema, schemaPartner P, structResult T) T Struct(st StructType, structPartner P, fieldResults []T) T Field(field NestedField, fieldPartner P, fieldResult T) T List(l ListType, listPartner P, elemResult T) T Map(m MapType, mapPartner P, keyResult, valResult T) T Primitive(p PrimitiveType, primitivePartner P) T } type PartnerAccessor[P any] interface { SchemaPartner(P) P FieldPartner(partnerStruct P, fieldID int, fieldName string) P ListElementPartner(P) P MapKeyPartner(P) P MapValuePartner(P) P } func VisitSchemaWithPartner[T, P any](sc *Schema, partner P, visitor SchemaWithPartnerVisitor[T, P], accessor PartnerAccessor[P]) (res T, err error) { if sc == nil { err = fmt.Errorf("%w: cannot visit nil schema", ErrInvalidArgument) return } if visitor == nil || accessor == nil { err = fmt.Errorf("%w: cannot visit with nil visitor or accessor", ErrInvalidArgument) return } defer func() { if r := recover(); r != nil { switch e := r.(type) { case string: err = fmt.Errorf("error encountered during schema visitor: %s", e) case error: err = fmt.Errorf("error encountered during schema visitor: %w", e) } } }() structPartner := accessor.SchemaPartner(partner) return visitor.Schema(sc, partner, visitStructWithPartner(sc.AsStruct(), structPartner, visitor, accessor)), nil } func visitStructWithPartner[T, P any](st StructType, partner P, visitor SchemaWithPartnerVisitor[T, P], accessor PartnerAccessor[P]) T { type ( beforeField interface { BeforeField(NestedField, P) } afterField interface { AfterField(NestedField, P) } ) bf, _ := visitor.(beforeField) af, _ := visitor.(afterField) fieldResults := make([]T, len(st.FieldList)) for i, f := range st.FieldList { fieldPartner := accessor.FieldPartner(partner, f.ID, f.Name) if bf != nil { bf.BeforeField(f, fieldPartner) } fieldResult := visitTypeWithPartner(f.Type, fieldPartner, visitor, accessor) fieldResults[i] = visitor.Field(f, fieldPartner, fieldResult) if af != nil { af.AfterField(f, fieldPartner) } } return visitor.Struct(st, partner, fieldResults) } func visitListWithPartner[T, P any](listType ListType, partner P, visitor SchemaWithPartnerVisitor[T, P], accessor PartnerAccessor[P]) T { type ( beforeListElem interface { BeforeListElement(NestedField, P) } afterListElem interface { AfterListElement(NestedField, P) } ) elemPartner := accessor.ListElementPartner(partner) if ble, ok := visitor.(beforeListElem); ok { ble.BeforeListElement(listType.ElementField(), elemPartner) } elemResult := visitTypeWithPartner(listType.Element, elemPartner, visitor, accessor) if ale, ok := visitor.(afterListElem); ok { ale.AfterListElement(listType.ElementField(), elemPartner) } return visitor.List(listType, partner, elemResult) } func visitMapWithPartner[T, P any](m MapType, partner P, visitor SchemaWithPartnerVisitor[T, P], accessor PartnerAccessor[P]) T { type ( beforeMapKey interface { BeforeMapKey(NestedField, P) } afterMapKey interface { AfterMapKey(NestedField, P) } beforeMapValue interface { BeforeMapValue(NestedField, P) } afterMapValue interface { AfterMapValue(NestedField, P) } ) keyPartner := accessor.MapKeyPartner(partner) if bmk, ok := visitor.(beforeMapKey); ok { bmk.BeforeMapKey(m.KeyField(), keyPartner) } keyResult := visitTypeWithPartner(m.KeyType, keyPartner, visitor, accessor) if amk, ok := visitor.(afterMapKey); ok { amk.AfterMapKey(m.KeyField(), keyPartner) } valPartner := accessor.MapValuePartner(partner) if bmv, ok := visitor.(beforeMapValue); ok { bmv.BeforeMapValue(m.ValueField(), valPartner) } valResult := visitTypeWithPartner(m.ValueType, valPartner, visitor, accessor) if amv, ok := visitor.(afterMapValue); ok { amv.AfterMapValue(m.ValueField(), valPartner) } return visitor.Map(m, partner, keyResult, valResult) } func visitTypeWithPartner[T, P any](t Type, fieldPartner P, visitor SchemaWithPartnerVisitor[T, P], accessor PartnerAccessor[P]) T { switch t := t.(type) { case *ListType: return visitListWithPartner(*t, fieldPartner, visitor, accessor) case *StructType: return visitStructWithPartner(*t, fieldPartner, visitor, accessor) case *MapType: return visitMapWithPartner(*t, fieldPartner, visitor, accessor) default: return visitor.Primitive(t.(PrimitiveType), fieldPartner) } } func makeCompatibleName(n string) string { if !validAvroName(n) { return sanitizeName(n) } return n } func validAvroName(n string) bool { if len(n) == 0 { panic("cannot validate empty name") } if !unicode.IsLetter(rune(n[0])) && n[0] != '_' { return false } for _, r := range n[1:] { if !unicode.In(r, unicode.Number, unicode.Letter) && r != '_' { return false } } return true } func sanitize(r rune) string { if unicode.IsDigit(r) { return "_" + string(r) } return fmt.Sprintf("_x%X", r) } func sanitizeName(n string) string { var b strings.Builder b.Grow(len(n)) first := n[0] if !(unicode.IsLetter(rune(first)) || first == '_') { b.WriteString(sanitize(rune(first))) } else { b.WriteByte(first) } for _, r := range n[1:] { if !unicode.In(r, unicode.Number, unicode.Letter) && r != '_' { b.WriteString(sanitize(r)) } else { b.WriteRune(r) } } return b.String() } func SanitizeColumnNames(sc *Schema) (*Schema, error) { result, err := Visit(sc, sanitizeColumnNameVisitor{}) if err != nil { return nil, err } return NewSchemaWithIdentifiers(sc.ID, sc.IdentifierFieldIDs, result.Type.(*StructType).FieldList...), nil } type sanitizeColumnNameVisitor struct{} func (sanitizeColumnNameVisitor) Schema(_ *Schema, structResult NestedField) NestedField { return structResult } func (sanitizeColumnNameVisitor) Field(field NestedField, fieldResult NestedField) NestedField { field.Type = fieldResult.Type field.Name = makeCompatibleName(field.Name) return field } func (sanitizeColumnNameVisitor) Struct(_ StructType, fieldResults []NestedField) NestedField { return NestedField{Type: &StructType{FieldList: fieldResults}} } func (sanitizeColumnNameVisitor) List(list ListType, elemResult NestedField) NestedField { list.Element = elemResult.Type return NestedField{Type: &list} } func (sanitizeColumnNameVisitor) Map(mapType MapType, keyResult, valueResult NestedField) NestedField { mapType.KeyType = keyResult.Type mapType.ValueType = valueResult.Type return NestedField{Type: &mapType} } func (sanitizeColumnNameVisitor) Primitive(p PrimitiveType) NestedField { return NestedField{Type: p} }