arrow/scalar/nested.go (656 lines of code) (raw):
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package scalar
import (
"bytes"
"errors"
"fmt"
"github.com/apache/arrow-go/v18/arrow"
"github.com/apache/arrow-go/v18/arrow/array"
"github.com/apache/arrow-go/v18/arrow/internal/debug"
"github.com/apache/arrow-go/v18/arrow/memory"
"golang.org/x/xerrors"
)
type ListScalar interface {
Scalar
GetList() arrow.Array
Release()
Retain()
}
type List struct {
scalar
Value arrow.Array
}
func (l *List) Release() {
if l.Value != nil {
l.Value.Release()
}
}
func (l *List) Retain() {
if l.Value != nil {
l.Value.Retain()
}
}
func (l *List) value() interface{} { return l.Value }
func (l *List) GetList() arrow.Array { return l.Value }
func (l *List) equals(rhs Scalar) bool {
return array.Equal(l.Value, rhs.(ListScalar).GetList())
}
func (l *List) Validate() (err error) {
if err = l.scalar.Validate(); err != nil {
return
}
if err = validateOptional(&l.scalar, l.Value, "value"); err != nil {
return
}
if !l.Valid {
return
}
valueType := l.Type.(arrow.ListLikeType).Elem()
listType := l.Type
if !arrow.TypeEqual(l.Value.DataType(), valueType) {
err = fmt.Errorf("%s scalar should have a value of type %s, got %s",
listType, valueType, l.Value.DataType())
}
return
}
func (l *List) ValidateFull() error { return l.Validate() }
func (l *List) CastTo(to arrow.DataType) (Scalar, error) {
if !l.Valid {
return MakeNullScalar(to), nil
}
if arrow.TypeEqual(l.Type, to) {
return l, nil
}
if to.ID() == arrow.STRING {
var bld bytes.Buffer
fmt.Fprint(&bld, l.Value)
buf := memory.NewBufferBytes(bld.Bytes())
defer buf.Release()
return NewStringScalarFromBuffer(buf), nil
}
return nil, fmt.Errorf("cannot convert non-nil list scalar to type %s", to)
}
func (l *List) String() string {
if !l.Valid {
return "null"
}
val, err := l.CastTo(arrow.BinaryTypes.String)
if err != nil {
return "..."
}
return string(val.(*String).Value.Bytes())
}
func NewListScalar(val arrow.Array) *List {
return &List{scalar{arrow.ListOf(val.DataType()), true}, array.MakeFromData(val.Data())}
}
func NewListScalarData(val arrow.ArrayData) *List {
return &List{scalar{arrow.ListOf(val.DataType()), true}, array.MakeFromData(val)}
}
type LargeList struct {
*List
}
func NewLargeListScalar(val arrow.Array) *LargeList {
return &LargeList{&List{scalar{arrow.LargeListOf(val.DataType()), true}, array.MakeFromData(val.Data())}}
}
func NewLargeListScalarData(val arrow.ArrayData) *LargeList {
return &LargeList{&List{scalar{arrow.LargeListOf(val.DataType()), true}, array.MakeFromData(val)}}
}
func makeMapType(typ *arrow.StructType) *arrow.MapType {
debug.Assert(typ.NumFields() == 2, "must pass struct with only 2 fields for MapScalar")
return arrow.MapOf(typ.Field(0).Type, typ.Field(1).Type)
}
type Map struct {
*List
}
func NewMapScalar(val arrow.Array) *Map {
return &Map{&List{scalar{makeMapType(val.DataType().(*arrow.StructType)), true}, array.MakeFromData(val.Data())}}
}
type FixedSizeList struct {
*List
}
func (f *FixedSizeList) Validate() (err error) {
if err = f.List.Validate(); err != nil {
return
}
if f.Valid {
listType := f.Type.(*arrow.FixedSizeListType)
if f.Value.Len() != int(listType.Len()) {
return fmt.Errorf("%s scalar should have a child value of length %d, got %d",
f.Type, listType.Len(), f.Value.Len())
}
}
return
}
func (f *FixedSizeList) ValidateFull() error { return f.Validate() }
func NewFixedSizeListScalar(val arrow.Array) *FixedSizeList {
return NewFixedSizeListScalarWithType(val, arrow.FixedSizeListOf(int32(val.Len()), val.DataType()))
}
func NewFixedSizeListScalarWithType(val arrow.Array, typ arrow.DataType) *FixedSizeList {
debug.Assert(val.Len() == int(typ.(*arrow.FixedSizeListType).Len()), "length of value for fixed size list scalar must match type")
return &FixedSizeList{&List{scalar{typ, true}, array.MakeFromData(val.Data())}}
}
type Vector []Scalar
type Struct struct {
scalar
Value Vector
}
func (s *Struct) Release() {
for _, v := range s.Value {
if v, ok := v.(Releasable); ok {
v.Release()
}
}
}
func (s *Struct) Field(name string) (Scalar, error) {
idx, ok := s.Type.(*arrow.StructType).FieldIdx(name)
if !ok {
return nil, fmt.Errorf("no field named %s found in struct scalar %s", name, s.Type)
}
return s.Value[idx], nil
}
func (s *Struct) value() interface{} { return s.Value }
func (s *Struct) String() string {
if !s.Valid {
return "null"
}
val, err := s.CastTo(arrow.BinaryTypes.String)
if err != nil {
return "..."
}
return string(val.(*String).Value.Bytes())
}
func (s *Struct) CastTo(to arrow.DataType) (Scalar, error) {
if !s.Valid {
return MakeNullScalar(to), nil
}
if to.ID() != arrow.STRING {
return nil, fmt.Errorf("cannot cast non-null struct scalar to type %s", to)
}
var bld bytes.Buffer
st := s.Type.(*arrow.StructType)
bld.WriteByte('{')
for i, v := range s.Value {
if i > 0 {
bld.WriteString(", ")
}
bld.WriteString(fmt.Sprintf("%s:%s = %s", st.Field(i).Name, st.Field(i).Type, v.String()))
}
bld.WriteByte('}')
buf := memory.NewBufferBytes(bld.Bytes())
defer buf.Release()
return NewStringScalarFromBuffer(buf), nil
}
func (s *Struct) equals(rhs Scalar) bool {
right := rhs.(*Struct)
if len(s.Value) != len(right.Value) {
return false
}
for i := range s.Value {
if !Equals(s.Value[i], right.Value[i]) {
return false
}
}
return true
}
func (s *Struct) Validate() (err error) {
if err = s.scalar.Validate(); err != nil {
return
}
if !s.Valid {
for _, v := range s.Value {
if v.IsValid() {
err = fmt.Errorf("%s scalar is marked null but has child values", s.Type)
return
}
}
return
}
st := s.Type.(*arrow.StructType)
num := st.NumFields()
if len(s.Value) != num {
return fmt.Errorf("non-null %s scalar should have %d child values, got %d", s.Type, num, len(s.Value))
}
for i, f := range st.Fields() {
if s.Value[i] == nil {
return fmt.Errorf("non-null %s scalar has missing child value at index %d", s.Type, i)
}
err = s.Value[i].Validate()
if err != nil {
return fmt.Errorf("%s scalar fails validation for child at index %d: %w", s.Type, i, err)
}
if !arrow.TypeEqual(s.Value[i].DataType(), f.Type) {
return fmt.Errorf("%s scalar should have a child value of type %s at index %d, got %s", s.Type, f.Type, i, s.Value[i].DataType())
}
}
return
}
func (s *Struct) ValidateFull() (err error) {
if err = s.scalar.ValidateFull(); err != nil {
return
}
if !s.Valid {
for _, v := range s.Value {
if v.IsValid() {
err = fmt.Errorf("%s scalar is marked null but has child values", s.Type)
return
}
}
return
}
st := s.Type.(*arrow.StructType)
num := st.NumFields()
if len(s.Value) != num {
return fmt.Errorf("non-null %s scalar should have %d child values, got %d", s.Type, num, len(s.Value))
}
for i, f := range st.Fields() {
if s.Value[i] == nil {
return fmt.Errorf("non-null %s scalar has missing child value at index %d", s.Type, i)
}
err = s.Value[i].ValidateFull()
if err != nil {
return fmt.Errorf("%s scalar fails validation for child at index %d: %w", s.Type, i, err)
}
if !arrow.TypeEqual(s.Value[i].DataType(), f.Type) {
return fmt.Errorf("%s scalar should have a child value of type %s at index %d, got %s", s.Type, f.Type, i, s.Value[i].DataType())
}
}
return
}
func NewStructScalar(val []Scalar, typ arrow.DataType) *Struct {
return &Struct{scalar{typ, true}, val}
}
func NewStructScalarWithNames(val []Scalar, names []string) (*Struct, error) {
if len(val) != len(names) {
return nil, xerrors.New("mismatching number of field names and child scalars")
}
fields := make([]arrow.Field, len(names))
for i, n := range names {
fields[i] = arrow.Field{Name: n, Type: val[i].DataType(), Nullable: true}
}
return NewStructScalar(val, arrow.StructOf(fields...)), nil
}
type Dictionary struct {
scalar
Value struct {
Index Scalar
Dict arrow.Array
}
}
func NewNullDictScalar(dt arrow.DataType) *Dictionary {
ret := &Dictionary{scalar: scalar{dt, false}}
ret.Value.Index = MakeNullScalar(dt.(*arrow.DictionaryType).IndexType)
ret.Value.Dict = nil
return ret
}
func NewDictScalar(index Scalar, dict arrow.Array) *Dictionary {
ret := &Dictionary{scalar: scalar{&arrow.DictionaryType{IndexType: index.DataType(), ValueType: dict.DataType()}, index.IsValid()}}
ret.Value.Index = index
ret.Value.Dict = dict
ret.Retain()
return ret
}
func (s *Dictionary) Data() []byte { return s.Value.Index.(PrimitiveScalar).Data() }
func (s *Dictionary) Retain() {
if r, ok := s.Value.Index.(Releasable); ok {
r.Retain()
}
if s.Value.Dict != (arrow.Array)(nil) {
s.Value.Dict.Retain()
}
}
func (s *Dictionary) Release() {
if r, ok := s.Value.Index.(Releasable); ok {
r.Release()
}
if s.Value.Dict != (arrow.Array)(nil) {
s.Value.Dict.Release()
}
}
func (s *Dictionary) Validate() (err error) {
dt, ok := s.Type.(*arrow.DictionaryType)
if !ok {
return errors.New("arrow/scalar: dictionary scalar should have type Dictionary")
}
if s.Value.Index == (Scalar)(nil) {
return fmt.Errorf("%s scalar doesn't have an index value", dt)
}
if err = s.Value.Index.Validate(); err != nil {
return fmt.Errorf("%s scalar fails validation for index value: %w", dt, err)
}
if !arrow.TypeEqual(s.Value.Index.DataType(), dt.IndexType) {
return fmt.Errorf("%s scalar should have an index value of type %s, got %s",
dt, dt.IndexType, s.Value.Index.DataType())
}
if s.IsValid() && !s.Value.Index.IsValid() {
return fmt.Errorf("non-null %s scalar has null index value", dt)
}
if !s.IsValid() && s.Value.Index.IsValid() {
return fmt.Errorf("null %s scalar has non-null index value", dt)
}
if !s.IsValid() {
return
}
if s.Value.Dict == (arrow.Array)(nil) {
return fmt.Errorf("%s scalar doesn't have a dictionary value", dt)
}
if !arrow.TypeEqual(s.Value.Dict.DataType(), dt.ValueType) {
return fmt.Errorf("%s scalar's value type doesn't match dict type: got %s", dt, s.Value.Dict.DataType())
}
return
}
func (s *Dictionary) ValidateFull() (err error) {
if err = s.Validate(); err != nil {
return
}
if !s.Value.Index.IsValid() {
return nil
}
max := s.Value.Dict.Len() - 1
switch idx := s.Value.Index.value().(type) {
case int8:
if idx < 0 || int(idx) > max {
err = fmt.Errorf("%s scalar index value out of bounds: %d", s.DataType(), idx)
}
case uint8:
if int(idx) > max {
err = fmt.Errorf("%s scalar index value out of bounds: %d", s.DataType(), idx)
}
case int16:
if idx < 0 || int(idx) > max {
err = fmt.Errorf("%s scalar index value out of bounds: %d", s.DataType(), idx)
}
case uint16:
if int(idx) > max {
err = fmt.Errorf("%s scalar index value out of bounds: %d", s.DataType(), idx)
}
case int32:
if idx < 0 || int(idx) > max {
err = fmt.Errorf("%s scalar index value out of bounds: %d", s.DataType(), idx)
}
case uint32:
if int(idx) > max {
err = fmt.Errorf("%s scalar index value out of bounds: %d", s.DataType(), idx)
}
case int64:
if idx < 0 || int(idx) > max {
err = fmt.Errorf("%s scalar index value out of bounds: %d", s.DataType(), idx)
}
case uint64:
if int(idx) > max {
err = fmt.Errorf("%s scalar index value out of bounds: %d", s.DataType(), idx)
}
}
return
}
func (s *Dictionary) String() string {
if !s.Valid {
return "null"
}
return s.Value.Dict.String() + "[" + s.Value.Index.String() + "]"
}
func (s *Dictionary) equals(rhs Scalar) bool {
return s.Value.Index.equals(rhs.(*Dictionary).Value.Index) &&
array.Equal(s.Value.Dict, rhs.(*Dictionary).Value.Dict)
}
func (s *Dictionary) CastTo(arrow.DataType) (Scalar, error) {
return nil, fmt.Errorf("cast from scalar %s not implemented", s.DataType())
}
func (s *Dictionary) GetEncodedValue() (Scalar, error) {
dt := s.Type.(*arrow.DictionaryType)
if !s.IsValid() {
return MakeNullScalar(dt.ValueType), nil
}
var idxValue int
switch dt.IndexType.ID() {
case arrow.INT8:
idxValue = int(s.Value.Index.value().(int8))
case arrow.UINT8:
idxValue = int(s.Value.Index.value().(uint8))
case arrow.INT16:
idxValue = int(s.Value.Index.value().(int16))
case arrow.UINT16:
idxValue = int(s.Value.Index.value().(uint16))
case arrow.INT32:
idxValue = int(s.Value.Index.value().(int32))
case arrow.UINT32:
idxValue = int(s.Value.Index.value().(uint32))
case arrow.INT64:
idxValue = int(s.Value.Index.value().(int64))
case arrow.UINT64:
idxValue = int(s.Value.Index.value().(uint64))
default:
return nil, fmt.Errorf("unimplemented dictionary type %s", dt.IndexType)
}
return GetScalar(s.Value.Dict, idxValue)
}
func (s *Dictionary) value() interface{} {
return s.Value.Index.value()
}
type Union interface {
Scalar
ChildValue() Scalar
Release()
}
type SparseUnion struct {
scalar
TypeCode arrow.UnionTypeCode
Value []Scalar
ChildID int
}
func (s *SparseUnion) equals(rhs Scalar) bool {
right := rhs.(*SparseUnion)
return Equals(s.ChildValue(), right.ChildValue())
}
func (s *SparseUnion) value() interface{} { return s.ChildValue() }
func (s *SparseUnion) String() string {
dt := s.Type.(*arrow.SparseUnionType)
val := s.ChildValue()
return "union{" + dt.Fields()[dt.ChildIDs()[s.TypeCode]].String() + " = " + val.String() + "}"
}
func (s *SparseUnion) Retain() {
for _, v := range s.Value {
if v, ok := v.(Releasable); ok {
v.Retain()
}
}
}
func (s *SparseUnion) Release() {
for _, v := range s.Value {
if v, ok := v.(Releasable); ok {
v.Release()
}
}
}
func (s *SparseUnion) Validate() (err error) {
dt := s.Type.(*arrow.SparseUnionType)
if dt.NumFields() != len(s.Value) {
return fmt.Errorf("sparse union scalar value had %d fields but type has %d fields", dt.NumFields(), len(s.Value))
}
if s.TypeCode < 0 || int(s.TypeCode) >= len(dt.ChildIDs()) || dt.ChildIDs()[s.TypeCode] == arrow.InvalidUnionChildID {
return fmt.Errorf("%s scalar has invalid type code %d", dt, s.TypeCode)
}
for i, f := range dt.Fields() {
v := s.Value[i]
if !arrow.TypeEqual(f.Type, v.DataType()) {
return fmt.Errorf("%s value for field %s had incorrect type of %s", dt, f, v.DataType())
}
if err = v.Validate(); err != nil {
return err
}
}
return
}
func (s *SparseUnion) ValidateFull() (err error) {
dt := s.Type.(*arrow.SparseUnionType)
if dt.NumFields() != len(s.Value) {
return fmt.Errorf("sparse union scalar value had %d fields but type has %d fields", dt.NumFields(), len(s.Value))
}
if s.TypeCode < 0 || int(s.TypeCode) >= len(dt.ChildIDs()) || dt.ChildIDs()[s.TypeCode] == arrow.InvalidUnionChildID {
return fmt.Errorf("%s scalar has invalid type code %d", dt, s.TypeCode)
}
for i, f := range dt.Fields() {
v := s.Value[i]
if !arrow.TypeEqual(f.Type, v.DataType()) {
return fmt.Errorf("%s value for field %s had incorrect type of %s", dt, f, v.DataType())
}
if err = v.ValidateFull(); err != nil {
return err
}
}
return
}
func (s *SparseUnion) CastTo(to arrow.DataType) (Scalar, error) {
if !s.Valid {
return MakeNullScalar(to), nil
}
switch to.ID() {
case arrow.STRING:
return NewStringScalar(s.String()), nil
case arrow.LARGE_STRING:
return NewLargeStringScalar(s.String()), nil
}
return nil, fmt.Errorf("cannot cast non-nil union to type other than string")
}
func (s *SparseUnion) ChildValue() Scalar { return s.Value[s.ChildID] }
func NewSparseUnionScalar(val []Scalar, code arrow.UnionTypeCode, dt *arrow.SparseUnionType) *SparseUnion {
ret := &SparseUnion{
scalar: scalar{dt, true},
TypeCode: code,
Value: val,
ChildID: dt.ChildIDs()[code],
}
ret.Valid = ret.Value[ret.ChildID].IsValid()
return ret
}
func NewSparseUnionScalarFromValue(val Scalar, idx int, dt *arrow.SparseUnionType) *SparseUnion {
code := dt.TypeCodes()[idx]
values := make([]Scalar, dt.NumFields())
for i, f := range dt.Fields() {
if i == idx {
values[i] = val
} else {
values[i] = MakeNullScalar(f.Type)
}
}
return NewSparseUnionScalar(values, code, dt)
}
type DenseUnion struct {
scalar
TypeCode arrow.UnionTypeCode
Value Scalar
}
func (s *DenseUnion) equals(rhs Scalar) bool {
right := rhs.(*DenseUnion)
return Equals(s.Value, right.Value)
}
func (s *DenseUnion) value() interface{} { return s.ChildValue() }
func (s *DenseUnion) String() string {
dt := s.Type.(*arrow.DenseUnionType)
return "union{" + dt.Fields()[dt.ChildIDs()[s.TypeCode]].String() + " = " + s.Value.String() + "}"
}
func (s *DenseUnion) Retain() {
if v, ok := s.Value.(Releasable); ok {
v.Retain()
}
}
func (s *DenseUnion) Release() {
if v, ok := s.Value.(Releasable); ok {
v.Release()
}
}
func (s *DenseUnion) Validate() (err error) {
dt := s.Type.(*arrow.DenseUnionType)
if s.TypeCode < 0 || int(s.TypeCode) >= len(dt.ChildIDs()) || dt.ChildIDs()[s.TypeCode] == arrow.InvalidUnionChildID {
return fmt.Errorf("%s scalar has invalid type code %d", dt, s.TypeCode)
}
fieldType := dt.Fields()[dt.ChildIDs()[s.TypeCode]].Type
if !arrow.TypeEqual(fieldType, s.Value.DataType()) {
return fmt.Errorf("%s scalar with type code %d should have an underlying value of type %s, got %s",
s.Type, s.TypeCode, fieldType, s.Value.DataType())
}
return s.Value.Validate()
}
func (s *DenseUnion) ValidateFull() error {
dt := s.Type.(*arrow.DenseUnionType)
if s.TypeCode < 0 || int(s.TypeCode) >= len(dt.ChildIDs()) || dt.ChildIDs()[s.TypeCode] == arrow.InvalidUnionChildID {
return fmt.Errorf("%s scalar has invalid type code %d", dt, s.TypeCode)
}
fieldType := dt.Fields()[dt.ChildIDs()[s.TypeCode]].Type
if !arrow.TypeEqual(fieldType, s.Value.DataType()) {
return fmt.Errorf("%s scalar with type code %d should have an underlying value of type %s, got %s",
s.Type, s.TypeCode, fieldType, s.Value.DataType())
}
return s.Value.ValidateFull()
}
func (s *DenseUnion) CastTo(to arrow.DataType) (Scalar, error) {
if !s.Valid {
return MakeNullScalar(to), nil
}
switch to.ID() {
case arrow.STRING:
return NewStringScalar(s.String()), nil
case arrow.LARGE_STRING:
return NewLargeStringScalar(s.String()), nil
}
return nil, fmt.Errorf("cannot cast non-nil union to type other than string")
}
func (s *DenseUnion) ChildValue() Scalar { return s.Value }
func NewDenseUnionScalar(v Scalar, code arrow.UnionTypeCode, dt *arrow.DenseUnionType) *DenseUnion {
return &DenseUnion{scalar: scalar{dt, v.IsValid()}, TypeCode: code, Value: v}
}
type RunEndEncoded struct {
scalar
Value Scalar
}
func NewRunEndEncodedScalar(v Scalar, dt *arrow.RunEndEncodedType) *RunEndEncoded {
return &RunEndEncoded{scalar: scalar{dt, v.IsValid()}, Value: v}
}
func (s *RunEndEncoded) Release() {
if r, ok := s.Value.(Releasable); ok {
r.Release()
}
}
func (s *RunEndEncoded) value() interface{} { return s.Value.value() }
func (s *RunEndEncoded) Validate() (err error) {
if err = s.Value.Validate(); err != nil {
return
}
if err = validateOptional(&s.scalar, s.value(), "value"); err != nil {
return
}
if !s.Valid {
return
}
if s.Type.ID() != arrow.RUN_END_ENCODED {
return fmt.Errorf("%w: run-end-encoded scalar should not have type %s",
arrow.ErrInvalid, s.Type)
}
if !arrow.TypeEqual(s.Value.DataType(), s.Type.(*arrow.RunEndEncodedType).Encoded()) {
return fmt.Errorf("%w: run-end-encoded scalar value type %s does not match type %s",
arrow.ErrInvalid, s.Value.DataType(), s.Type)
}
return
}
func (s *RunEndEncoded) ValidateFull() error { return s.Validate() }
func (s *RunEndEncoded) equals(rhs Scalar) bool {
other := rhs.(*RunEndEncoded)
return Equals(s.Value, other.Value)
}
func (s *RunEndEncoded) String() string {
return s.Value.String()
}
func (s *RunEndEncoded) CastTo(to arrow.DataType) (Scalar, error) {
if !s.Valid {
return MakeNullScalar(to), nil
}
if arrow.TypeEqual(s.Type, to) {
return s, nil
}
if otherREE, ok := to.(*arrow.RunEndEncodedType); ok {
sc, err := s.Value.CastTo(otherREE.Encoded())
if err != nil {
return nil, err
}
return NewRunEndEncodedScalar(sc, otherREE), nil
}
return s.Value.CastTo(to)
}