in arrow/compute/internal/kernels/base_arithmetic.go [146:376]
func getGoArithmeticOpIntegral[InT, OutT arrow.UintType | arrow.IntType](op ArithmeticOp) exec.ArrayKernelExec {
switch op {
case OpAdd:
return ScalarBinary(getGoArithmeticBinary(func(a, b InT, _ *error) OutT { return OutT(a + b) }))
case OpSub:
return ScalarBinary(getGoArithmeticBinary(func(a, b InT, _ *error) OutT { return OutT(a - b) }))
case OpMul:
return ScalarBinary(getGoArithmeticBinary(func(a, b InT, _ *error) OutT { return OutT(a * b) }))
case OpDiv:
return ScalarBinaryNotNull(func(_ *exec.KernelCtx, a, b InT, e *error) OutT {
if b == 0 {
*e = errDivByZero
return 0
}
return OutT(a / b)
})
case OpAbsoluteValue:
if ones := ^InT(0); ones < 0 {
shiftBy := (SizeOf[InT]() * 8) - 1
return ScalarUnary(func(_ *exec.KernelCtx, arg []InT, out []OutT) error {
// get abs without branching
for i, v := range arg {
// right shift (sign check)
mask := v >> shiftBy
// add the mask '+' and '-' balance
v = v + mask
// invert and return
out[i] = OutT(v ^ mask)
}
return nil
})
}
if SizeOf[InT]() == SizeOf[OutT]() {
return ScalarUnary(func(_ *exec.KernelCtx, arg []InT, out []OutT) error {
in, output := arrow.GetBytes(arg), arrow.GetBytes(out)
copy(output, in)
return nil
})
} else {
return ScalarUnary(func(_ *exec.KernelCtx, arg []InT, out []OutT) error {
DoStaticCast(arg, out)
return nil
})
}
case OpNegate:
return ScalarUnary(func(_ *exec.KernelCtx, arg []InT, out []OutT) error {
for i, v := range arg {
out[i] = OutT(-v)
}
return nil
})
case OpSign:
if ^InT(0) < 0 {
var neg int8 = -1
return ScalarUnary(func(_ *exec.KernelCtx, arg []InT, out []OutT) error {
neg := OutT(neg)
for i, v := range arg {
switch {
case v > 0:
out[i] = 1
case v < 0:
out[i] = neg
default:
out[i] = 0
}
}
return nil
})
}
return ScalarUnary(func(_ *exec.KernelCtx, arg []InT, out []OutT) error {
for i, v := range arg {
if v > 0 {
out[i] = 1
} else {
out[i] = 0
}
}
return nil
})
case OpPower:
return ScalarBinary(getGoArithmeticBinary(func(a, b InT, err *error) OutT {
if b < 0 {
*err = errNegativePower
return 0
}
// integer power
var (
base = uint64(a)
exp = uint64(b)
pow uint64 = 1
)
// right to left 0(logn) power
for exp != 0 {
if exp&1 != 0 {
pow *= base
}
base *= base
exp >>= 1
}
return OutT(pow)
}))
case OpAddChecked:
shiftBy := (SizeOf[InT]() * 8) - 1
// ie: uint32 does a >> 31 at the end, int32 does >> 30
if ^InT(0) < 0 {
shiftBy--
}
return ScalarBinaryNotNull(func(_ *exec.KernelCtx, a, b InT, e *error) (out OutT) {
out = OutT(a + b)
// see math/bits/bits.go Add64 for explanation of logic
carry := (OutT(a&b) | (OutT(a|b) &^ out)) >> shiftBy
if carry > 0 {
*e = errOverflow
}
return
})
case OpSubChecked:
shiftBy := (SizeOf[InT]() * 8) - 1
// ie: uint32 does a >> 31 at the end, int32 does >> 30
if ^InT(0) < 0 {
shiftBy--
}
return ScalarBinaryNotNull(func(_ *exec.KernelCtx, a, b InT, e *error) (out OutT) {
out = OutT(a - b)
// see math/bits/bits.go Sub64 for explanation of bit logic
carry := (OutT(^a&b) | (^OutT(a^b) & out)) >> shiftBy
if carry > 0 {
*e = errOverflow
}
return
})
case OpMulChecked:
return ScalarBinary(getGoArithmeticBinary(func(a, b InT, e *error) (out OutT) {
o, err := mulWithOverflow(a, b)
if err != nil {
*e = err
}
return OutT(o)
}))
case OpDivChecked:
return ScalarBinaryNotNull(func(_ *exec.KernelCtx, a, b InT, e *error) (out OutT) {
if b == 0 {
*e = errDivByZero
return
}
return OutT(a / b)
})
case OpAbsoluteValueChecked:
if ones := ^InT(0); ones < 0 {
shiftBy := (SizeOf[InT]() * 8) - 1
min := MinOf[InT]()
return ScalarUnary(func(_ *exec.KernelCtx, arg []InT, out []OutT) error {
for i, v := range arg {
if v == min {
return errOverflow
}
// right shift (sign check)
mask := v >> shiftBy
// add the mask '+' and '-' balance
v = v + mask
// invert and return
out[i] = OutT(v ^ mask)
}
return nil
})
}
if SizeOf[InT]() == SizeOf[OutT]() {
return ScalarUnary(func(_ *exec.KernelCtx, arg []InT, out []OutT) error {
in, output := arrow.GetBytes(arg), arrow.GetBytes(out)
copy(output, in)
return nil
})
} else {
return ScalarUnary(func(_ *exec.KernelCtx, arg []InT, out []OutT) error {
DoStaticCast(arg, out)
return nil
})
}
case OpNegateChecked:
if ones := ^InT(0); ones < 0 {
min := MinOf[InT]()
// signed
return ScalarUnary(func(_ *exec.KernelCtx, arg []InT, out []OutT) error {
for i, v := range arg {
if v != min {
out[i] = OutT(-v)
} else {
return errOverflow
}
}
return nil
})
}
case OpPowerChecked:
return ScalarBinaryNotNull(func(_ *exec.KernelCtx, base, exp InT, e *error) OutT {
if exp < 0 {
*e = errNegativePower
return 0
} else if exp == 0 {
return 1
}
// left to right 0(logn) power with overflow checks
var (
overflow bool
bitmask = uint64(1) << (63 - bits.LeadingZeros64(uint64(exp)))
pow InT = 1
err error
)
for bitmask != 0 {
pow, err = mulWithOverflow(pow, pow)
overflow = overflow || (err != nil)
if uint64(exp)&bitmask != 0 {
pow, err = mulWithOverflow(pow, base)
overflow = overflow || (err != nil)
}
bitmask >>= 1
}
if overflow {
*e = errOverflow
}
return OutT(pow)
})
}
debug.Assert(false, "invalid arithmetic op")
return nil
}