cpp-ch/local-engine/Functions/SparkFunctionRoundHalfUp.h (299 lines of code) (raw):

/* * Licensed to the Apache Software Foundation (ASF) under one or more * contributor license agreements. See the NOTICE file distributed with * this work for additional information regarding copyright ownership. * The ASF licenses this file to You under the Apache License, Version 2.0 * (the "License"); you may not use this file except in compliance with * the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #pragma once #include <Functions/FunctionsRound.h> namespace DB::ErrorCodes { extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH; } namespace local_engine { template <typename T, DB::Vectorize vectorize> class BaseFloatRoundingHalfUpComputation; #ifdef __SSE4_1__ /// vectorized implementation for x86 template <> class BaseFloatRoundingHalfUpComputation<Float32, DB::Vectorize::Yes> { public: using ScalarType = Float32; using VectorType = __m128; static const size_t data_count = 4; static VectorType load(const ScalarType * in) { return _mm_loadu_ps(in); } static VectorType load1(const ScalarType in) { return _mm_load1_ps(&in); } static void store(ScalarType * out, VectorType val) { _mm_storeu_ps(out, val); } static VectorType multiply(VectorType val, VectorType scale) { return _mm_mul_ps(val, scale); } static VectorType divide(VectorType val, VectorType scale) { return _mm_div_ps(val, scale); } template <DB::RoundingMode mode> static VectorType apply(VectorType val) { ScalarType tempFloatsIn[data_count]; ScalarType tempFloatsOut[data_count]; store(tempFloatsIn, val); for (size_t i = 0; i < data_count; ++i) tempFloatsOut[i] = std::roundf(tempFloatsIn[i]); return load(tempFloatsOut); } static VectorType prepare(size_t scale) { return load1(scale); } }; template <> class BaseFloatRoundingHalfUpComputation<Float64, DB::Vectorize::Yes> { public: using ScalarType = Float64; using VectorType = __m128d; static const size_t data_count = 2; static VectorType load(const ScalarType * in) { return _mm_loadu_pd(in); } static VectorType load1(const ScalarType in) { return _mm_load1_pd(&in); } static void store(ScalarType * out, VectorType val) { _mm_storeu_pd(out, val); } static VectorType multiply(VectorType val, VectorType scale) { return _mm_mul_pd(val, scale); } static VectorType divide(VectorType val, VectorType scale) { return _mm_div_pd(val, scale); } template <DB::RoundingMode mode> static VectorType apply(VectorType val) { ScalarType tempFloatsIn[data_count]; ScalarType tempFloatsOut[data_count]; store(tempFloatsIn, val); for (size_t i = 0; i < data_count; ++i) tempFloatsOut[i] = std::round(tempFloatsIn[i]); return load(tempFloatsOut); } static VectorType prepare(size_t scale) { return load1(scale); } }; /// end __SSE4_1__ #endif /// Sequential implementation for ARM. Also used for scalar arguments template <typename T> class BaseFloatRoundingHalfUpComputation<T, DB::Vectorize::No> { public: using ScalarType = T; using VectorType = T; static const size_t data_count = 1; static VectorType load(const ScalarType * in) { return *in; } static VectorType load1(const ScalarType in) { return in; } static VectorType store(ScalarType * out, ScalarType val) { return *out = val;} static VectorType multiply(VectorType val, VectorType scale) { return val * scale; } static VectorType divide(VectorType val, VectorType scale) { return val / scale; } template <DB::RoundingMode mode> static VectorType apply(VectorType val) { if constexpr (std::is_same_v<ScalarType, Float32>) { return std::roundf(val); } else { return std::round(val); } } static VectorType prepare(size_t scale) { return load1(scale); } }; template <> class BaseFloatRoundingHalfUpComputation<BFloat16, DB::Vectorize::No> { public: using ScalarType = BFloat16; using VectorType = BFloat16; static const size_t data_count = 1; static VectorType load(const ScalarType * in) { return *in; } static VectorType load1(const ScalarType in) { return in; } static VectorType store(ScalarType * out, ScalarType val) { return *out = val;} static VectorType multiply(VectorType val, VectorType scale) { return val * scale; } static VectorType divide(VectorType val, VectorType scale) { return val / scale; } template <DB::RoundingMode mode> static VectorType apply(VectorType val) { return BFloat16(std::roundf(static_cast<Float32>(val))); } static VectorType prepare(size_t scale) { return load1(BFloat16(static_cast<Float32>(scale))); } }; /** Implementation of low-level round-off functions for floating-point values. */ template <typename T, DB::RoundingMode rounding_mode, DB::ScaleMode scale_mode, DB::Vectorize vectorize> class FloatRoundingHalfUpComputation : public BaseFloatRoundingHalfUpComputation<T, vectorize> { using Base = BaseFloatRoundingHalfUpComputation<T, vectorize>; public: static inline void compute(const T * __restrict in, const typename Base::VectorType & scale, T * __restrict out) { auto val = Base::load(in); if (scale_mode == DB::ScaleMode::Positive) val = Base::multiply(val, scale); else if (scale_mode == DB::ScaleMode::Negative) val = Base::divide(val, scale); val = Base::template apply<rounding_mode>(val); if (scale_mode == DB::ScaleMode::Positive) val = Base::divide(val, scale); else if (scale_mode == DB::ScaleMode::Negative) val = Base::multiply(val, scale); Base::store(out, val); } }; /** Implementing high-level rounding functions. */ template <typename T, DB::RoundingMode rounding_mode, DB::ScaleMode scale_mode> struct FloatRoundingHalfUpImpl { private: static_assert(!DB::is_decimal<T>); template <DB::Vectorize vectorize = #ifdef __SSE4_1__ std::is_same_v<T, BFloat16> ? DB::Vectorize::No : DB::Vectorize::Yes #else DB::Vectorize::No #endif > using Op = FloatRoundingHalfUpComputation<T, rounding_mode, scale_mode, vectorize>; using Data = std::array<T, Op<>::data_count>; using ColumnType = DB::ColumnVector<T>; using Container = typename ColumnType::Container; public: static NO_INLINE void apply(const Container & in, size_t scale, Container & out) { auto mm_scale = Op<>::prepare(scale); const size_t data_count = std::tuple_size<Data>(); const T * end_in = in.data() + in.size(); const T * limit = in.data() + in.size() / data_count * data_count; const T * __restrict p_in = in.data(); T * __restrict p_out = out.data(); while (p_in < limit) { Op<>::compute(p_in, mm_scale, p_out); p_in += data_count; p_out += data_count; } if (p_in < end_in) { Data tmp_src{{}}; Data tmp_dst; size_t tail_size_bytes = (end_in - p_in) * sizeof(*p_in); memcpy(&tmp_src, p_in, tail_size_bytes); Op<>::compute(reinterpret_cast<T *>(&tmp_src), mm_scale, reinterpret_cast<T *>(&tmp_dst)); memcpy(p_out, &tmp_dst, tail_size_bytes); } } }; /** Select the appropriate processing algorithm depending on the scale. */ template <typename T, DB::RoundingMode rounding_mode, DB::TieBreakingMode tie_breaking_mode> struct DispatcherRoundingHalfUp { template <DB::ScaleMode scale_mode> using FunctionRoundingImpl = std::conditional_t< std::is_floating_point_v<T> || std::is_same_v<T, BFloat16>, FloatRoundingHalfUpImpl<T, rounding_mode, scale_mode>, DB::IntegerRoundingImpl<T, rounding_mode, scale_mode, tie_breaking_mode>>; static DB::ColumnPtr apply(const DB::IColumn * col_general, DB::Scale scale_arg) { const auto * const col = checkAndGetColumn<DB::ColumnVector<T>>(col_general); auto col_res = DB::ColumnVector<T>::create(); typename DB::ColumnVector<T>::Container & vec_res = col_res->getData(); vec_res.resize_exact(col->getData().size()); if (!vec_res.empty()) { if (scale_arg == 0) { size_t scale = 1; FunctionRoundingImpl<DB::ScaleMode::Zero>::apply(col->getData(), scale, vec_res); } else if (scale_arg > 0) { size_t scale = intExp10(scale_arg); FunctionRoundingImpl<DB::ScaleMode::Positive>::apply(col->getData(), scale, vec_res); } else { size_t scale = intExp10(-scale_arg); FunctionRoundingImpl<DB::ScaleMode::Negative>::apply(col->getData(), scale, vec_res); } } return col_res; } }; template <DB::is_decimal T, DB::RoundingMode rounding_mode, DB::TieBreakingMode tie_breaking_mode> struct DispatcherRoundingHalfUp<T, rounding_mode, tie_breaking_mode> { public: static DB::ColumnPtr apply(const DB::IColumn * col_general, DB::Scale scale_arg) { const auto * const col = checkAndGetColumn<DB::ColumnDecimal<T>>(col_general); const typename DB::ColumnDecimal<T>::Container & vec_src = col->getData(); auto col_res = DB::ColumnDecimal<T>::create(vec_src.size(), col->getScale()); auto & vec_res = col_res->getData(); if (!vec_res.empty()) DB::DecimalRoundingImpl<T, rounding_mode, tie_breaking_mode>::apply(col->getData(), col->getScale(), vec_res, scale_arg); return col_res; } }; /** A template for functions that round the value of an input parameter of type * (U)Int8/16/32/64, Float32/64 or Decimal32/64/128, and accept an additional optional parameter (default is 0). */ template <typename Name, DB::RoundingMode rounding_mode, DB::TieBreakingMode tie_breaking_mode> class FunctionRoundingHalfUp : public DB::IFunction { public: static constexpr auto name = "roundHalfUp"; static DB::FunctionPtr create(DB::ContextPtr) { return std::make_shared<FunctionRoundingHalfUp>(); } String getName() const override { return name; } bool isVariadic() const override { return true; } size_t getNumberOfArguments() const override { return 0; } bool isSuitableForShortCircuitArgumentsExecution(const DB::DataTypesWithConstInfo & /*arguments*/) const override { return false; } /// Get result types by argument types. If the function does not apply to these arguments, throw an exception. DB::DataTypePtr getReturnTypeImpl(const DB::DataTypes & arguments) const override { if ((arguments.empty()) || (arguments.size() > 2)) throw DB::Exception( DB::ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, "Number of arguments for function {} doesn't match: passed {}, should be 1 or 2.", getName(), arguments.size()); for (const auto & type : arguments) if (!isNumber(type)) throw DB::Exception( DB::ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Illegal type {} of argument of function {}", arguments[0]->getName(), getName()); return arguments[0]; } static DB::Scale getScaleArg(const DB::ColumnsWithTypeAndName & arguments) { if (arguments.size() == 2) { const DB::IColumn & scale_column = *arguments[1].column; if (!isColumnConst(scale_column)) throw DB::Exception(DB::ErrorCodes::ILLEGAL_COLUMN, "DB::Scale argument for rounding functions must be constant"); DB::Field scale_field = assert_cast<const DB::ColumnConst &>(scale_column).getField(); if (scale_field.getType() != DB::Field::Types::UInt64 && scale_field.getType() != DB::Field::Types::Int64) throw DB::Exception(DB::ErrorCodes::ILLEGAL_COLUMN, "DB::Scale argument for rounding functions must have integer type"); Int64 scale64 = scale_field.safeGet<Int64>(); if (scale64 > std::numeric_limits<DB::Scale>::max() || scale64 < std::numeric_limits<DB::Scale>::min()) throw DB::Exception(DB::ErrorCodes::ARGUMENT_OUT_OF_BOUND, "DB::Scale argument for rounding function is too large"); return scale64; } return 0; } bool useDefaultImplementationForConstants() const override { return true; } DB::ColumnNumbers getArgumentsThatAreAlwaysConstant() const override { return {1}; } DB::ColumnPtr executeImpl(const DB::ColumnsWithTypeAndName & arguments, const DB::DataTypePtr &, size_t /*input_rows_count*/) const override { const DB::ColumnWithTypeAndName & column = arguments[0]; DB::Scale scale_arg = getScaleArg(arguments); DB::ColumnPtr res; auto call = [&](const auto & types) -> bool { using Types = std::decay_t<decltype(types)>; using DataType = typename Types::LeftType; if constexpr (DB::IsDataTypeNumber<DataType> || DB::IsDataTypeDecimal<DataType>) { using FieldType = typename DataType::FieldType; res = DispatcherRoundingHalfUp<FieldType, rounding_mode, tie_breaking_mode>::apply(column.column.get(), scale_arg); return true; } return false; }; if (!callOnIndexAndDataType<void>(column.type->getTypeId(), call)) throw DB::Exception(DB::ErrorCodes::ILLEGAL_COLUMN, "Illegal column {} of argument of function {}", column.name, getName()); return res; } bool hasInformationAboutMonotonicity() const override { return true; } Monotonicity getMonotonicityForRange(const DB::IDataType &, const DB::Field &, const DB::Field &) const override { return {.is_monotonic = true, .is_always_monotonic = true}; } }; struct NameRoundHalfUp { static constexpr auto name = "roundHalfUp"; }; using FunctionRoundHalfUp = FunctionRoundingHalfUp<NameRoundHalfUp, DB::RoundingMode::Round, DB::TieBreakingMode::Auto>; }