func executeScalarBatch()

in arrow/compute/exprs/exec.go [485:643]


func executeScalarBatch(ctx context.Context, input compute.ExecBatch, exp expr.Expression, ext ExtensionIDSet) (compute.Datum, error) {
	if !exp.IsScalar() {
		return nil, fmt.Errorf("%w: ExecuteScalarExpression cannot execute non-scalar expressions",
			arrow.ErrInvalid)
	}

	switch e := exp.(type) {
	case expr.Literal:
		return literalToDatum(compute.GetAllocator(ctx), e, ext)
	case *expr.FieldReference:
		return execFieldRef(ctx, e, input, ext)
	case *expr.Cast:
		if e.Input == nil {
			return nil, fmt.Errorf("%w: cast without argument to cast", arrow.ErrInvalid)
		}

		arg, err := executeScalarBatch(ctx, input, e.Input, ext)
		if err != nil {
			return nil, err
		}
		defer arg.Release()

		dt, _, err := FromSubstraitType(e.Type, ext)
		if err != nil {
			return nil, fmt.Errorf("%w: could not determine type for cast", err)
		}

		var opts *compute.CastOptions
		switch e.FailureBehavior {
		case types.BehaviorThrowException:
			opts = compute.UnsafeCastOptions(dt)
		case types.BehaviorUnspecified:
			return nil, fmt.Errorf("%w: cast behavior unspecified", arrow.ErrInvalid)
		case types.BehaviorReturnNil:
			return nil, fmt.Errorf("%w: cast behavior return nil", arrow.ErrNotImplemented)
		}
		return compute.CastDatum(ctx, arg, opts)
	case *expr.ScalarFunction:
		var (
			err       error
			allScalar = true
			args      = make([]compute.Datum, e.NArgs())
		)
		for i := 0; i < e.NArgs(); i++ {
			switch v := e.Arg(i).(type) {
			case types.Enum:
				args[i] = compute.NewDatum(scalar.NewStringScalar(string(v)))
			case expr.Expression:
				args[i], err = executeScalarBatch(ctx, input, v, ext)
				if err != nil {
					return nil, err
				}
				defer args[i].Release()

				if args[i].Kind() != compute.KindScalar {
					allScalar = false
				}
			default:
				return nil, arrow.ErrNotImplemented
			}
		}

		_, conv, ok := ext.DecodeFunction(e.FuncRef())
		if !ok {
			return nil, fmt.Errorf("%w: %s", arrow.ErrNotImplemented, e.Name())
		}

		fname, args, opts, err := conv(e, args)
		if err != nil {
			return nil, err
		}

		argTypes := make([]arrow.DataType, len(args))
		for i, arg := range args {
			argTypes[i] = arg.(compute.ArrayLikeDatum).Type()
		}

		ectx := compute.GetExecCtx(ctx)
		fn, ok := ectx.Registry.GetFunction(fname)
		if !ok {
			return nil, arrow.ErrInvalid
		}

		if fn.Kind() != compute.FuncScalar {
			return nil, arrow.ErrInvalid
		}

		k, err := fn.DispatchBest(argTypes...)
		if err != nil {
			return nil, err
		}

		var newArgs []compute.Datum
		// cast arguments if necessary
		for i, arg := range args {
			if !arrow.TypeEqual(argTypes[i], arg.(compute.ArrayLikeDatum).Type()) {
				if newArgs == nil {
					newArgs = make([]compute.Datum, len(args))
					copy(newArgs, args)
				}
				newArgs[i], err = compute.CastDatum(ctx, arg, compute.SafeCastOptions(argTypes[i]))
				if err != nil {
					return nil, err
				}
				defer newArgs[i].Release()
			}
		}
		if newArgs != nil {
			args = newArgs
		}

		kctx := &exec.KernelCtx{Ctx: ctx, Kernel: k}
		init := k.GetInitFn()
		kinitArgs := exec.KernelInitArgs{Kernel: k, Inputs: argTypes, Options: opts}
		if init != nil {
			kctx.State, err = init(kctx, kinitArgs)
			if err != nil {
				return nil, err
			}
		}

		executor := compute.NewScalarExecutor()
		if err := executor.Init(kctx, kinitArgs); err != nil {
			return nil, err
		}

		batch := compute.ExecBatch{Values: args}
		if allScalar {
			batch.Len = 1
		} else {
			batch.Len = input.Len
		}

		ctx, cancel := context.WithCancel(context.Background())
		defer cancel()

		ch := make(chan compute.Datum, ectx.ExecChannelSize)
		go func() {
			defer close(ch)
			if err = executor.Execute(ctx, &batch, ch); err != nil {
				cancel()
			}
		}()

		result := executor.WrapResults(ctx, ch, false)
		if err == nil {
			debug.Assert(executor.CheckResultType(result) == nil, "invalid result type")
		}

		if ctx.Err() == context.Canceled && result != nil {
			result.Release()
			result = nil
		}

		return result, err
	}

	return nil, arrow.ErrNotImplemented
}