in arrow/compute/exprs/exec.go [485:643]
func executeScalarBatch(ctx context.Context, input compute.ExecBatch, exp expr.Expression, ext ExtensionIDSet) (compute.Datum, error) {
if !exp.IsScalar() {
return nil, fmt.Errorf("%w: ExecuteScalarExpression cannot execute non-scalar expressions",
arrow.ErrInvalid)
}
switch e := exp.(type) {
case expr.Literal:
return literalToDatum(compute.GetAllocator(ctx), e, ext)
case *expr.FieldReference:
return execFieldRef(ctx, e, input, ext)
case *expr.Cast:
if e.Input == nil {
return nil, fmt.Errorf("%w: cast without argument to cast", arrow.ErrInvalid)
}
arg, err := executeScalarBatch(ctx, input, e.Input, ext)
if err != nil {
return nil, err
}
defer arg.Release()
dt, _, err := FromSubstraitType(e.Type, ext)
if err != nil {
return nil, fmt.Errorf("%w: could not determine type for cast", err)
}
var opts *compute.CastOptions
switch e.FailureBehavior {
case types.BehaviorThrowException:
opts = compute.UnsafeCastOptions(dt)
case types.BehaviorUnspecified:
return nil, fmt.Errorf("%w: cast behavior unspecified", arrow.ErrInvalid)
case types.BehaviorReturnNil:
return nil, fmt.Errorf("%w: cast behavior return nil", arrow.ErrNotImplemented)
}
return compute.CastDatum(ctx, arg, opts)
case *expr.ScalarFunction:
var (
err error
allScalar = true
args = make([]compute.Datum, e.NArgs())
)
for i := 0; i < e.NArgs(); i++ {
switch v := e.Arg(i).(type) {
case types.Enum:
args[i] = compute.NewDatum(scalar.NewStringScalar(string(v)))
case expr.Expression:
args[i], err = executeScalarBatch(ctx, input, v, ext)
if err != nil {
return nil, err
}
defer args[i].Release()
if args[i].Kind() != compute.KindScalar {
allScalar = false
}
default:
return nil, arrow.ErrNotImplemented
}
}
_, conv, ok := ext.DecodeFunction(e.FuncRef())
if !ok {
return nil, fmt.Errorf("%w: %s", arrow.ErrNotImplemented, e.Name())
}
fname, args, opts, err := conv(e, args)
if err != nil {
return nil, err
}
argTypes := make([]arrow.DataType, len(args))
for i, arg := range args {
argTypes[i] = arg.(compute.ArrayLikeDatum).Type()
}
ectx := compute.GetExecCtx(ctx)
fn, ok := ectx.Registry.GetFunction(fname)
if !ok {
return nil, arrow.ErrInvalid
}
if fn.Kind() != compute.FuncScalar {
return nil, arrow.ErrInvalid
}
k, err := fn.DispatchBest(argTypes...)
if err != nil {
return nil, err
}
var newArgs []compute.Datum
// cast arguments if necessary
for i, arg := range args {
if !arrow.TypeEqual(argTypes[i], arg.(compute.ArrayLikeDatum).Type()) {
if newArgs == nil {
newArgs = make([]compute.Datum, len(args))
copy(newArgs, args)
}
newArgs[i], err = compute.CastDatum(ctx, arg, compute.SafeCastOptions(argTypes[i]))
if err != nil {
return nil, err
}
defer newArgs[i].Release()
}
}
if newArgs != nil {
args = newArgs
}
kctx := &exec.KernelCtx{Ctx: ctx, Kernel: k}
init := k.GetInitFn()
kinitArgs := exec.KernelInitArgs{Kernel: k, Inputs: argTypes, Options: opts}
if init != nil {
kctx.State, err = init(kctx, kinitArgs)
if err != nil {
return nil, err
}
}
executor := compute.NewScalarExecutor()
if err := executor.Init(kctx, kinitArgs); err != nil {
return nil, err
}
batch := compute.ExecBatch{Values: args}
if allScalar {
batch.Len = 1
} else {
batch.Len = input.Len
}
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
ch := make(chan compute.Datum, ectx.ExecChannelSize)
go func() {
defer close(ch)
if err = executor.Execute(ctx, &batch, ch); err != nil {
cancel()
}
}()
result := executor.WrapResults(ctx, ch, false)
if err == nil {
debug.Assert(executor.CheckResultType(result) == nil, "invalid result type")
}
if ctx.Err() == context.Canceled && result != nil {
result.Release()
result = nil
}
return result, err
}
return nil, arrow.ErrNotImplemented
}