cpp-ch/local-engine/Functions/SparkParseURL.cpp (555 lines of code) (raw):

/* * Licensed to the Apache Software Foundation (ASF) under one or more * contributor license agreements. See the NOTICE file distributed with * this work for additional information regarding copyright ownership. * The ASF licenses this file to You under the Apache License, Version 2.0 * (the "License"); you may not use this file except in compliance with * the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include <Columns/ColumnFixedString.h> #include <Columns/ColumnNullable.h> #include <Columns/IColumn.h> #include <DataTypes/DataTypeNullable.h> #include <Functions/FunctionFactory.h> #include <Functions/FunctionStringToString.h> #include <Functions/FunctionsStringSearchToString.h> #include <Functions/IFunction.h> #include <Functions/URL/domain.h> #include <Poco/Logger.h> #include <memory> namespace DB { namespace ErrorCodes { extern const int ILLEGAL_COLUMN; extern const int ILLEGAL_TYPE_OF_ARGUMENT; } } namespace local_engine { /// allow to return null. template <typename Extractor> struct ExtractNullableSubstringImpl { static void vector(const DB::ColumnString::Chars & data, const DB::ColumnString::Offsets & offsets, DB::ColumnString::Chars & res_data, DB::ColumnString::Offsets & res_offsets, DB::IColumn & null_map) { size_t size = offsets.size(); res_offsets.resize(size); res_data.reserve(size * Extractor::getReserveLengthForElement()); null_map.reserve(size); size_t prev_offset = 0; size_t res_offset = 0; /// Matched part. DB::Pos start; size_t length; for (size_t i = 0; i < size; ++i) { Extractor::execute(reinterpret_cast<const char *>(&data[prev_offset]), offsets[i] - prev_offset - 1, start, length); res_data.resize(res_data.size() + length + 1); if (start) { memcpySmallAllowReadWriteOverflow15(&res_data[res_offset], start, length); null_map.insert(0); } else { null_map.insert(1); } res_offset += length + 1; res_data[res_offset - 1] = 0; res_offsets[i] = res_offset; prev_offset = offsets[i]; } } }; template <typename Impl, typename Name, bool is_injective = false> class FunctionStringToNullableString : public DB::IFunction { public: static constexpr auto name = Name::name; static DB::FunctionPtr create(DB::ContextPtr) { return std::make_shared<FunctionStringToNullableString>(); } String getName() const override { return name; } size_t getNumberOfArguments() const override { return 1; } bool isInjective(const DB::ColumnsWithTypeAndName &) const override { return is_injective; } bool isSuitableForShortCircuitArgumentsExecution(const DB::DataTypesWithConstInfo & /*arguments*/) const override { return true; } DB::DataTypePtr getReturnTypeImpl(const DB::DataTypes & arguments) const override { if (!DB::isStringOrFixedString(arguments[0])) throw DB::Exception(DB::ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Illegal type {} of argument of function {}", arguments[0]->getName(), getName()); return DB::makeNullable(arguments[0]); } bool useDefaultImplementationForConstants() const override { return true; } DB::ColumnPtr executeImpl(const DB::ColumnsWithTypeAndName & arguments, const DB::DataTypePtr &, size_t /*input_rows_count*/) const override { const DB::ColumnPtr column = arguments[0].column; auto null_map = DB::DataTypeUInt8().createColumn(); if (const DB::ColumnString * col = checkAndGetColumn<DB::ColumnString>(column.get())) { auto col_res = DB::ColumnString::create(); Impl::vector(col->getChars(), col->getOffsets(), col_res->getChars(), col_res->getOffsets(), *null_map); return DB::ColumnNullable::create(std::move(col_res), std::move(null_map)); } else throw DB::Exception(DB::ErrorCodes::ILLEGAL_COLUMN, "Illegal column {} of argument of function {}", arguments[0].column->getName(), getName()); } }; template <typename Impl, typename Name> class FunctionsStringSearchToNullableString : public DB::IFunction { public: static constexpr auto name = Name::name; static DB::FunctionPtr create(DB::ContextPtr) { return std::make_shared<FunctionsStringSearchToNullableString>(); } String getName() const override { return name; } size_t getNumberOfArguments() const override { return 2; } bool useDefaultImplementationForConstants() const override { return true; } DB::ColumnNumbers getArgumentsThatAreAlwaysConstant() const override { return {1}; } bool isSuitableForShortCircuitArgumentsExecution(const DB::DataTypesWithConstInfo & /*arguments*/) const override { return true; } DB::DataTypePtr getReturnTypeImpl(const DB::DataTypes & arguments) const override { if (!isString(arguments[0])) throw DB::Exception(DB::ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Illegal type {} of argument of function {}", arguments[0]->getName(), getName()); if (!isString(arguments[1])) throw DB::Exception(DB::ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Illegal type {} of argument of function {}", arguments[1]->getName(), getName()); return DB::makeNullable(std::make_shared<DB::DataTypeString>()); } DB::ColumnPtr executeImpl(const DB::ColumnsWithTypeAndName & arguments, const DB::DataTypePtr &, size_t /*input_rows_count*/) const override { const DB::ColumnPtr column = arguments[0].column; const DB::ColumnPtr column_needle = arguments[1].column; const DB::ColumnConst * col_needle = typeid_cast<const DB::ColumnConst *>(&*column_needle); if (!col_needle) throw DB::Exception(DB::ErrorCodes::ILLEGAL_COLUMN, "Second argument of function {} must be constant string", getName()); if (const DB::ColumnString * col = DB::checkAndGetColumn<DB::ColumnString>(column.get())) { auto col_res = DB::ColumnString::create(); auto null_map = DB::DataTypeUInt8().createColumn(); DB::ColumnString::Chars & vec_res = col_res->getChars(); DB::ColumnString::Offsets & offsets_res = col_res->getOffsets(); Impl::vector(col->getChars(), col->getOffsets(), col_needle->getValue<String>(), vec_res, offsets_res, *null_map); return DB::ColumnNullable::create(std::move(col_res), std::move(null_map)); } else throw DB::Exception(DB::ErrorCodes::ILLEGAL_COLUMN, "Illegal column {} of argument of function {}", arguments[0].column->getName(), getName()); } }; /// Different from CH extractURLParameters which returns an array result. struct NameSparkExtractURLQuery { static constexpr auto name = "spark_parse_url_query"; }; struct SparkExtractURLQuery { static size_t getReserveLengthForElement() { return 15; } static void execute(DB::Pos data, size_t size, DB::Pos & res_data, size_t & res_size) { res_data = data; res_size = 0; DB::Pos pos = data; DB::Pos end = data + size; const static String protocol_delim = "://"; DB::Pos protocol_delim_pos = static_cast<DB::Pos>(memmem(pos, end - pos, protocol_delim.data(), protocol_delim.size())); DB::Pos query_string_begin = nullptr; if (protocol_delim_pos) { query_string_begin = find_first_symbols<'?', '#'>(pos, end); } else { query_string_begin = find_first_symbols<'?', '#', ':'>(pos, end); } if (query_string_begin && query_string_begin < end) { if (*query_string_begin != '?') { res_data = nullptr; res_size = 0; return; } res_data = query_string_begin + 1; DB::Pos query_string_end = find_first_symbols<'#'>(res_data, end); if (query_string_end && query_string_end < end) { res_size = query_string_end - res_data; } else { res_size = end - res_data; } } else { res_data = nullptr; res_size = 0; } } }; using SparkFunctionURLQuery = FunctionStringToNullableString<ExtractNullableSubstringImpl<SparkExtractURLQuery>, NameSparkExtractURLQuery>; REGISTER_FUNCTION(SparkFunctionURLQuery) { factory.registerFunction<SparkFunctionURLQuery>(); } struct NameSparkExtractURLOneQuery { static constexpr auto name = "spark_parse_url_one_query"; }; struct SparkExtractURLOneQuery { static void vector(const DB::ColumnString::Chars & data, const DB::ColumnString::Offsets & offsets, std::string pattern, DB::ColumnString::Chars & res_data, DB::ColumnString::Offsets & res_offsets, DB::IColumn & null_map) { const static String protocol_delim = "://"; res_data.reserve(data.size() / 5); res_offsets.resize(offsets.size()); pattern += '='; const char * param_str = pattern.c_str(); size_t param_len = pattern.size(); DB::ColumnString::Offset prev_offset = 0; DB::ColumnString::Offset res_offset = 0; for (size_t i = 0; i < offsets.size(); ++i) { DB::ColumnString::Offset cur_offset = offsets[i]; const char * str = reinterpret_cast<const char *>(&data[prev_offset]); const char * end = reinterpret_cast<const char *>(&data[cur_offset]); /// Find query string or fragment identifier. /// Note that we support parameters in fragment identifier in the same way as in query string. DB::Pos protocol_delim_pos = static_cast<DB::Pos>(memmem(str, end - str, protocol_delim.data(), protocol_delim.size())); DB::Pos query_string_begin = nullptr; if (protocol_delim_pos) { query_string_begin = find_first_symbols<'?', '#'>(protocol_delim_pos, end); } else { query_string_begin = find_first_symbols<'?', '#', ':'>(str, end); } if (*query_string_begin != '?') { query_string_begin = end; } /// Will point to the beginning of "name=value" pair. Then it will be reassigned to the beginning of "value". const char * param_begin = nullptr; if (query_string_begin + 1 < end) { param_begin = query_string_begin + 1; while (true) { param_begin = static_cast<const char *>(memmem(param_begin, end - param_begin, param_str, param_len)); if (!param_begin) break; if (param_begin[-1] != '?' && param_begin[-1] != '#' && param_begin[-1] != '&') { /// Parameter name is different but has the same suffix. param_begin += param_len; continue; } else { param_begin += param_len; break; } } } if (param_begin) { const char * param_end = find_first_symbols<'&', '#'>(param_begin, end); if (param_end == end) param_end = param_begin + strlen(param_begin); size_t param_size = param_end - param_begin; res_data.resize(res_offset + param_size + 1); memcpySmallAllowReadWriteOverflow15(&res_data[res_offset], param_begin, param_size); res_offset += param_size; null_map.insert(0); } else { /// No parameter found, put empty string in result. res_data.resize(res_offset + 1); null_map.insert(1); } res_data[res_offset] = 0; ++res_offset; res_offsets[i] = res_offset; prev_offset = cur_offset; } } }; using SparkFunctionURLOneQuery = FunctionsStringSearchToNullableString<SparkExtractURLOneQuery, NameSparkExtractURLOneQuery>; REGISTER_FUNCTION(SparkFunctionURLOneQuery) { factory.registerFunction<SparkFunctionURLOneQuery>(); } struct SparkExtractURLHost { static size_t getReserveLengthForElement() { return 15; } static void execute(DB::Pos data, size_t size, DB::Pos & res_data, size_t & res_size) { DB::Pos end = data + size; const static String protocol_delim = "://"; DB::Pos protocol_delim_start = static_cast<DB::Pos>(memmem(data, size, protocol_delim.data(), protocol_delim.size())); if (!protocol_delim_start) { res_data = nullptr; res_size = 0; return; } DB::Pos userinfo_delim_pos = find_first_symbols<'@'>(protocol_delim_start + protocol_delim.size(), end); std::string_view host; if (userinfo_delim_pos && userinfo_delim_pos < end) { host = DB::getURLHost(userinfo_delim_pos + 1, end - userinfo_delim_pos); } else { host = DB::getURLHost(protocol_delim_start + protocol_delim.size() , end - protocol_delim_start - protocol_delim.size()); } if (host.empty()) { res_data = data; res_size = 0; } else { res_data = host.data(); res_size = host.size(); } } }; struct NameSparkExtractURLHost { static constexpr auto name = "spark_parse_url_host"; }; using SparkFunctionURLHost = FunctionStringToNullableString<ExtractNullableSubstringImpl<SparkExtractURLHost>, NameSparkExtractURLHost>; REGISTER_FUNCTION(SparkFunctionURLHost) { factory.registerFunction<SparkFunctionURLHost>(); } struct NameSparkExtractURLPath { static constexpr auto name = "spark_parse_url_path"; }; struct SparkExtractURLPath { static size_t getReserveLengthForElement() { return 25; } static void execute(DB::Pos data, size_t size, DB::Pos & res_data, size_t & res_size) { res_data = data; res_size = 0; DB::Pos pos = data; DB::Pos end = data + size; const static String protocol_delim = "://"; const auto * start_pos = static_cast<DB::Pos>(memmem(pos, end - pos, protocol_delim.data(), protocol_delim.size())); if (start_pos) { start_pos += protocol_delim.size(); const auto * path_start_pos = find_first_symbols<'/', '#', '?'>(start_pos, end); if (path_start_pos && path_start_pos < end) { if (*path_start_pos != '/') return; res_data = path_start_pos; const auto * path_end_pos = find_first_symbols<'?', '#'>(path_start_pos, end); if (path_end_pos && path_end_pos < end) { res_size = path_end_pos - path_start_pos; } else { res_size = end - path_start_pos; } } } } }; using SparkFunctionURLPath = DB::FunctionStringToString<DB::ExtractSubstringImpl<SparkExtractURLPath>, NameSparkExtractURLPath>; REGISTER_FUNCTION(SparkFunctionURLPath) { factory.registerFunction<SparkFunctionURLPath>(); } struct NameSparkExtractUserInfo { static constexpr auto name = "spark_parse_url_userinfo"; }; struct SparkExtractURLUserInfo { static size_t getReserveLengthForElement() { return 25; } static void execute(DB::Pos data, size_t size, DB::Pos & res_data, size_t & res_size) { res_data = data; res_size = 0; DB::Pos pos = data; DB::Pos end = data + size; const static String protocol_delim = "://"; const static String userinfo_delim = "@"; DB::Pos protocol_delim_start = static_cast<DB::Pos>(memmem(pos, end - pos, protocol_delim.data(), protocol_delim.size())); if (!protocol_delim_start) { res_data = nullptr; res_size = 0; return; } res_data = protocol_delim_start + protocol_delim.size(); DB::Pos userinfo_delim_start = find_first_symbols<'@'>(res_data, end); if (!userinfo_delim_start || userinfo_delim_start >= end) { res_data = nullptr; res_size = 0; return; } res_size = userinfo_delim_start - res_data; } }; using SparkFunctionURLUserInfo = FunctionStringToNullableString<ExtractNullableSubstringImpl<SparkExtractURLUserInfo>, NameSparkExtractUserInfo>; REGISTER_FUNCTION(SparkFunctionURLUserInfo) { factory.registerFunction<SparkFunctionURLUserInfo>(); } struct NameSparkExtractURLRef { static constexpr auto name = "spark_parse_url_ref"; }; struct SparkExtractURLRef { static size_t getReserveLengthForElement() { return 25; } static void execute(DB::Pos data, size_t size, DB::Pos & res_data, size_t & res_size) { res_data = data; res_size = 0; DB::Pos pos = data; DB::Pos end = data + size; const static String ref_delim = "#"; const auto * ref_delim_pos = find_first_symbols<'#'>(pos, end); if (ref_delim_pos && ref_delim_pos < end) { res_data = ref_delim_pos + 1; res_size = end - res_data; } else { res_data = nullptr; res_size = 0; } } }; using SparkFunctionURLRef = FunctionStringToNullableString<ExtractNullableSubstringImpl<SparkExtractURLRef>, NameSparkExtractURLRef>; REGISTER_FUNCTION(SparkFunctionURLRef) { factory.registerFunction<SparkFunctionURLRef>(); } struct NameSparkExtractURLFile { static constexpr auto name = "spark_parse_url_file"; }; struct SparkExtractURLFile { static size_t getReserveLengthForElement() { return 25; } static void execute(DB::Pos data, size_t size, DB::Pos & res_data, size_t & res_size) { res_data = data; res_size = 0; DB::Pos pos = data; DB::Pos end = data + size; const static String protocol_delim = "://"; const static String slash_delim = "/"; const static String query_delim = "?"; const auto * protocol_delim_pos = static_cast<DB::Pos>(memmem(pos, end - pos, protocol_delim.data(), protocol_delim.size())); if (!protocol_delim_pos) { auto colon_pos = find_first_symbols<':'>(pos, end); if (colon_pos && colon_pos + 1 < end) { res_data = nullptr; return; } res_size = size; return; } DB::Pos file_begin_pos = find_first_symbols<'/', '?', '#'>(protocol_delim_pos + protocol_delim.size(), end); if (file_begin_pos && file_begin_pos < end) { if (*file_begin_pos == '#') { return; } res_data = file_begin_pos; DB::Pos ref_delim_pos = find_first_symbols<'#'>(file_begin_pos + 1, end); if (ref_delim_pos && ref_delim_pos < end) { res_size = ref_delim_pos - res_data; } else { res_size = end - res_data; } } } }; using SparkFunctionURLFile = FunctionStringToNullableString<ExtractNullableSubstringImpl<SparkExtractURLFile>, NameSparkExtractURLFile>; REGISTER_FUNCTION(SparkFunctionURLFile) { factory.registerFunction<SparkFunctionURLFile>(); } struct NameSparkExtractURLAuthority { static constexpr auto name = "spark_parse_url_authority"; }; struct SparkExtractURLAuthority { static size_t getReserveLengthForElement() { return 25; } static void execute(DB::Pos data, size_t size, DB::Pos & res_data, size_t & res_size) { res_data = data; res_size = 0; DB::Pos pos = data; DB::Pos end = data + size; const static String protocol_delim = "://"; const auto * protocol_delim_pos = static_cast<DB::Pos>(memmem(pos, end - pos, protocol_delim.data(), protocol_delim.size())); if (!protocol_delim_pos) { res_data = nullptr; res_size = 0; return; } res_data = protocol_delim_pos + protocol_delim.size(); DB::Pos end_pos = find_first_symbols<'/', '?', '#'>(res_data, end); if (end_pos) { res_size = end_pos - res_data; } else { res_size = end - res_data -1 ; } } }; using SparkFunctionURLAuthority = FunctionStringToNullableString<ExtractNullableSubstringImpl<SparkExtractURLAuthority>, NameSparkExtractURLAuthority>; REGISTER_FUNCTION(SparkFunctionURLAuthority) { factory.registerFunction<SparkFunctionURLAuthority>(); } struct NameSparkExtractURLInvalid { static constexpr auto name = "spark_parse_url_invalid"; }; struct SparkExtractURLInvalid { static size_t getReserveLengthForElement() { return 1; } static void execute(DB::Pos, size_t, DB::Pos & res_data, size_t & res_size) { res_data = nullptr; res_size = 0; } }; using SparkFunctionURLInvalid = FunctionStringToNullableString<ExtractNullableSubstringImpl<SparkExtractURLInvalid>, NameSparkExtractURLInvalid>; REGISTER_FUNCTION(SparkFunctionURLInvalid) { factory.registerFunction<SparkFunctionURLInvalid>(); } }