backend/query/ann_functions_rewriter.cc (79 lines of code) (raw):
//
// Copyright 2020 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
#include "backend/query/ann_functions_rewriter.h"
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include "zetasql/public/function_signature.h"
#include "zetasql/public/json_value.h"
#include "zetasql/public/value.h"
#include "zetasql/resolved_ast/resolved_ast.h"
#include "absl/status/status.h"
#include "common/errors.h"
#include "zetasql/base/ret_check.h"
#include "zetasql/base/status_macros.h"
namespace google {
namespace spanner {
namespace emulator {
namespace backend {
bool IsANNFunction(std::string function_name) {
if (function_name == "approx_cosine_distance" ||
function_name == "approx_dot_product" ||
function_name == "approx_euclidean_distance") {
return true;
}
return false;
}
absl::Status ANNFunctionsRewriter::VisitResolvedFunctionCall(
const zetasql::ResolvedFunctionCall* node) {
if (!IsANNFunction(node->function()->Name())) {
return CopyVisitResolvedFunctionCall(node);
}
if (node->argument_list_size() < 3) {
return error::ApproxDistanceFunctionOptionsRequired(
node->function()->Name());
}
ZETASQL_RET_CHECK(node->argument_list_size() == 3);
const zetasql::ResolvedExpr* argument_for_placeholder =
node->argument_list().back().get();
if (!argument_for_placeholder->Is<zetasql::ResolvedLiteral>()) {
return error::ApproxDistanceFunctionOptionMustBeLiteral(
node->function()->Name());
}
const zetasql::Value& placeholder_value =
argument_for_placeholder->GetAs<zetasql::ResolvedLiteral>()->value();
ZETASQL_RET_CHECK(placeholder_value.has_content());
if (placeholder_value.type_kind() == zetasql::TYPE_JSON) {
zetasql::JSONValueConstRef json_value = placeholder_value.json_value();
if (!json_value.HasMember("num_leaves_to_search")) {
return error::ApproxDistanceFunctionInvalidJsonOption(
node->function()->Name());
}
zetasql::JSONValueConstRef leaves_json =
json_value.GetMember("num_leaves_to_search");
if (!leaves_json.IsUInt64()) {
return error::ApproxDistanceFunctionInvalidJsonOption(
node->function()->Name());
}
}
std::vector<std::unique_ptr<zetasql::ResolvedExpr>> argument_list;
zetasql::FunctionArgumentTypeList argument_types;
for (int i = 0; i < node->signature().arguments().size() - 1; ++i) {
ZETASQL_ASSIGN_OR_RETURN(argument_list.emplace_back(),
Copy(node->argument_list(i)));
argument_types.push_back(node->signature().argument(i));
}
zetasql::FunctionSignature new_signature(
node->signature().result_type(), argument_types,
node->signature().context_id(), node->signature().options());
std::unique_ptr<zetasql::ResolvedFunctionCall> new_node =
zetasql::MakeResolvedFunctionCall(
node->type(), node->function(), new_signature,
std::move(argument_list), node->error_mode());
ann_functions_.insert(new_node.get());
PushNodeToStack(std::move(new_node));
return absl::OkStatus();
}
} // namespace backend
} // namespace emulator
} // namespace spanner
} // namespace google