cpp-ch/local-engine/Functions/SparkParseURL.cpp (555 lines of code) (raw):
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <Columns/ColumnFixedString.h>
#include <Columns/ColumnNullable.h>
#include <Columns/IColumn.h>
#include <DataTypes/DataTypeNullable.h>
#include <Functions/FunctionFactory.h>
#include <Functions/FunctionStringToString.h>
#include <Functions/FunctionsStringSearchToString.h>
#include <Functions/IFunction.h>
#include <Functions/URL/domain.h>
#include <Poco/Logger.h>
#include <memory>
namespace DB
{
namespace ErrorCodes
{
extern const int ILLEGAL_COLUMN;
extern const int ILLEGAL_TYPE_OF_ARGUMENT;
}
}
namespace local_engine
{
/// allow to return null.
template <typename Extractor>
struct ExtractNullableSubstringImpl
{
static void vector(const DB::ColumnString::Chars & data, const DB::ColumnString::Offsets & offsets,
DB::ColumnString::Chars & res_data, DB::ColumnString::Offsets & res_offsets, DB::IColumn & null_map)
{
size_t size = offsets.size();
res_offsets.resize(size);
res_data.reserve(size * Extractor::getReserveLengthForElement());
null_map.reserve(size);
size_t prev_offset = 0;
size_t res_offset = 0;
/// Matched part.
DB::Pos start;
size_t length;
for (size_t i = 0; i < size; ++i)
{
Extractor::execute(reinterpret_cast<const char *>(&data[prev_offset]), offsets[i] - prev_offset - 1, start, length);
res_data.resize(res_data.size() + length + 1);
if (start)
{
memcpySmallAllowReadWriteOverflow15(&res_data[res_offset], start, length);
null_map.insert(0);
}
else
{
null_map.insert(1);
}
res_offset += length + 1;
res_data[res_offset - 1] = 0;
res_offsets[i] = res_offset;
prev_offset = offsets[i];
}
}
};
template <typename Impl, typename Name, bool is_injective = false>
class FunctionStringToNullableString : public DB::IFunction
{
public:
static constexpr auto name = Name::name;
static DB::FunctionPtr create(DB::ContextPtr)
{
return std::make_shared<FunctionStringToNullableString>();
}
String getName() const override
{
return name;
}
size_t getNumberOfArguments() const override
{
return 1;
}
bool isInjective(const DB::ColumnsWithTypeAndName &) const override
{
return is_injective;
}
bool isSuitableForShortCircuitArgumentsExecution(const DB::DataTypesWithConstInfo & /*arguments*/) const override
{
return true;
}
DB::DataTypePtr getReturnTypeImpl(const DB::DataTypes & arguments) const override
{
if (!DB::isStringOrFixedString(arguments[0]))
throw DB::Exception(DB::ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Illegal type {} of argument of function {}",
arguments[0]->getName(), getName());
return DB::makeNullable(arguments[0]);
}
bool useDefaultImplementationForConstants() const override { return true; }
DB::ColumnPtr executeImpl(const DB::ColumnsWithTypeAndName & arguments, const DB::DataTypePtr &, size_t /*input_rows_count*/) const override
{
const DB::ColumnPtr column = arguments[0].column;
auto null_map = DB::DataTypeUInt8().createColumn();
if (const DB::ColumnString * col = checkAndGetColumn<DB::ColumnString>(column.get()))
{
auto col_res = DB::ColumnString::create();
Impl::vector(col->getChars(), col->getOffsets(), col_res->getChars(), col_res->getOffsets(), *null_map);
return DB::ColumnNullable::create(std::move(col_res), std::move(null_map));
}
else
throw DB::Exception(DB::ErrorCodes::ILLEGAL_COLUMN, "Illegal column {} of argument of function {}",
arguments[0].column->getName(), getName());
}
};
template <typename Impl, typename Name>
class FunctionsStringSearchToNullableString : public DB::IFunction
{
public:
static constexpr auto name = Name::name;
static DB::FunctionPtr create(DB::ContextPtr) { return std::make_shared<FunctionsStringSearchToNullableString>(); }
String getName() const override { return name; }
size_t getNumberOfArguments() const override { return 2; }
bool useDefaultImplementationForConstants() const override { return true; }
DB::ColumnNumbers getArgumentsThatAreAlwaysConstant() const override { return {1}; }
bool isSuitableForShortCircuitArgumentsExecution(const DB::DataTypesWithConstInfo & /*arguments*/) const override { return true; }
DB::DataTypePtr getReturnTypeImpl(const DB::DataTypes & arguments) const override
{
if (!isString(arguments[0]))
throw DB::Exception(DB::ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Illegal type {} of argument of function {}",
arguments[0]->getName(), getName());
if (!isString(arguments[1]))
throw DB::Exception(DB::ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Illegal type {} of argument of function {}",
arguments[1]->getName(), getName());
return DB::makeNullable(std::make_shared<DB::DataTypeString>());
}
DB::ColumnPtr executeImpl(const DB::ColumnsWithTypeAndName & arguments, const DB::DataTypePtr &, size_t /*input_rows_count*/) const override
{
const DB::ColumnPtr column = arguments[0].column;
const DB::ColumnPtr column_needle = arguments[1].column;
const DB::ColumnConst * col_needle = typeid_cast<const DB::ColumnConst *>(&*column_needle);
if (!col_needle)
throw DB::Exception(DB::ErrorCodes::ILLEGAL_COLUMN, "Second argument of function {} must be constant string", getName());
if (const DB::ColumnString * col = DB::checkAndGetColumn<DB::ColumnString>(column.get()))
{
auto col_res = DB::ColumnString::create();
auto null_map = DB::DataTypeUInt8().createColumn();
DB::ColumnString::Chars & vec_res = col_res->getChars();
DB::ColumnString::Offsets & offsets_res = col_res->getOffsets();
Impl::vector(col->getChars(), col->getOffsets(), col_needle->getValue<String>(), vec_res, offsets_res, *null_map);
return DB::ColumnNullable::create(std::move(col_res), std::move(null_map));
}
else
throw DB::Exception(DB::ErrorCodes::ILLEGAL_COLUMN, "Illegal column {} of argument of function {}",
arguments[0].column->getName(), getName());
}
};
/// Different from CH extractURLParameters which returns an array result.
struct NameSparkExtractURLQuery
{
static constexpr auto name = "spark_parse_url_query";
};
struct SparkExtractURLQuery
{
static size_t getReserveLengthForElement() { return 15; }
static void execute(DB::Pos data, size_t size, DB::Pos & res_data, size_t & res_size)
{
res_data = data;
res_size = 0;
DB::Pos pos = data;
DB::Pos end = data + size;
const static String protocol_delim = "://";
DB::Pos protocol_delim_pos = static_cast<DB::Pos>(memmem(pos, end - pos, protocol_delim.data(), protocol_delim.size()));
DB::Pos query_string_begin = nullptr;
if (protocol_delim_pos)
{
query_string_begin = find_first_symbols<'?', '#'>(pos, end);
}
else
{
query_string_begin = find_first_symbols<'?', '#', ':'>(pos, end);
}
if (query_string_begin && query_string_begin < end)
{
if (*query_string_begin != '?')
{
res_data = nullptr;
res_size = 0;
return;
}
res_data = query_string_begin + 1;
DB::Pos query_string_end = find_first_symbols<'#'>(res_data, end);
if (query_string_end && query_string_end < end)
{
res_size = query_string_end - res_data;
}
else
{
res_size = end - res_data;
}
}
else
{
res_data = nullptr;
res_size = 0;
}
}
};
using SparkFunctionURLQuery = FunctionStringToNullableString<ExtractNullableSubstringImpl<SparkExtractURLQuery>, NameSparkExtractURLQuery>;
REGISTER_FUNCTION(SparkFunctionURLQuery)
{
factory.registerFunction<SparkFunctionURLQuery>();
}
struct NameSparkExtractURLOneQuery
{
static constexpr auto name = "spark_parse_url_one_query";
};
struct SparkExtractURLOneQuery
{
static void vector(const DB::ColumnString::Chars & data,
const DB::ColumnString::Offsets & offsets,
std::string pattern,
DB::ColumnString::Chars & res_data, DB::ColumnString::Offsets & res_offsets, DB::IColumn & null_map)
{
const static String protocol_delim = "://";
res_data.reserve(data.size() / 5);
res_offsets.resize(offsets.size());
pattern += '=';
const char * param_str = pattern.c_str();
size_t param_len = pattern.size();
DB::ColumnString::Offset prev_offset = 0;
DB::ColumnString::Offset res_offset = 0;
for (size_t i = 0; i < offsets.size(); ++i)
{
DB::ColumnString::Offset cur_offset = offsets[i];
const char * str = reinterpret_cast<const char *>(&data[prev_offset]);
const char * end = reinterpret_cast<const char *>(&data[cur_offset]);
/// Find query string or fragment identifier.
/// Note that we support parameters in fragment identifier in the same way as in query string.
DB::Pos protocol_delim_pos = static_cast<DB::Pos>(memmem(str, end - str, protocol_delim.data(), protocol_delim.size()));
DB::Pos query_string_begin = nullptr;
if (protocol_delim_pos)
{
query_string_begin = find_first_symbols<'?', '#'>(protocol_delim_pos, end);
}
else
{
query_string_begin = find_first_symbols<'?', '#', ':'>(str, end);
}
if (*query_string_begin != '?')
{
query_string_begin = end;
}
/// Will point to the beginning of "name=value" pair. Then it will be reassigned to the beginning of "value".
const char * param_begin = nullptr;
if (query_string_begin + 1 < end)
{
param_begin = query_string_begin + 1;
while (true)
{
param_begin = static_cast<const char *>(memmem(param_begin, end - param_begin, param_str, param_len));
if (!param_begin)
break;
if (param_begin[-1] != '?' && param_begin[-1] != '#' && param_begin[-1] != '&')
{
/// Parameter name is different but has the same suffix.
param_begin += param_len;
continue;
}
else
{
param_begin += param_len;
break;
}
}
}
if (param_begin)
{
const char * param_end = find_first_symbols<'&', '#'>(param_begin, end);
if (param_end == end)
param_end = param_begin + strlen(param_begin);
size_t param_size = param_end - param_begin;
res_data.resize(res_offset + param_size + 1);
memcpySmallAllowReadWriteOverflow15(&res_data[res_offset], param_begin, param_size);
res_offset += param_size;
null_map.insert(0);
}
else
{
/// No parameter found, put empty string in result.
res_data.resize(res_offset + 1);
null_map.insert(1);
}
res_data[res_offset] = 0;
++res_offset;
res_offsets[i] = res_offset;
prev_offset = cur_offset;
}
}
};
using SparkFunctionURLOneQuery = FunctionsStringSearchToNullableString<SparkExtractURLOneQuery, NameSparkExtractURLOneQuery>;
REGISTER_FUNCTION(SparkFunctionURLOneQuery)
{
factory.registerFunction<SparkFunctionURLOneQuery>();
}
struct SparkExtractURLHost
{
static size_t getReserveLengthForElement() { return 15; }
static void execute(DB::Pos data, size_t size, DB::Pos & res_data, size_t & res_size)
{
DB::Pos end = data + size;
const static String protocol_delim = "://";
DB::Pos protocol_delim_start = static_cast<DB::Pos>(memmem(data, size, protocol_delim.data(), protocol_delim.size()));
if (!protocol_delim_start)
{
res_data = nullptr;
res_size = 0;
return;
}
DB::Pos userinfo_delim_pos = find_first_symbols<'@'>(protocol_delim_start + protocol_delim.size(), end);
std::string_view host;
if (userinfo_delim_pos && userinfo_delim_pos < end)
{
host = DB::getURLHost(userinfo_delim_pos + 1, end - userinfo_delim_pos);
}
else
{
host = DB::getURLHost(protocol_delim_start + protocol_delim.size() , end - protocol_delim_start - protocol_delim.size());
}
if (host.empty())
{
res_data = data;
res_size = 0;
}
else
{
res_data = host.data();
res_size = host.size();
}
}
};
struct NameSparkExtractURLHost
{
static constexpr auto name = "spark_parse_url_host";
};
using SparkFunctionURLHost = FunctionStringToNullableString<ExtractNullableSubstringImpl<SparkExtractURLHost>, NameSparkExtractURLHost>;
REGISTER_FUNCTION(SparkFunctionURLHost)
{
factory.registerFunction<SparkFunctionURLHost>();
}
struct NameSparkExtractURLPath
{
static constexpr auto name = "spark_parse_url_path";
};
struct SparkExtractURLPath
{
static size_t getReserveLengthForElement() { return 25; }
static void execute(DB::Pos data, size_t size, DB::Pos & res_data, size_t & res_size)
{
res_data = data;
res_size = 0;
DB::Pos pos = data;
DB::Pos end = data + size;
const static String protocol_delim = "://";
const auto * start_pos = static_cast<DB::Pos>(memmem(pos, end - pos, protocol_delim.data(), protocol_delim.size()));
if (start_pos)
{
start_pos += protocol_delim.size();
const auto * path_start_pos = find_first_symbols<'/', '#', '?'>(start_pos, end);
if (path_start_pos && path_start_pos < end)
{
if (*path_start_pos != '/')
return;
res_data = path_start_pos;
const auto * path_end_pos = find_first_symbols<'?', '#'>(path_start_pos, end);
if (path_end_pos && path_end_pos < end)
{
res_size = path_end_pos - path_start_pos;
}
else
{
res_size = end - path_start_pos;
}
}
}
}
};
using SparkFunctionURLPath = DB::FunctionStringToString<DB::ExtractSubstringImpl<SparkExtractURLPath>, NameSparkExtractURLPath>;
REGISTER_FUNCTION(SparkFunctionURLPath)
{
factory.registerFunction<SparkFunctionURLPath>();
}
struct NameSparkExtractUserInfo
{
static constexpr auto name = "spark_parse_url_userinfo";
};
struct SparkExtractURLUserInfo
{
static size_t getReserveLengthForElement() { return 25; }
static void execute(DB::Pos data, size_t size, DB::Pos & res_data, size_t & res_size)
{
res_data = data;
res_size = 0;
DB::Pos pos = data;
DB::Pos end = data + size;
const static String protocol_delim = "://";
const static String userinfo_delim = "@";
DB::Pos protocol_delim_start = static_cast<DB::Pos>(memmem(pos, end - pos, protocol_delim.data(), protocol_delim.size()));
if (!protocol_delim_start)
{
res_data = nullptr;
res_size = 0;
return;
}
res_data = protocol_delim_start + protocol_delim.size();
DB::Pos userinfo_delim_start = find_first_symbols<'@'>(res_data, end);
if (!userinfo_delim_start || userinfo_delim_start >= end)
{
res_data = nullptr;
res_size = 0;
return;
}
res_size = userinfo_delim_start - res_data;
}
};
using SparkFunctionURLUserInfo = FunctionStringToNullableString<ExtractNullableSubstringImpl<SparkExtractURLUserInfo>, NameSparkExtractUserInfo>;
REGISTER_FUNCTION(SparkFunctionURLUserInfo)
{
factory.registerFunction<SparkFunctionURLUserInfo>();
}
struct NameSparkExtractURLRef
{
static constexpr auto name = "spark_parse_url_ref";
};
struct SparkExtractURLRef
{
static size_t getReserveLengthForElement() { return 25; }
static void execute(DB::Pos data, size_t size, DB::Pos & res_data, size_t & res_size)
{
res_data = data;
res_size = 0;
DB::Pos pos = data;
DB::Pos end = data + size;
const static String ref_delim = "#";
const auto * ref_delim_pos = find_first_symbols<'#'>(pos, end);
if (ref_delim_pos && ref_delim_pos < end)
{
res_data = ref_delim_pos + 1;
res_size = end - res_data;
}
else
{
res_data = nullptr;
res_size = 0;
}
}
};
using SparkFunctionURLRef = FunctionStringToNullableString<ExtractNullableSubstringImpl<SparkExtractURLRef>, NameSparkExtractURLRef>;
REGISTER_FUNCTION(SparkFunctionURLRef)
{
factory.registerFunction<SparkFunctionURLRef>();
}
struct NameSparkExtractURLFile
{
static constexpr auto name = "spark_parse_url_file";
};
struct SparkExtractURLFile
{
static size_t getReserveLengthForElement() { return 25; }
static void execute(DB::Pos data, size_t size, DB::Pos & res_data, size_t & res_size)
{
res_data = data;
res_size = 0;
DB::Pos pos = data;
DB::Pos end = data + size;
const static String protocol_delim = "://";
const static String slash_delim = "/";
const static String query_delim = "?";
const auto * protocol_delim_pos = static_cast<DB::Pos>(memmem(pos, end - pos, protocol_delim.data(), protocol_delim.size()));
if (!protocol_delim_pos)
{
auto colon_pos = find_first_symbols<':'>(pos, end);
if (colon_pos && colon_pos + 1 < end)
{
res_data = nullptr;
return;
}
res_size = size;
return;
}
DB::Pos file_begin_pos = find_first_symbols<'/', '?', '#'>(protocol_delim_pos + protocol_delim.size(), end);
if (file_begin_pos && file_begin_pos < end)
{
if (*file_begin_pos == '#')
{
return;
}
res_data = file_begin_pos;
DB::Pos ref_delim_pos = find_first_symbols<'#'>(file_begin_pos + 1, end);
if (ref_delim_pos && ref_delim_pos < end)
{
res_size = ref_delim_pos - res_data;
}
else
{
res_size = end - res_data;
}
}
}
};
using SparkFunctionURLFile = FunctionStringToNullableString<ExtractNullableSubstringImpl<SparkExtractURLFile>, NameSparkExtractURLFile>;
REGISTER_FUNCTION(SparkFunctionURLFile)
{
factory.registerFunction<SparkFunctionURLFile>();
}
struct NameSparkExtractURLAuthority
{
static constexpr auto name = "spark_parse_url_authority";
};
struct SparkExtractURLAuthority
{
static size_t getReserveLengthForElement() { return 25; }
static void execute(DB::Pos data, size_t size, DB::Pos & res_data, size_t & res_size)
{
res_data = data;
res_size = 0;
DB::Pos pos = data;
DB::Pos end = data + size;
const static String protocol_delim = "://";
const auto * protocol_delim_pos = static_cast<DB::Pos>(memmem(pos, end - pos, protocol_delim.data(), protocol_delim.size()));
if (!protocol_delim_pos)
{
res_data = nullptr;
res_size = 0;
return;
}
res_data = protocol_delim_pos + protocol_delim.size();
DB::Pos end_pos = find_first_symbols<'/', '?', '#'>(res_data, end);
if (end_pos)
{
res_size = end_pos - res_data;
}
else
{
res_size = end - res_data -1 ;
}
}
};
using SparkFunctionURLAuthority
= FunctionStringToNullableString<ExtractNullableSubstringImpl<SparkExtractURLAuthority>, NameSparkExtractURLAuthority>;
REGISTER_FUNCTION(SparkFunctionURLAuthority)
{
factory.registerFunction<SparkFunctionURLAuthority>();
}
struct NameSparkExtractURLInvalid
{
static constexpr auto name = "spark_parse_url_invalid";
};
struct SparkExtractURLInvalid
{
static size_t getReserveLengthForElement() { return 1; }
static void execute(DB::Pos, size_t, DB::Pos & res_data, size_t & res_size)
{
res_data = nullptr;
res_size = 0;
}
};
using SparkFunctionURLInvalid = FunctionStringToNullableString<ExtractNullableSubstringImpl<SparkExtractURLInvalid>, NameSparkExtractURLInvalid>;
REGISTER_FUNCTION(SparkFunctionURLInvalid)
{
factory.registerFunction<SparkFunctionURLInvalid>();
}
}