arrow/compute/internal/kernels/_lib/base_arithmetic.cc (405 lines of code) (raw):
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <arch.h>
#include <math.h>
#include <stdint.h>
#include <limits.h>
#include "types.h"
// Corresponds to equivalent ArithmeticOp enum in base_arithmetic.go
// for passing across which operation to perform. This allows simpler
// implementation at the cost of having to pass the extra int8 and
// perform a switch.
//
// In cases of small arrays, this is completely negligible. In cases
// of large arrays, the time saved by using SIMD here is significantly
// worth the cost.
enum class optype : int8_t {
ADD,
SUB,
MUL,
DIV,
ABSOLUTE_VALUE,
NEGATE,
SQRT,
POWER,
SIN,
COS,
TAN,
ASIN,
ACOS,
ATAN,
ATAN2,
LN,
LOG10,
LOG2,
LOG1P,
LOGB,
SIGN,
// this impl doesn't actually perform any overflow checks as we need
// to only run overflow checks on non-null entries
ADD_CHECKED,
SUB_CHECKED,
MUL_CHECKED,
DIV_CHECKED,
ABSOLUTE_VALUE_CHECKED,
NEGATE_CHECKED,
SQRT_CHECKED,
POWER_CHECKED,
SIN_CHECKED,
COS_CHECKED,
TAN_CHECKED,
ASIN_CHECKED,
ACOS_CHECKED,
LN_CHECKED,
LOG10_CHECKED,
LOG2_CHECKED,
LOG1P_CHECKED,
LOGB_CHECKED,
};
struct Add {
template <typename T, typename Arg0, typename Arg1>
static constexpr T Call(Arg0 left, Arg1 right) {
if constexpr (is_arithmetic_v<T>)
return left + right;
}
};
struct Sub {
template <typename T, typename Arg0, typename Arg1>
static constexpr T Call(Arg0 left, Arg1 right) {
if constexpr (is_arithmetic_v<T>)
return left - right;
}
};
struct AddChecked {
template <typename T, typename Arg0, typename Arg1>
static constexpr T Call(Arg0 left, Arg1 right) {
static_assert(is_same<T, Arg0>::value && is_same<T, Arg1>::value, "");
if constexpr(is_arithmetic_v<T>) {
return left + right;
}
}
};
struct SubChecked {
template <typename T, typename Arg0, typename Arg1>
static constexpr T Call(Arg0 left, Arg1 right) {
static_assert(is_same<T, Arg0>::value && is_same<T, Arg1>::value, "");
if constexpr(is_arithmetic_v<T>) {
return left - right;
}
}
};
template <typename T>
using maybe_make_unsigned = conditional_t<is_integral_v<T> && !is_same_v<T, bool>, make_unsigned_t<T>, T>;
template <typename T, typename Unsigned = maybe_make_unsigned<T>>
constexpr Unsigned to_unsigned(T signed_) {
return static_cast<Unsigned>(signed_);
}
struct Multiply {
static_assert(is_same_v<decltype(int8_t() * int8_t()), int32_t>, "");
static_assert(is_same_v<decltype(uint8_t() * uint8_t()), int32_t>, "");
static_assert(is_same_v<decltype(int16_t() * int16_t()), int32_t>, "");
static_assert(is_same_v<decltype(uint16_t() * uint16_t()), int32_t>, "");
static_assert(is_same_v<decltype(int32_t() * int32_t()), int32_t>, "");
static_assert(is_same_v<decltype(uint32_t() * uint32_t()), uint32_t>, "");
static_assert(is_same_v<decltype(int64_t() * int64_t()), int64_t>, "");
static_assert(is_same_v<decltype(uint64_t() * uint64_t()), uint64_t>, "");
template <typename T, typename Arg0, typename Arg1>
static constexpr T Call(Arg0 left, Arg1 right) {
static_assert(is_same_v<T, Arg0> && is_same_v<T, Arg1>, "");
if constexpr(is_floating_point_v<T>) {
return left * right;
} else if constexpr(is_unsigned_v<T> && !is_same_v<T, uint16_t>) {
return left * right;
} else if constexpr(is_signed_v<T> && !is_same_v<T, int16_t>) {
return to_unsigned(left) * to_unsigned(right);
} else if constexpr(is_same_v<T, int16_t> || is_same_v<T, uint16_t>) {
// multiplication of 16 bit integer types implicitly promotes to
// signed 32 bit integer. However, some inputs may overflow (which
// triggers undefined behavior). Therefore we first cast to 32 bit
// unsigned integers where overflow is well defined.
return static_cast<uint32_t>(left) * static_cast<uint32_t>(right);
}
}
};
struct MultiplyChecked {
template <typename T, typename Arg0, typename Arg1>
static constexpr T Call(Arg0 left, Arg1 right) {
static_assert(is_same_v<T, Arg0> && is_same_v<T, Arg1>, "");
if constexpr(is_arithmetic_v<T>) {
return left * right;
}
}
};
struct AbsoluteValue {
template <typename T, typename Arg>
static constexpr T Call(Arg input) {
if constexpr(is_same_v<Arg, float>) {
*(((int*)&input)+0) &= 0x7fffffff;
return input;
} else if constexpr(is_same_v<Arg, double>) {
*(((int*)&input)+1) &= 0x7fffffff;
return input;
} else if constexpr(is_unsigned_v<Arg>) {
return input;
} else {
const auto mask = input >> (sizeof(Arg) * CHAR_BIT - 1);
return (input + mask) ^ mask;
}
}
};
struct AbsoluteValueChecked {
template <typename T, typename Arg>
static constexpr T Call(Arg input) {
if constexpr(is_same_v<Arg, float>) {
*(((int*)&input)+0) &= 0x7fffffff;
return input;
} else if constexpr(is_same_v<Arg, double>) {
*(((int*)&input)+1) &= 0x7fffffff;
return input;
} else if constexpr(is_unsigned_v<Arg>) {
return input;
} else {
const auto mask = input >> (sizeof(Arg) * CHAR_BIT - 1);
return (input + mask) ^ mask;
}
}
};
struct Negate {
template <typename T, typename Arg>
static constexpr T Call(Arg input) {
if constexpr(is_floating_point_v<Arg>) {
return -input;
} else if constexpr(is_unsigned_v<Arg>) {
return ~input + 1;
} else {
return -input;
}
}
};
struct NegateChecked {
template <typename T, typename Arg>
static constexpr T Call(Arg input) {
static_assert(is_same_v<T, Arg>, "");
if constexpr(is_floating_point_v<Arg>) {
return -input;
} else if constexpr(is_unsigned_v<Arg>) {
return 0;
} else {
return -input;
}
}
};
struct Sign {
template <typename T, typename Arg>
static constexpr T Call(Arg input) {
if constexpr(is_floating_point_v<Arg>) {
return isnan(input) ? input : ((input == 0) ? 0 : (signbit(input) ? -1 : 1));
} else if constexpr(is_unsigned_v<Arg>) {
return input > 0 ? 1 : 0;
} else if constexpr(is_signed_v<Arg>) {
return input > 0 ? 1 : (input ? -1 : 0);
}
}
};
template <typename T, typename Op, typename OutT = T>
struct arithmetic_op_arr_arr_impl {
static inline void exec(const void* in_left, const void* in_right, void* out, const int len) {
const T* left = reinterpret_cast<const T*>(in_left);
const T* right = reinterpret_cast<const T*>(in_right);
OutT* output = reinterpret_cast<OutT*>(out);
for (int i = 0; i < len; ++i) {
output[i] = Op::template Call<OutT, T, T>(left[i], right[i]);
}
}
};
template <typename T, typename Op, typename OutT = T>
struct arithmetic_op_arr_scalar_impl {
static inline void exec(const void* in_left, const void* scalar_right, void* out, const int len) {
const T* left = reinterpret_cast<const T*>(in_left);
const T right = *reinterpret_cast<const T*>(scalar_right);
OutT* output = reinterpret_cast<OutT*>(out);
for (int i = 0; i < len; ++i) {
output[i] = Op::template Call<OutT, T, T>(left[i], right);
}
}
};
template <typename T, typename Op, typename OutT = T>
struct arithmetic_op_scalar_arr_impl {
static inline void exec(const void* scalar_left, const void* in_right, void* out, const int len) {
const T left = *reinterpret_cast<const T*>(scalar_left);
const T* right = reinterpret_cast<const T*>(in_right);
OutT* output = reinterpret_cast<OutT*>(out);
for (int i = 0; i < len; ++i) {
output[i] = Op::template Call<OutT, T, T>(left, right[i]);
}
}
};
template <typename T, typename Op, typename OutT = T>
struct arithmetic_unary_op_impl {
static inline void exec(const void* arg, void* out, const int len) {
const T* input = reinterpret_cast<const T*>(arg);
OutT* output = reinterpret_cast<OutT*>(out);
for (int i = 0; i < len; ++i) {
output[i] = Op::template Call<OutT, T>(input[i]);
}
}
};
template <typename Op, template<typename...> typename Impl>
static inline void arithmetic_op(const int type, const void* in_left, const void* in_right, void* output, const int len) {
const auto intype = static_cast<arrtype>(type);
switch (intype) {
case arrtype::UINT8:
return Impl<uint8_t, Op>::exec(in_left, in_right, output, len);
case arrtype::INT8:
return Impl<int8_t, Op>::exec(in_left, in_right, output, len);
case arrtype::UINT16:
return Impl<uint16_t, Op>::exec(in_left, in_right, output, len);
case arrtype::INT16:
return Impl<int16_t, Op>::exec(in_left, in_right, output, len);
case arrtype::UINT32:
return Impl<uint32_t, Op>::exec(in_left, in_right, output, len);
case arrtype::INT32:
return Impl<int32_t, Op>::exec(in_left, in_right, output, len);
case arrtype::UINT64:
return Impl<uint64_t, Op>::exec(in_left, in_right, output, len);
case arrtype::INT64:
return Impl<int64_t, Op>::exec(in_left, in_right, output, len);
case arrtype::FLOAT32:
return Impl<float, Op>::exec(in_left, in_right, output, len);
case arrtype::FLOAT64:
return Impl<double, Op>::exec(in_left, in_right, output, len);
default:
break;
}
}
template <typename Op, template <typename...> typename Impl, typename Input>
static inline void arithmetic_op(const int otype, const void* input, void* output, const int len) {
const auto outtype = static_cast<arrtype>(otype);
switch (outtype) {
case arrtype::UINT8:
return Impl<Input, Op, uint8_t>::exec(input, output, len);
case arrtype::INT8:
return Impl<Input, Op, int8_t>::exec(input, output, len);
case arrtype::UINT16:
return Impl<Input, Op, uint16_t>::exec(input, output, len);
case arrtype::INT16:
return Impl<Input, Op, int16_t>::exec(input, output, len);
case arrtype::UINT32:
return Impl<Input, Op, uint32_t>::exec(input, output, len);
case arrtype::INT32:
return Impl<Input, Op, int32_t>::exec(input, output, len);
case arrtype::UINT64:
return Impl<Input, Op, uint64_t>::exec(input, output, len);
case arrtype::INT64:
return Impl<Input, Op, int64_t>::exec(input, output, len);
case arrtype::FLOAT32:
return Impl<Input, Op, float>::exec(input, output, len);
case arrtype::FLOAT64:
return Impl<Input, Op, double>::exec(input, output, len);
default:
break;
}
}
template <typename Op, template <typename...> typename Impl>
static inline void arithmetic_op(const int type, const void* input, void* output, const int len) {
const auto intype = static_cast<arrtype>(type);
switch (intype) {
case arrtype::UINT8:
return Impl<uint8_t, Op>::exec(input, output, len);
case arrtype::INT8:
return Impl<int8_t, Op>::exec(input, output, len);
case arrtype::UINT16:
return Impl<uint16_t, Op>::exec(input, output, len);
case arrtype::INT16:
return Impl<int16_t, Op>::exec(input, output, len);
case arrtype::UINT32:
return Impl<uint32_t, Op>::exec(input, output, len);
case arrtype::INT32:
return Impl<int32_t, Op>::exec(input, output, len);
case arrtype::UINT64:
return Impl<uint64_t, Op>::exec(input, output, len);
case arrtype::INT64:
return Impl<int64_t, Op>::exec(input, output, len);
case arrtype::FLOAT32:
return Impl<float, Op>::exec(input, output, len);
case arrtype::FLOAT64:
return Impl<double, Op>::exec(input, output, len);
default:
break;
}
}
template <typename Op, template <typename...> typename Impl>
static inline void arithmetic_op(const int itype, const int otype, const void* input, void* output, const int len) {
const auto intype = static_cast<arrtype>(itype);
switch (intype) {
case arrtype::UINT8:
return arithmetic_op<Op, Impl, uint8_t>(otype, input, output, len);
case arrtype::INT8:
return arithmetic_op<Op, Impl, int8_t>(otype, input, output, len);
case arrtype::UINT16:
return arithmetic_op<Op, Impl, uint16_t>(otype, input, output, len);
case arrtype::INT16:
return arithmetic_op<Op, Impl, int16_t>(otype, input, output, len);
case arrtype::UINT32:
return arithmetic_op<Op, Impl, uint32_t>(otype, input, output, len);
case arrtype::INT32:
return arithmetic_op<Op, Impl, int32_t>(otype, input, output, len);
case arrtype::UINT64:
return arithmetic_op<Op, Impl, uint64_t>(otype, input, output, len);
case arrtype::INT64:
return arithmetic_op<Op, Impl, int64_t>(otype, input, output, len);
case arrtype::FLOAT32:
return arithmetic_op<Op, Impl, float>(otype, input, output, len);
case arrtype::FLOAT64:
return arithmetic_op<Op, Impl, double>(otype, input, output, len);
default:
break;
}
}
template <template <typename...> class Impl>
static inline void arithmetic_unary_impl_same_types(const int type, const int8_t op, const void* input, void* output, const int len) {
const auto opt = static_cast<optype>(op);
switch (opt) {
case optype::ABSOLUTE_VALUE:
return arithmetic_op<AbsoluteValue, Impl>(type, input, output, len);
case optype::ABSOLUTE_VALUE_CHECKED:
return arithmetic_op<AbsoluteValueChecked, Impl>(type, input, output, len);
case optype::NEGATE:
return arithmetic_op<Negate, Impl>(type, input, output, len);
case optype::NEGATE_CHECKED:
return arithmetic_op<NegateChecked, Impl>(type, input, output, len);
case optype::SIGN:
return arithmetic_op<Sign, Impl>(type, input, output, len);
default:
break;
}
}
template <template <typename...> class Impl>
static inline void arithmetic_unary_impl(const int itype, const int otype, const int8_t op, const void* input, void* output, const int len) {
const auto opt = static_cast<optype>(op);
switch (opt) {
case optype::SIGN:
return arithmetic_op<Sign, Impl>(itype, otype, input, output, len);
default:
break;
}
}
template <template <typename...> class Impl>
static inline void arithmetic_binary_impl(const int type, const int8_t op, const void* in_left, const void* in_right, void* out, const int len) {
const auto opt = static_cast<optype>(op);
switch (opt) {
case optype::ADD:
return arithmetic_op<Add, Impl>(type, in_left, in_right, out, len);
case optype::ADD_CHECKED:
return arithmetic_op<AddChecked, Impl>(type, in_left, in_right, out, len);
case optype::SUB:
return arithmetic_op<Sub, Impl>(type, in_left, in_right, out, len);
case optype::SUB_CHECKED:
return arithmetic_op<SubChecked, Impl>(type, in_left, in_right, out, len);
case optype::MUL:
return arithmetic_op<Multiply, Impl>(type, in_left, in_right, out, len);
case optype::MUL_CHECKED:
return arithmetic_op<MultiplyChecked, Impl>(type, in_left, in_right, out, len);
default:
// don't implement divide here as we can only divide on non-null entries
// so we can avoid dividing by zero
break;
}
}
extern "C" void FULL_NAME(arithmetic_binary)(const int type, const int8_t op, const void* in_left, const void* in_right, void* out, const int len) {
arithmetic_binary_impl<arithmetic_op_arr_arr_impl>(type, op, in_left, in_right, out, len);
}
extern "C" void FULL_NAME(arithmetic_arr_scalar)(const int type, const int8_t op, const void* in_left, const void* in_right, void* out, const int len) {
arithmetic_binary_impl<arithmetic_op_arr_scalar_impl>(type, op, in_left, in_right, out, len);
}
extern "C" void FULL_NAME(arithmetic_scalar_arr)(const int type, const int8_t op, const void* in_left, const void* in_right, void* out, const int len) {
arithmetic_binary_impl<arithmetic_op_scalar_arr_impl>(type, op, in_left, in_right, out, len);
}
extern "C" void FULL_NAME(arithmetic_unary_same_types)(const int type, const int8_t op, const void* input, void* output, const int len) {
arithmetic_unary_impl_same_types<arithmetic_unary_op_impl>(type, op, input, output, len);
}
extern "C" void FULL_NAME(arithmetic_unary_diff_type)(const int itype, const int otype, const int8_t op, const void* input, void* output, const int len) {
arithmetic_unary_impl<arithmetic_unary_op_impl>(itype, otype, op, input, output, len);
}