arrow/compute/internal/kernels/_lib/scalar_comparison.cc (192 lines of code) (raw):

// Licensed to the Apache Software Foundation (ASF) under one // or more contributor license agreements. See the NOTICE file // distributed with this work for additional information // regarding copyright ownership. The ASF licenses this file // to you under the Apache License, Version 2.0 (the // "License"); you may not use this file except in compliance // with the License. You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include <arch.h> #include <stdint.h> #include "types.h" // pack integers into a bitmap in batches of 8 template <int batch_size> inline void pack_bits(const uint32_t* values, uint8_t* out) { for (int i = 0; i < batch_size / 8; ++i) { *out++ = (values[0] | values[1]<<1 | values[2]<<2 | values[3]<<3 | values[4]<<4 | values[5]<<5 | values[6]<<6 | values[7]<<7); values += 8; } } struct Equal { template <typename T> static constexpr bool Call(const T& left, const T& right) { return left == right; } }; struct NotEqual { template <typename T> static constexpr bool Call(const T& left, const T& right) { return left != right; } }; struct Greater { template <typename T> static constexpr bool Call(const T& left, const T& right) { return left > right; } }; struct GreaterEqual { template <typename T> static constexpr bool Call(const T& left, const T& right) { return left >= right; } }; static inline void set_bit_to(uint8_t* bits, int64_t i, bool bit_is_set) { bits[i/8] ^= static_cast<uint8_t>(-static_cast<uint8_t>(bit_is_set) ^ bits[i / 8]) & static_cast<uint8_t>(1 << (i % 8)); } template <typename T, typename Op> struct compare_primitive_arr_arr { static inline void Exec(const void* left_void, const void* right_void, int64_t length, void* out_void, const int offset) { const T* left = reinterpret_cast<const T*>(left_void); const T* right = reinterpret_cast<const T*>(right_void); uint8_t* out_bitmap = reinterpret_cast<uint8_t*>(out_void); static constexpr int kBatchSize = 32; int64_t num_batches = length / kBatchSize; uint32_t temp_output[kBatchSize]; if (int prefix = offset % 8) { for (int i = prefix; i < 8; ++i) { set_bit_to(out_bitmap, i, Op::template Call<T>(*left++, *right++)); } out_bitmap++; } for (int64_t j = 0; j < num_batches; ++j) { for (int i = 0; i < kBatchSize; ++i) { temp_output[i] = Op::template Call<T>(*left++, *right++); } pack_bits<kBatchSize>(temp_output, out_bitmap); out_bitmap += kBatchSize / 8; } int64_t bit_index = 0; for (int64_t j = kBatchSize * num_batches; j < length; ++j) { set_bit_to(out_bitmap, bit_index++, Op::template Call<T>(*left++, *right++)); } } }; template <typename T, typename Op> struct compare_primitive_arr_scalar { static inline void Exec(const void* left_void, const void* right_void, int64_t length, void* out_void, const int offset) { const T* left = reinterpret_cast<const T*>(left_void); const T right = *reinterpret_cast<const T*>(right_void); uint8_t* out_bitmap = reinterpret_cast<uint8_t*>(out_void); static constexpr int kBatchSize = 32; int64_t num_batches = length / kBatchSize; uint32_t temp_output[kBatchSize]; if (int prefix = offset % 8) { for (int i = prefix; i < 8; ++i) { set_bit_to(out_bitmap, i, Op::template Call<T>(*left++, right)); } out_bitmap++; } for (int64_t j = 0; j < num_batches; ++j) { for (int i = 0; i < kBatchSize; ++i) { temp_output[i] = Op::template Call<T>(*left++, right); } pack_bits<kBatchSize>(temp_output, out_bitmap); out_bitmap += kBatchSize / 8; } int64_t bit_index = 0; for (int64_t j = kBatchSize * num_batches; j < length; ++j) { set_bit_to(out_bitmap, bit_index++, Op::template Call<T>(*left++, right)); } } }; template <typename T, typename Op> struct compare_primitive_scalar_arr { static inline void Exec(const void* left_void, const void* right_void, int64_t length, void* out_void, const int offset) { const T left = *reinterpret_cast<const T*>(left_void); const T* right = reinterpret_cast<const T*>(right_void); uint8_t* out_bitmap = reinterpret_cast<uint8_t*>(out_void); static constexpr int kBatchSize = 32; int64_t num_batches = length / kBatchSize; uint32_t temp_output[kBatchSize]; if (int prefix = offset % 8) { for (int i = prefix; i < 8; ++i) { set_bit_to(out_bitmap, i, Op::template Call<T>(left, *right++)); } out_bitmap++; } for (int64_t j = 0; j < num_batches; ++j) { for (int i = 0; i < kBatchSize; ++i) { temp_output[i] = Op::template Call<T>(left, *right++); } pack_bits<kBatchSize>(temp_output, out_bitmap); out_bitmap += kBatchSize / 8; } int64_t bit_index = 0; for (int64_t j = kBatchSize * num_batches; j < length; ++j) { set_bit_to(out_bitmap, bit_index++, Op::template Call<T>(left, *right++)); } } }; enum class cmpop : int8_t { EQUAL, NOT_EQUAL, GREATER, GREATER_EQUAL, // LESS and LESS_EQUAL are handled by doing flipped // versions of GREATER and GREATER_EQUAL }; template <typename Op, template <typename...> typename Impl> static inline void comparison_exec(const int type, const void* left, const void* right, void* out, const int64_t length, const int offset) { const auto ty = static_cast<arrtype>(type); switch (ty) { case arrtype::UINT8: return Impl<uint8_t, Op>::Exec(left, right, length, out, offset); case arrtype::INT8: return Impl<int8_t, Op>::Exec(left, right, length, out, offset); case arrtype::UINT16: return Impl<uint16_t, Op>::Exec(left, right, length, out, offset); case arrtype::INT16: return Impl<int16_t, Op>::Exec(left, right, length, out, offset); case arrtype::UINT32: return Impl<uint32_t, Op>::Exec(left, right, length, out, offset); case arrtype::INT32: return Impl<int32_t, Op>::Exec(left, right, length, out, offset); case arrtype::UINT64: return Impl<uint64_t, Op>::Exec(left, right, length, out, offset); case arrtype::INT64: return Impl<int64_t, Op>::Exec(left, right, length, out, offset); case arrtype::FLOAT32: return Impl<float, Op>::Exec(left, right, length, out, offset); case arrtype::FLOAT64: return Impl<double, Op>::Exec(left, right, length, out, offset); default: break; } } extern "C" void FULL_NAME(comparison_equal_arr_arr)(const int type, const void* left, const void* right, void* out, const int64_t length, const int offset) { comparison_exec<Equal, compare_primitive_arr_arr>(type, left, right, out, length, offset); } extern "C" void FULL_NAME(comparison_equal_arr_scalar)(const int type, const void* left, const void* right, void* out, const int64_t length, const int offset) { comparison_exec<Equal, compare_primitive_arr_scalar>(type, left, right, out, length, offset); } extern "C" void FULL_NAME(comparison_equal_scalar_arr)(const int type, const void* left, const void* right, void* out, const int64_t length, const int offset) { comparison_exec<Equal, compare_primitive_scalar_arr>(type, left, right, out, length, offset); } extern "C" void FULL_NAME(comparison_not_equal_arr_arr)(const int type, const void* left, const void* right, void* out, const int64_t length, const int offset) { comparison_exec<NotEqual, compare_primitive_arr_arr>(type, left, right, out, length, offset); } extern "C" void FULL_NAME(comparison_not_equal_arr_scalar)(const int type, const void* left, const void* right, void* out, const int64_t length, const int offset) { comparison_exec<NotEqual, compare_primitive_arr_scalar>(type, left, right, out, length, offset); } extern "C" void FULL_NAME(comparison_not_equal_scalar_arr)(const int type, const void* left, const void* right, void* out, const int64_t length, const int offset) { comparison_exec<NotEqual, compare_primitive_scalar_arr>(type, left, right, out, length, offset); } extern "C" void FULL_NAME(comparison_greater_arr_arr)(const int type, const void* left, const void* right, void* out, const int64_t length, const int offset) { comparison_exec<Greater, compare_primitive_arr_arr>(type, left, right, out, length, offset); } extern "C" void FULL_NAME(comparison_greater_arr_scalar)(const int type, const void* left, const void* right, void* out, const int64_t length, const int offset) { comparison_exec<Greater, compare_primitive_arr_scalar>(type, left, right, out, length, offset); } extern "C" void FULL_NAME(comparison_greater_scalar_arr)(const int type, const void* left, const void* right, void* out, const int64_t length, const int offset) { comparison_exec<Greater, compare_primitive_scalar_arr>(type, left, right, out, length, offset); } extern "C" void FULL_NAME(comparison_greater_equal_arr_arr)(const int type, const void* left, const void* right, void* out, const int64_t length, const int offset) { comparison_exec<GreaterEqual, compare_primitive_arr_arr>(type, left, right, out, length, offset); } extern "C" void FULL_NAME(comparison_greater_equal_arr_scalar)(const int type, const void* left, const void* right, void* out, const int64_t length, const int offset) { comparison_exec<GreaterEqual, compare_primitive_arr_scalar>(type, left, right, out, length, offset); } extern "C" void FULL_NAME(comparison_greater_equal_scalar_arr)(const int type, const void* left, const void* right, void* out, const int64_t length, const int offset) { comparison_exec<GreaterEqual, compare_primitive_scalar_arr>(type, left, right, out, length, offset); }