arrow/compute/internal/kernels/base_arithmetic.go (826 lines of code) (raw):

// Licensed to the Apache Software Foundation (ASF) under one // or more contributor license agreements. See the NOTICE file // distributed with this work for additional information // regarding copyright ownership. The ASF licenses this file // to you under the Apache License, Version 2.0 (the // "License"); you may not use this file except in compliance // with the License. You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. //go:build go1.18 package kernels import ( "fmt" "math" "math/bits" "github.com/apache/arrow-go/v18/arrow" "github.com/apache/arrow-go/v18/arrow/compute/exec" "github.com/apache/arrow-go/v18/arrow/decimal128" "github.com/apache/arrow-go/v18/arrow/decimal256" "github.com/apache/arrow-go/v18/arrow/internal/debug" "github.com/apache/arrow-go/v18/internal/utils" "golang.org/x/exp/constraints" ) type ArithmeticOp int8 const ( OpAdd ArithmeticOp = iota OpSub OpMul OpDiv OpAbsoluteValue OpNegate // NO SIMD for the following yet OpSqrt OpPower OpSin OpCos OpTan OpAsin OpAcos OpAtan OpAtan2 OpLn OpLog10 OpLog2 OpLog1p OpLogb // End NO SIMD OpSign // Checked versions will not use SIMD except for float32/float64 impls OpAddChecked OpSubChecked OpMulChecked OpDivChecked OpAbsoluteValueChecked OpNegateChecked // No SIMD impls for the rest of these yet OpSqrtChecked OpPowerChecked OpSinChecked OpCosChecked OpTanChecked OpAsinChecked OpAcosChecked OpLnChecked OpLog10Checked OpLog2Checked OpLog1pChecked OpLogbChecked ) func mulWithOverflow[T arrow.IntType | arrow.UintType](a, b T) (T, error) { min, max := MinOf[T](), MaxOf[T]() switch { case a > 0: if b > 0 { if a > (max / b) { return 0, errOverflow } } else { if b < (min / a) { return 0, errOverflow } } case b > 0: if a < (min / b) { return 0, errOverflow } default: if (a != 0) && (b < (max / a)) { return 0, errOverflow } } return a * b, nil } func getGoArithmeticBinary[OutT, Arg0T, Arg1T arrow.NumericType](op func(a Arg0T, b Arg1T, e *error) OutT) binaryOps[OutT, Arg0T, Arg1T] { return binaryOps[OutT, Arg0T, Arg1T]{ arrArr: func(_ *exec.KernelCtx, left []Arg0T, right []Arg1T, out []OutT) error { var err error for i := range out { out[i] = op(left[i], right[i], &err) } return err }, arrScalar: func(_ *exec.KernelCtx, left []Arg0T, right Arg1T, out []OutT) error { var err error for i := range out { out[i] = op(left[i], right, &err) } return err }, scalarArr: func(_ *exec.KernelCtx, left Arg0T, right []Arg1T, out []OutT) error { var err error for i := range out { out[i] = op(left, right[i], &err) } return err }, } } var ( errOverflow = fmt.Errorf("%w: overflow", arrow.ErrInvalid) errDivByZero = fmt.Errorf("%w: divide by zero", arrow.ErrInvalid) errNegativeSqrt = fmt.Errorf("%w: square root of negative number", arrow.ErrInvalid) errNegativePower = fmt.Errorf("%w: integers to negative integer powers are not allowed", arrow.ErrInvalid) errDomainErr = fmt.Errorf("%w: domain error", arrow.ErrInvalid) errLogZero = fmt.Errorf("%w: logarithm of zero", arrow.ErrInvalid) errLogNeg = fmt.Errorf("%w: logarithm of negative number", arrow.ErrInvalid) ) func getGoArithmeticOpIntegral[InT, OutT arrow.UintType | arrow.IntType](op ArithmeticOp) exec.ArrayKernelExec { switch op { case OpAdd: return ScalarBinary(getGoArithmeticBinary(func(a, b InT, _ *error) OutT { return OutT(a + b) })) case OpSub: return ScalarBinary(getGoArithmeticBinary(func(a, b InT, _ *error) OutT { return OutT(a - b) })) case OpMul: return ScalarBinary(getGoArithmeticBinary(func(a, b InT, _ *error) OutT { return OutT(a * b) })) case OpDiv: return ScalarBinaryNotNull(func(_ *exec.KernelCtx, a, b InT, e *error) OutT { if b == 0 { *e = errDivByZero return 0 } return OutT(a / b) }) case OpAbsoluteValue: if ones := ^InT(0); ones < 0 { shiftBy := (SizeOf[InT]() * 8) - 1 return ScalarUnary(func(_ *exec.KernelCtx, arg []InT, out []OutT) error { // get abs without branching for i, v := range arg { // right shift (sign check) mask := v >> shiftBy // add the mask '+' and '-' balance v = v + mask // invert and return out[i] = OutT(v ^ mask) } return nil }) } if SizeOf[InT]() == SizeOf[OutT]() { return ScalarUnary(func(_ *exec.KernelCtx, arg []InT, out []OutT) error { in, output := arrow.GetBytes(arg), arrow.GetBytes(out) copy(output, in) return nil }) } else { return ScalarUnary(func(_ *exec.KernelCtx, arg []InT, out []OutT) error { DoStaticCast(arg, out) return nil }) } case OpNegate: return ScalarUnary(func(_ *exec.KernelCtx, arg []InT, out []OutT) error { for i, v := range arg { out[i] = OutT(-v) } return nil }) case OpSign: if ^InT(0) < 0 { var neg int8 = -1 return ScalarUnary(func(_ *exec.KernelCtx, arg []InT, out []OutT) error { neg := OutT(neg) for i, v := range arg { switch { case v > 0: out[i] = 1 case v < 0: out[i] = neg default: out[i] = 0 } } return nil }) } return ScalarUnary(func(_ *exec.KernelCtx, arg []InT, out []OutT) error { for i, v := range arg { if v > 0 { out[i] = 1 } else { out[i] = 0 } } return nil }) case OpPower: return ScalarBinary(getGoArithmeticBinary(func(a, b InT, err *error) OutT { if b < 0 { *err = errNegativePower return 0 } // integer power var ( base = uint64(a) exp = uint64(b) pow uint64 = 1 ) // right to left 0(logn) power for exp != 0 { if exp&1 != 0 { pow *= base } base *= base exp >>= 1 } return OutT(pow) })) case OpAddChecked: shiftBy := (SizeOf[InT]() * 8) - 1 // ie: uint32 does a >> 31 at the end, int32 does >> 30 if ^InT(0) < 0 { shiftBy-- } return ScalarBinaryNotNull(func(_ *exec.KernelCtx, a, b InT, e *error) (out OutT) { out = OutT(a + b) // see math/bits/bits.go Add64 for explanation of logic carry := (OutT(a&b) | (OutT(a|b) &^ out)) >> shiftBy if carry > 0 { *e = errOverflow } return }) case OpSubChecked: shiftBy := (SizeOf[InT]() * 8) - 1 // ie: uint32 does a >> 31 at the end, int32 does >> 30 if ^InT(0) < 0 { shiftBy-- } return ScalarBinaryNotNull(func(_ *exec.KernelCtx, a, b InT, e *error) (out OutT) { out = OutT(a - b) // see math/bits/bits.go Sub64 for explanation of bit logic carry := (OutT(^a&b) | (^OutT(a^b) & out)) >> shiftBy if carry > 0 { *e = errOverflow } return }) case OpMulChecked: return ScalarBinary(getGoArithmeticBinary(func(a, b InT, e *error) (out OutT) { o, err := mulWithOverflow(a, b) if err != nil { *e = err } return OutT(o) })) case OpDivChecked: return ScalarBinaryNotNull(func(_ *exec.KernelCtx, a, b InT, e *error) (out OutT) { if b == 0 { *e = errDivByZero return } return OutT(a / b) }) case OpAbsoluteValueChecked: if ones := ^InT(0); ones < 0 { shiftBy := (SizeOf[InT]() * 8) - 1 min := MinOf[InT]() return ScalarUnary(func(_ *exec.KernelCtx, arg []InT, out []OutT) error { for i, v := range arg { if v == min { return errOverflow } // right shift (sign check) mask := v >> shiftBy // add the mask '+' and '-' balance v = v + mask // invert and return out[i] = OutT(v ^ mask) } return nil }) } if SizeOf[InT]() == SizeOf[OutT]() { return ScalarUnary(func(_ *exec.KernelCtx, arg []InT, out []OutT) error { in, output := arrow.GetBytes(arg), arrow.GetBytes(out) copy(output, in) return nil }) } else { return ScalarUnary(func(_ *exec.KernelCtx, arg []InT, out []OutT) error { DoStaticCast(arg, out) return nil }) } case OpNegateChecked: if ones := ^InT(0); ones < 0 { min := MinOf[InT]() // signed return ScalarUnary(func(_ *exec.KernelCtx, arg []InT, out []OutT) error { for i, v := range arg { if v != min { out[i] = OutT(-v) } else { return errOverflow } } return nil }) } case OpPowerChecked: return ScalarBinaryNotNull(func(_ *exec.KernelCtx, base, exp InT, e *error) OutT { if exp < 0 { *e = errNegativePower return 0 } else if exp == 0 { return 1 } // left to right 0(logn) power with overflow checks var ( overflow bool bitmask = uint64(1) << (63 - bits.LeadingZeros64(uint64(exp))) pow InT = 1 err error ) for bitmask != 0 { pow, err = mulWithOverflow(pow, pow) overflow = overflow || (err != nil) if uint64(exp)&bitmask != 0 { pow, err = mulWithOverflow(pow, base) overflow = overflow || (err != nil) } bitmask >>= 1 } if overflow { *e = errOverflow } return OutT(pow) }) } debug.Assert(false, "invalid arithmetic op") return nil } func getGoArithmeticOpFloating[InT, OutT constraints.Float](op ArithmeticOp) exec.ArrayKernelExec { switch op { case OpAdd, OpAddChecked: return ScalarBinary(getGoArithmeticBinary(func(a, b InT, _ *error) OutT { return OutT(a + b) })) case OpSub, OpSubChecked: return ScalarBinary(getGoArithmeticBinary(func(a, b InT, _ *error) OutT { return OutT(a - b) })) case OpMul, OpMulChecked: return ScalarBinary(getGoArithmeticBinary(func(a, b InT, _ *error) OutT { return OutT(a * b) })) case OpDiv: return ScalarBinaryNotNull(func(_ *exec.KernelCtx, a, b InT, e *error) (out OutT) { return OutT(a / b) }) case OpDivChecked: return ScalarBinaryNotNull(func(_ *exec.KernelCtx, a, b InT, e *error) (out OutT) { if b == 0 { *e = errDivByZero return } return OutT(a / b) }) case OpAbsoluteValue, OpAbsoluteValueChecked: return ScalarUnary(func(_ *exec.KernelCtx, arg []InT, out []OutT) error { for i, v := range arg { out[i] = OutT(math.Abs(float64(v))) } return nil }) case OpNegate, OpNegateChecked: return ScalarUnary(func(_ *exec.KernelCtx, arg []InT, out []OutT) error { for i, v := range arg { out[i] = OutT(-v) } return nil }) case OpSqrt: return ScalarUnary(func(_ *exec.KernelCtx, arg []InT, out []OutT) error { for i, v := range arg { out[i] = OutT(math.Sqrt(float64(v))) } return nil }) case OpSqrtChecked: return ScalarUnaryNotNull(func(_ *exec.KernelCtx, arg InT, e *error) OutT { if arg < 0 { *e = errNegativeSqrt return OutT(math.NaN()) } return OutT(math.Sqrt(float64(arg))) }) case OpSign: return ScalarUnary(func(_ *exec.KernelCtx, arg []InT, out []OutT) error { for i, v := range arg { switch { case math.IsNaN(float64(v)): out[i] = OutT(v) case v == 0: out[i] = 0 case math.Signbit(float64(v)): out[i] = -1 default: out[i] = 1 } } return nil }) case OpPower, OpPowerChecked: return ScalarBinary(getGoArithmeticBinary(func(a, b InT, _ *error) OutT { return OutT(math.Pow(float64(a), float64(b))) })) case OpSin: return ScalarUnary(func(_ *exec.KernelCtx, vals []InT, out []OutT) error { for i, v := range vals { out[i] = OutT(math.Sin(float64(v))) } return nil }) case OpSinChecked: return ScalarUnaryNotNull(func(_ *exec.KernelCtx, arg InT, e *error) OutT { if math.IsInf(float64(arg), 0) { *e = errDomainErr return OutT(arg) } return OutT(math.Sin(float64(arg))) }) case OpCos: return ScalarUnary(func(_ *exec.KernelCtx, vals []InT, out []OutT) error { for i, v := range vals { out[i] = OutT(math.Cos(float64(v))) } return nil }) case OpCosChecked: return ScalarUnaryNotNull(func(_ *exec.KernelCtx, arg InT, e *error) OutT { if math.IsInf(float64(arg), 0) { *e = errDomainErr return OutT(arg) } return OutT(math.Cos(float64(arg))) }) case OpTan: return ScalarUnary(func(_ *exec.KernelCtx, vals []InT, out []OutT) error { for i, v := range vals { out[i] = OutT(math.Tan(float64(v))) } return nil }) case OpTanChecked: return ScalarUnaryNotNull(func(_ *exec.KernelCtx, arg InT, e *error) OutT { if math.IsInf(float64(arg), 0) { *e = errDomainErr return OutT(arg) } return OutT(math.Tan(float64(arg))) }) case OpAsin: return ScalarUnary(func(_ *exec.KernelCtx, vals []InT, out []OutT) error { for i, v := range vals { out[i] = OutT(math.Asin(float64(v))) } return nil }) case OpAsinChecked: return ScalarUnaryNotNull(func(_ *exec.KernelCtx, arg InT, e *error) OutT { if arg < -1 || arg > 1 { *e = errDomainErr return OutT(arg) } return OutT(math.Asin(float64(arg))) }) case OpAcos: return ScalarUnary(func(_ *exec.KernelCtx, vals []InT, out []OutT) error { for i, v := range vals { out[i] = OutT(math.Acos(float64(v))) } return nil }) case OpAcosChecked: return ScalarUnaryNotNull(func(_ *exec.KernelCtx, arg InT, e *error) OutT { if arg < -1 || arg > 1 { *e = errDomainErr return OutT(arg) } return OutT(math.Acos(float64(arg))) }) case OpAtan: return ScalarUnary(func(_ *exec.KernelCtx, vals []InT, out []OutT) error { for i, v := range vals { out[i] = OutT(math.Atan(float64(v))) } return nil }) case OpAtan2: return ScalarBinary(getGoArithmeticBinary(func(a, b InT, _ *error) OutT { return OutT(math.Atan2(float64(a), float64(b))) })) case OpLn: return ScalarUnary(func(_ *exec.KernelCtx, vals []InT, out []OutT) error { for i, v := range vals { out[i] = OutT(math.Log(float64(v))) } return nil }) case OpLnChecked: return ScalarUnaryNotNull(func(_ *exec.KernelCtx, arg InT, e *error) OutT { switch { case arg == 0: *e = errLogZero return OutT(arg) case arg < 0: *e = errLogNeg return OutT(arg) } return OutT(math.Log(float64(arg))) }) case OpLog10: return ScalarUnary(func(_ *exec.KernelCtx, vals []InT, out []OutT) error { for i, v := range vals { out[i] = OutT(math.Log10(float64(v))) } return nil }) case OpLog10Checked: return ScalarUnaryNotNull(func(_ *exec.KernelCtx, arg InT, e *error) OutT { switch { case arg == 0: *e = errLogZero return OutT(arg) case arg < 0: *e = errLogNeg return OutT(arg) } return OutT(math.Log10(float64(arg))) }) case OpLog2: return ScalarUnary(func(_ *exec.KernelCtx, vals []InT, out []OutT) error { for i, v := range vals { out[i] = OutT(math.Log2(float64(v))) } return nil }) case OpLog2Checked: return ScalarUnaryNotNull(func(_ *exec.KernelCtx, arg InT, e *error) OutT { switch { case arg == 0: *e = errLogZero return OutT(arg) case arg < 0: *e = errLogNeg return OutT(arg) } return OutT(math.Log2(float64(arg))) }) case OpLog1p: return ScalarUnary(func(_ *exec.KernelCtx, vals []InT, out []OutT) error { for i, v := range vals { out[i] = OutT(math.Log1p(float64(v))) } return nil }) case OpLog1pChecked: return ScalarUnaryNotNull(func(_ *exec.KernelCtx, arg InT, e *error) OutT { switch { case arg == -1: *e = errLogZero return OutT(arg) case arg < -1: *e = errLogNeg return OutT(arg) } return OutT(math.Log1p(float64(arg))) }) case OpLogb: return ScalarBinary(getGoArithmeticBinary(func(x, base InT, _ *error) OutT { if x == 0 { if base == 0 || base < 0 { return OutT(math.NaN()) } else { return OutT(math.Inf(-1)) } } else if x < 0 { return OutT(math.NaN()) } return OutT(math.Log(float64(x)) / math.Log(float64(base))) })) case OpLogbChecked: return ScalarBinaryNotNull((func(_ *exec.KernelCtx, x, base InT, e *error) OutT { if x == 0 || base == 0 { *e = errLogZero return OutT(x) } else if x < 0 || base < 0 { *e = errLogNeg return OutT(x) } return OutT(math.Log(float64(x)) / math.Log(float64(base))) })) } debug.Assert(false, "invalid arithmetic op") return nil } func timeDurationOp[OutT, Arg0T, Arg1T ~int32 | ~int64](multiple int64, op ArithmeticOp) exec.ArrayKernelExec { switch op { case OpAdd: return ScalarBinary(getGoArithmeticBinary(func(a Arg0T, b Arg1T, e *error) OutT { result := OutT(a) + OutT(b) if result < 0 || multiple <= int64(result) { *e = fmt.Errorf("%w: %d is not within acceptable range of [0, %d) s", arrow.ErrInvalid, result, multiple) } return result })) case OpSub: return ScalarBinary(getGoArithmeticBinary(func(a Arg0T, b Arg1T, e *error) OutT { result := OutT(a) - OutT(b) if result < 0 || multiple <= int64(result) { *e = fmt.Errorf("%w: %d is not within acceptable range of [0, %d) s", arrow.ErrInvalid, result, multiple) } return result })) case OpAddChecked: shiftBy := (SizeOf[OutT]() * 8) - 1 // ie: uint32 does a >> 31 at the end, int32 does >> 30 if ^OutT(0) < 0 { shiftBy-- } return ScalarBinary(getGoArithmeticBinary(func(a Arg0T, b Arg1T, e *error) (result OutT) { left, right := OutT(a), OutT(b) result = left + right carry := ((left & right) | ((left | right) &^ result)) >> shiftBy if carry > 0 { *e = errOverflow return } if result < 0 || multiple <= int64(result) { *e = fmt.Errorf("%w: %d is not within acceptable range of [0, %d) s", arrow.ErrInvalid, result, multiple) } return })) case OpSubChecked: shiftBy := (SizeOf[OutT]() * 8) - 1 // ie: uint32 does a >> 31 at the end, int32 does >> 30 if ^OutT(0) < 0 { shiftBy-- } return ScalarBinary(getGoArithmeticBinary(func(a Arg0T, b Arg1T, e *error) (result OutT) { left, right := OutT(a), OutT(b) result = left - right carry := ((^left & right) | (^(left ^ right) & result)) >> shiftBy if carry > 0 { *e = errOverflow return } if result < 0 || multiple <= int64(result) { *e = fmt.Errorf("%w: %d is not within acceptable range of [0, %d) s", arrow.ErrInvalid, result, multiple) } return })) } return nil } func SubtractDate32(op ArithmeticOp) exec.ArrayKernelExec { const secondsPerDay = 86400 switch op { case OpSub: return ScalarBinary(getGoArithmeticBinary(func(a, b arrow.Time32, e *error) (result arrow.Duration) { return arrow.Duration((a - b) * secondsPerDay) })) case OpSubChecked: return ScalarBinary(getGoArithmeticBinary(func(a, b arrow.Time32, e *error) (result arrow.Duration) { result = arrow.Duration(a) - arrow.Duration(b) val, ok := utils.Mul64(int64(result), secondsPerDay) if !ok { *e = errOverflow } return arrow.Duration(val) })) } panic("invalid op for subtractDate32") } type decOps[T decimal128.Num | decimal256.Num] struct { Add func(T, T) T Sub func(T, T) T Div func(T, T) T Mul func(T, T) T Abs func(T) T Neg func(T) T Sign func(T) int } var dec128Ops = decOps[decimal128.Num]{ Add: func(a, b decimal128.Num) decimal128.Num { return a.Add(b) }, Sub: func(a, b decimal128.Num) decimal128.Num { return a.Sub(b) }, Mul: func(a, b decimal128.Num) decimal128.Num { return a.Mul(b) }, Div: func(a, b decimal128.Num) decimal128.Num { a, _ = a.Div(b) return a }, Abs: func(a decimal128.Num) decimal128.Num { return a.Abs() }, Neg: func(a decimal128.Num) decimal128.Num { return a.Negate() }, Sign: func(a decimal128.Num) int { return a.Sign() }, } var dec256Ops = decOps[decimal256.Num]{ Add: func(a, b decimal256.Num) decimal256.Num { return a.Add(b) }, Sub: func(a, b decimal256.Num) decimal256.Num { return a.Sub(b) }, Mul: func(a, b decimal256.Num) decimal256.Num { return a.Mul(b) }, Div: func(a, b decimal256.Num) decimal256.Num { a, _ = a.Div(b) return a }, Abs: func(a decimal256.Num) decimal256.Num { return a.Abs() }, Neg: func(a decimal256.Num) decimal256.Num { return a.Negate() }, Sign: func(a decimal256.Num) int { return a.Sign() }, } func getArithmeticOpDecimalImpl[T decimal128.Num | decimal256.Num](op ArithmeticOp, fns decOps[T]) exec.ArrayKernelExec { if op >= OpAddChecked { op -= OpAddChecked // decimal128/256 checked is the same as unchecked } switch op { case OpAdd: return ScalarBinaryNotNull(func(_ *exec.KernelCtx, arg0, arg1 T, _ *error) T { return fns.Add(arg0, arg1) }) case OpSub: return ScalarBinaryNotNull(func(_ *exec.KernelCtx, arg0, arg1 T, _ *error) T { return fns.Sub(arg0, arg1) }) case OpMul: return ScalarBinaryNotNull(func(_ *exec.KernelCtx, arg0, arg1 T, _ *error) T { return fns.Mul(arg0, arg1) }) case OpDiv: var z T return ScalarBinaryNotNull(func(_ *exec.KernelCtx, arg0, arg1 T, e *error) (out T) { if arg1 == z { *e = errDivByZero return } return fns.Div(arg0, arg1) }) case OpAbsoluteValue: return ScalarUnaryNotNull(func(_ *exec.KernelCtx, arg T, _ *error) T { return fns.Abs(arg) }) case OpNegate: return ScalarUnaryNotNull(func(_ *exec.KernelCtx, arg T, _ *error) T { return fns.Neg(arg) }) case OpSign: return ScalarUnaryNotNull(func(_ *exec.KernelCtx, arg T, _ *error) int64 { return int64(fns.Sign(arg)) }) } debug.Assert(false, "unimplemented arithmetic op") return nil } func getArithmeticDecimal[T decimal128.Num | decimal256.Num](op ArithmeticOp) exec.ArrayKernelExec { var def T switch any(def).(type) { case decimal128.Num: return getArithmeticOpDecimalImpl(op, dec128Ops) case decimal256.Num: return getArithmeticOpDecimalImpl(op, dec256Ops) } panic("should never get here") } func ArithmeticExecSameType(ty arrow.Type, op ArithmeticOp) exec.ArrayKernelExec { switch ty { case arrow.INT8: return getArithmeticOpIntegral[int8, int8](op) case arrow.UINT8: return getArithmeticOpIntegral[uint8, uint8](op) case arrow.INT16: return getArithmeticOpIntegral[int16, int16](op) case arrow.UINT16: return getArithmeticOpIntegral[uint16, uint16](op) case arrow.INT32, arrow.TIME32: return getArithmeticOpIntegral[int32, int32](op) case arrow.UINT32: return getArithmeticOpIntegral[uint32, uint32](op) case arrow.INT64, arrow.TIME64, arrow.DATE64, arrow.TIMESTAMP, arrow.DURATION: return getArithmeticOpIntegral[int64, int64](op) case arrow.UINT64: return getArithmeticOpIntegral[uint64, uint64](op) case arrow.FLOAT32: return getArithmeticOpFloating[float32, float32](op) case arrow.FLOAT64: return getArithmeticOpFloating[float64, float64](op) } debug.Assert(false, "invalid arithmetic type") return nil } func arithmeticExec[InT arrow.IntType | arrow.UintType](oty arrow.Type, op ArithmeticOp) exec.ArrayKernelExec { switch oty { case arrow.INT8: return getArithmeticOpIntegral[InT, int8](op) case arrow.UINT8: return getArithmeticOpIntegral[InT, uint8](op) case arrow.INT16: return getArithmeticOpIntegral[InT, int16](op) case arrow.UINT16: return getArithmeticOpIntegral[InT, uint16](op) case arrow.INT32, arrow.TIME32: return getArithmeticOpIntegral[InT, int32](op) case arrow.UINT32: return getArithmeticOpIntegral[InT, uint32](op) case arrow.INT64, arrow.TIME64, arrow.DATE64, arrow.TIMESTAMP, arrow.DURATION: return getArithmeticOpIntegral[InT, int64](op) case arrow.UINT64: return getArithmeticOpIntegral[InT, uint64](op) } debug.Assert(false, "arithmetic integral to floating not implemented") return nil } func ArithmeticExec(ity, oty arrow.Type, op ArithmeticOp) exec.ArrayKernelExec { if ity == oty { return ArithmeticExecSameType(ity, op) } switch ity { case arrow.INT8: return arithmeticExec[int8](oty, op) case arrow.UINT8: return arithmeticExec[uint8](oty, op) case arrow.INT16: return arithmeticExec[int16](oty, op) case arrow.UINT16: return arithmeticExec[uint16](oty, op) case arrow.INT32, arrow.TIME32: return arithmeticExec[int32](oty, op) case arrow.UINT32: return arithmeticExec[uint32](oty, op) case arrow.INT64, arrow.TIME64, arrow.DATE64, arrow.TIMESTAMP, arrow.DURATION: return arithmeticExec[int64](oty, op) case arrow.UINT64: return arithmeticExec[uint64](oty, op) case arrow.FLOAT32: if oty == arrow.FLOAT32 { return getArithmeticOpFloating[float32, float32](op) } return getArithmeticOpFloating[float32, float64](op) case arrow.FLOAT64: if oty == arrow.FLOAT32 { return getArithmeticOpFloating[float64, float32](op) } return getArithmeticOpFloating[float64, float64](op) } return nil }