arrow/compute/internal/kernels/scalar_set_lookup.go (248 lines of code) (raw):
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package kernels
import (
"fmt"
"github.com/apache/arrow-go/v18/arrow"
"github.com/apache/arrow-go/v18/arrow/bitutil"
"github.com/apache/arrow-go/v18/arrow/compute/exec"
"github.com/apache/arrow-go/v18/arrow/internal/debug"
"github.com/apache/arrow-go/v18/arrow/memory"
"github.com/apache/arrow-go/v18/internal/bitutils"
"github.com/apache/arrow-go/v18/internal/hashing"
)
type NullMatchingBehavior int8
const (
NullMatchingMatch NullMatchingBehavior = iota
NullMatchingSkip
NullMatchingEmitNull
NullMatchingInconclusive
)
func visitBinary[OffsetT int32 | int64](data *exec.ArraySpan, valid func([]byte) error, null func() error) error {
if data.Len == 0 {
return nil
}
rawBytes := data.Buffers[2].Buf
offsets := exec.GetSpanOffsets[OffsetT](data, 1)
return bitutils.VisitBitBlocksShort(data.Buffers[0].Buf, data.Offset, data.Len,
func(pos int64) error {
return valid(rawBytes[offsets[pos]:offsets[pos+1]])
}, null)
}
func visitNumeric[T arrow.FixedWidthType](data *exec.ArraySpan, valid func(T) error, null func() error) error {
if data.Len == 0 {
return nil
}
values := exec.GetSpanValues[T](data, 1)
return bitutils.VisitBitBlocksShort(data.Buffers[0].Buf, data.Offset, data.Len,
func(pos int64) error {
return valid(values[pos])
}, null)
}
func visitFSB(data *exec.ArraySpan, valid func([]byte) error, null func() error) error {
if data.Len == 0 {
return nil
}
sz := int64(data.Type.(arrow.FixedWidthDataType).Bytes())
rawBytes := data.Buffers[1].Buf
return bitutils.VisitBitBlocksShort(data.Buffers[0].Buf, data.Offset, data.Len,
func(pos int64) error {
return valid(rawBytes[pos*sz : (pos+1)*sz])
}, null)
}
type SetLookupOptions struct {
ValueSetType arrow.DataType
TotalLen int64
ValueSet []exec.ArraySpan
NullBehavior NullMatchingBehavior
}
type lookupState interface {
Init(SetLookupOptions) error
}
func CreateSetLookupState(opts SetLookupOptions, alloc memory.Allocator) (exec.KernelState, error) {
valueSetType := opts.ValueSetType
if valueSetType.ID() == arrow.EXTENSION {
valueSetType = valueSetType.(arrow.ExtensionType).StorageType()
}
var state lookupState
switch ty := valueSetType.(type) {
case arrow.BinaryDataType:
switch ty.Layout().Buffers[1].ByteWidth {
case 4:
state = &SetLookupState[[]byte]{
Alloc: alloc,
visitFn: visitBinary[int32],
}
case 8:
state = &SetLookupState[[]byte]{
Alloc: alloc,
visitFn: visitBinary[int64],
}
}
case arrow.FixedWidthDataType:
switch ty.Bytes() {
case 1:
state = &SetLookupState[uint8]{
Alloc: alloc,
visitFn: visitNumeric[uint8],
}
case 2:
state = &SetLookupState[uint16]{
Alloc: alloc,
visitFn: visitNumeric[uint16],
}
case 4:
state = &SetLookupState[uint32]{
Alloc: alloc,
visitFn: visitNumeric[uint32],
}
case 8:
state = &SetLookupState[uint64]{
Alloc: alloc,
visitFn: visitNumeric[uint64],
}
default:
state = &SetLookupState[[]byte]{
Alloc: alloc,
visitFn: visitFSB,
}
}
default:
return nil, fmt.Errorf("%w: unsupported type %s for SetLookup functions", arrow.ErrInvalid, opts.ValueSetType)
}
return state, state.Init(opts)
}
type SetLookupState[T hashing.MemoTypes] struct {
visitFn func(*exec.ArraySpan, func(T) error, func() error) error
ValueSetType arrow.DataType
Alloc memory.Allocator
Lookup hashing.TypedMemoTable[T]
// When there are duplicates in value set, memotable indices
// must be mapped back to indices in the value set
MemoIndexToValueIndex []int32
NullIndex int32
NullBehavior NullMatchingBehavior
}
func (s *SetLookupState[T]) ValueType() arrow.DataType {
return s.ValueSetType
}
func (s *SetLookupState[T]) Init(opts SetLookupOptions) error {
s.ValueSetType = opts.ValueSetType
s.NullBehavior = opts.NullBehavior
s.MemoIndexToValueIndex = make([]int32, 0, opts.TotalLen)
s.NullIndex = -1
memoType := s.ValueSetType.ID()
if memoType == arrow.EXTENSION {
memoType = s.ValueSetType.(arrow.ExtensionType).StorageType().ID()
}
lookup, err := newMemoTable(s.Alloc, memoType)
if err != nil {
return err
}
s.Lookup = lookup.(hashing.TypedMemoTable[T])
if s.Lookup == nil {
return fmt.Errorf("unsupported type %s for SetLookup functions", s.ValueSetType)
}
var offset int64
for _, c := range opts.ValueSet {
if err := s.AddArrayValueSet(&c, offset); err != nil {
return err
}
offset += c.Len
}
lookupNull, _ := s.Lookup.GetNull()
if s.NullBehavior != NullMatchingSkip && lookupNull >= 0 {
s.NullIndex = int32(lookupNull)
}
return nil
}
func (s *SetLookupState[T]) AddArrayValueSet(data *exec.ArraySpan, startIdx int64) error {
idx := startIdx
return s.visitFn(data,
func(v T) error {
memoSize := len(s.MemoIndexToValueIndex)
memoIdx, found, err := s.Lookup.InsertOrGet(v)
if err != nil {
return err
}
if !found {
debug.Assert(memoIdx == memoSize, "inconsistent memo index and size")
s.MemoIndexToValueIndex = append(s.MemoIndexToValueIndex, int32(idx))
} else {
debug.Assert(memoIdx < memoSize, "inconsistent memo index and size")
}
idx++
return nil
}, func() error {
memoSize := len(s.MemoIndexToValueIndex)
nullIdx, found := s.Lookup.GetOrInsertNull()
if !found {
debug.Assert(nullIdx == memoSize, "inconsistent memo index and size")
s.MemoIndexToValueIndex = append(s.MemoIndexToValueIndex, int32(idx))
} else {
debug.Assert(nullIdx < memoSize, "inconsistent memo index and size")
}
idx++
return nil
})
}
func DispatchIsIn(state lookupState, in *exec.ArraySpan, out *exec.ExecResult) error {
inType := in.Type
if inType.ID() == arrow.EXTENSION {
inType = inType.(arrow.ExtensionType).StorageType()
}
switch ty := inType.(type) {
case arrow.BinaryDataType:
return isInKernelExec(state.(*SetLookupState[[]byte]), in, out)
case arrow.FixedWidthDataType:
switch ty.Bytes() {
case 1:
return isInKernelExec(state.(*SetLookupState[uint8]), in, out)
case 2:
return isInKernelExec(state.(*SetLookupState[uint16]), in, out)
case 4:
return isInKernelExec(state.(*SetLookupState[uint32]), in, out)
case 8:
return isInKernelExec(state.(*SetLookupState[uint64]), in, out)
default:
return isInKernelExec(state.(*SetLookupState[[]byte]), in, out)
}
default:
return fmt.Errorf("%w: unsupported type %s for is_in function", arrow.ErrInvalid, in.Type)
}
}
func isInKernelExec[T hashing.MemoTypes](state *SetLookupState[T], in *exec.ArraySpan, out *exec.ExecResult) error {
writerBool := bitutil.NewBitmapWriter(out.Buffers[1].Buf, int(out.Offset), int(out.Len))
defer writerBool.Finish()
writerNulls := bitutil.NewBitmapWriter(out.Buffers[0].Buf, int(out.Offset), int(out.Len))
defer writerNulls.Finish()
valueSetHasNull := state.NullIndex != -1
return state.visitFn(in,
func(v T) error {
switch {
case state.Lookup.Exists(v):
writerBool.Set()
writerNulls.Set()
case state.NullBehavior == NullMatchingInconclusive && valueSetHasNull:
writerBool.Clear()
writerNulls.Clear()
default:
writerBool.Clear()
writerNulls.Set()
}
writerBool.Next()
writerNulls.Next()
return nil
}, func() error {
switch {
case state.NullBehavior == NullMatchingMatch && valueSetHasNull:
writerBool.Set()
writerNulls.Set()
case state.NullBehavior == NullMatchingSkip || (!valueSetHasNull && state.NullBehavior == NullMatchingMatch):
writerBool.Clear()
writerNulls.Set()
default:
writerBool.Clear()
writerNulls.Clear()
}
writerBool.Next()
writerNulls.Next()
return nil
})
}