arrow/compute/scalar_set_lookup.go (167 lines of code) (raw):

// Licensed to the Apache Software Foundation (ASF) under one // or more contributor license agreements. See the NOTICE file // distributed with this work for additional information // regarding copyright ownership. The ASF licenses this file // to you under the Apache License, Version 2.0 (the // "License"); you may not use this file except in compliance // with the License. You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package compute import ( "context" "errors" "fmt" "github.com/apache/arrow-go/v18/arrow" "github.com/apache/arrow-go/v18/arrow/compute/exec" "github.com/apache/arrow-go/v18/arrow/compute/internal/kernels" "github.com/apache/arrow-go/v18/arrow/extensions" "github.com/apache/arrow-go/v18/internal/hashing" ) var ( isinDoc = FunctionDoc{ Summary: "Find each element in a set of values", Description: `For each element in "values", return true if it is found in a given set, false otherwise`, ArgNames: []string{"values"}, OptionsType: "SetOptions", OptionsRequired: true, } ) type NullMatchingBehavior = kernels.NullMatchingBehavior const ( NullMatchingMatch = kernels.NullMatchingMatch NullMatchingSkip = kernels.NullMatchingSkip NullMatchingEmitNull = kernels.NullMatchingEmitNull NullMatchingInconclusive = kernels.NullMatchingInconclusive ) type setLookupFunc struct { ScalarFunction } func (fn *setLookupFunc) Execute(ctx context.Context, opts FunctionOptions, args ...Datum) (Datum, error) { return execInternal(ctx, fn, opts, -1, args...) } func (fn *setLookupFunc) DispatchBest(vals ...arrow.DataType) (exec.Kernel, error) { ensureDictionaryDecoded(vals...) return fn.DispatchExact(vals...) } type SetOptions struct { ValueSet Datum NullBehavior NullMatchingBehavior } func (*SetOptions) TypeName() string { return "SetOptions" } func initSetLookup(ctx *exec.KernelCtx, args exec.KernelInitArgs) (exec.KernelState, error) { if args.Options == nil { return nil, fmt.Errorf("%w: calling a set lookup function without SetOptions", ErrInvalid) } opts, ok := args.Options.(*SetOptions) if !ok { return nil, fmt.Errorf("%w: expected SetOptions, got %T", ErrInvalid, args.Options) } valueset, ok := opts.ValueSet.(ArrayLikeDatum) if !ok { return nil, fmt.Errorf("%w: expected array-like datum, got %T", ErrInvalid, opts.ValueSet) } argType := args.Inputs[0] if (argType.ID() == arrow.STRING || argType.ID() == arrow.LARGE_STRING) && !arrow.IsBaseBinary(valueset.Type().ID()) { // don't implicitly cast from a non-binary type to string // since most types support casting to string and that may lead to // surprises. However we do want most other implicit casts return nil, fmt.Errorf("%w: array type doesn't match type of values set: %s vs %s", ErrInvalid, argType, valueset.Type()) } if !arrow.TypeEqual(valueset.Type(), argType) { result, err := CastDatum(ctx.Ctx, valueset, SafeCastOptions(argType)) if err == nil { defer result.Release() valueset = result.(ArrayLikeDatum) } else if CanCast(argType, valueset.Type()) { // avoid casting from non-binary types to string like above // otherwise will try to cast input array to valueset during // execution if (valueset.Type().ID() == arrow.STRING || valueset.Type().ID() == arrow.LARGE_STRING) && !arrow.IsBaseBinary(argType.ID()) { return nil, fmt.Errorf("%w: array type doesn't match type of values set: %s vs %s", ErrInvalid, argType, valueset.Type()) } } else { return nil, fmt.Errorf("%w: array type doesn't match type of values set: %s vs %s", ErrInvalid, argType, valueset.Type()) } } internalOpts := kernels.SetLookupOptions{ ValueSet: make([]exec.ArraySpan, 1), TotalLen: opts.ValueSet.Len(), NullBehavior: opts.NullBehavior, } switch valueset.Kind() { case KindArray: internalOpts.ValueSet[0].SetMembers(valueset.(*ArrayDatum).Value) internalOpts.ValueSetType = valueset.(*ArrayDatum).Type() case KindChunked: chnked := valueset.(*ChunkedDatum).Value internalOpts.ValueSetType = chnked.DataType() internalOpts.ValueSet = make([]exec.ArraySpan, len(chnked.Chunks())) for i, c := range chnked.Chunks() { internalOpts.ValueSet[i].SetMembers(c.Data()) } default: return nil, fmt.Errorf("%w: expected array or chunked array, got %s", ErrInvalid, opts.ValueSet.Kind()) } return kernels.CreateSetLookupState(internalOpts, exec.GetAllocator(ctx.Ctx)) } type setLookupState interface { Init(kernels.SetLookupOptions) error ValueType() arrow.DataType } func execIsIn(ctx *exec.KernelCtx, batch *exec.ExecSpan, out *exec.ExecResult) error { state := ctx.State.(setLookupState) ctx.Kernel.(*exec.ScalarKernel).Data = state in := batch.Values[0] if !arrow.TypeEqual(in.Type(), state.ValueType()) { materialized := in.Array.MakeArray() defer materialized.Release() castResult, err := CastArray(ctx.Ctx, materialized, SafeCastOptions(state.ValueType())) if err != nil { if errors.Is(err, arrow.ErrNotImplemented) { return fmt.Errorf("%w: array type doesn't match type of values set: %s vs %s", ErrInvalid, in.Type(), state.ValueType()) } return err } defer castResult.Release() var casted exec.ArraySpan casted.SetMembers(castResult.Data()) return kernels.DispatchIsIn(state, &casted, out) } return kernels.DispatchIsIn(state, &in.Array, out) } func IsIn(ctx context.Context, opts SetOptions, values Datum) (Datum, error) { return CallFunction(ctx, "is_in", &opts, values) } func IsInSet(ctx context.Context, valueSet, values Datum) (Datum, error) { return IsIn(ctx, SetOptions{ValueSet: valueSet}, values) } func RegisterScalarSetLookup(reg FunctionRegistry) { inBase := NewScalarFunction("is_in", Unary(), isinDoc) types := []exec.InputType{ exec.NewMatchedInput(exec.Primitive()), exec.NewIDInput(arrow.DECIMAL32), exec.NewIDInput(arrow.DECIMAL64), } outType := exec.NewOutputType(arrow.FixedWidthTypes.Boolean) for _, ty := range types { kn := exec.NewScalarKernel([]exec.InputType{ty}, outType, execIsIn, initSetLookup) kn.MemAlloc = exec.MemPrealloc kn.NullHandling = exec.NullComputedPrealloc if err := inBase.AddKernel(kn); err != nil { panic(err) } } binaryTypes := []exec.InputType{ exec.NewMatchedInput(exec.BinaryLike()), exec.NewMatchedInput(exec.LargeBinaryLike()), exec.NewExactInput(extensions.NewUUIDType()), exec.NewIDInput(arrow.FIXED_SIZE_BINARY), exec.NewIDInput(arrow.DECIMAL128), exec.NewIDInput(arrow.DECIMAL256), } for _, ty := range binaryTypes { kn := exec.NewScalarKernel([]exec.InputType{ty}, outType, execIsIn, initSetLookup) kn.MemAlloc = exec.MemPrealloc kn.NullHandling = exec.NullComputedPrealloc kn.CleanupFn = func(state exec.KernelState) error { s := state.(*kernels.SetLookupState[[]byte]) s.Lookup.(*hashing.BinaryMemoTable).Release() return nil } if err := inBase.AddKernel(kn); err != nil { panic(err) } } reg.AddFunction(&setLookupFunc{*inBase}, false) }