/* * Licensed to the Apache Software Foundation (ASF) under one or more * contributor license agreements. See the NOTICE file distributed with * this work for additional information regarding copyright ownership. * The ASF licenses this file to You under the Apache License, Version 2.0 * (the "License"); you may not use this file except in compliance with * the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include #include #include #include #include #include #include #include #include using namespace DB; namespace DB { namespace ErrorCodes { extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH; extern const int ILLEGAL_TYPE_OF_ARGUMENT; extern const int TYPE_MISMATCH; } } namespace local_engine { template class SparkFunctionCastFloatToInt : public DB::IFunction { public: size_t getNumberOfArguments() const override { return 1; } static constexpr auto name = Name::name; static DB::FunctionPtr create(DB::ContextPtr) { return std::make_shared(); } SparkFunctionCastFloatToInt() = default; ~SparkFunctionCastFloatToInt() override = default; String getName() const override { return name; } bool useDefaultImplementationForConstants() const override { return true; } bool isSuitableForShortCircuitArgumentsExecution(const DataTypesWithConstInfo & /*arguments*/) const override { return true; } DB::DataTypePtr getReturnTypeImpl(const DB::DataTypes &) const override { if constexpr (std::is_integral_v) { return DB::makeNullable(std::make_shared>()); } else throw DB::Exception(DB::ErrorCodes::TYPE_MISMATCH, "Function {}'s return type should be Int", name); } DB::ColumnPtr executeImpl(const DB::ColumnsWithTypeAndName & arguments, const DB::DataTypePtr & result_type, size_t) const override { if (arguments.size() != 1) throw DB::Exception(DB::ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, "Function {}'s arguments number must be 1", name); if (!isFloat(removeNullable(arguments[0].type))) throw DB::Exception(DB::ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Function {}'s 1st argument must be float type", name); DB::ColumnPtr src_col = arguments[0].column; size_t size = src_col->size(); auto res_col = DB::ColumnVector::create(size); auto null_map_col = DB::ColumnUInt8::create(size, 0); switch(removeNullable(arguments[0].type)->getTypeId()) { case DB::TypeIndex::Float32: { executeInternal(src_col, res_col->getData(), null_map_col->getData()); break; } case DB::TypeIndex::Float64: { executeInternal(src_col, res_col->getData(), null_map_col->getData()); break; } } return DB::ColumnNullable::create(std::move(res_col), std::move(null_map_col)); } template void executeInternal(const DB::ColumnPtr & src, DB::PaddedPODArray & data, DB::PaddedPODArray & null_map_data) const { const DB::ColumnVector * src_vec = assert_cast *>(src.get()); for (size_t i = 0; i < src_vec->size(); ++i) { F element = src_vec->getElement(i); if (isNaN(element) || !isFinite(element)) { data[i] = 0; null_map_data[i] = 1; } else data[i] = static_cast(element); } } }; }