backend/query/ann_validator.cc (357 lines of code) (raw):

// // Copyright 2020 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // #include "backend/query/ann_validator.h" #include <string> #include <vector> #include "zetasql/public/function.h" #include "zetasql/public/value.h" #include "zetasql/resolved_ast/resolved_ast.h" #include "zetasql/resolved_ast/resolved_column.h" #include "zetasql/resolved_ast/resolved_node.h" #include "absl/log/check.h" #include "absl/log/log.h" #include "absl/status/status.h" #include "absl/strings/match.h" #include "backend/query/ann_functions_rewriter.h" #include "backend/schema/catalog/column.h" #include "backend/schema/catalog/index.h" #include "backend/schema/ddl/operations.pb.h" #include "common/errors.h" #include "zetasql/base/ret_check.h" #include "zetasql/base/status_macros.h" namespace google { namespace spanner { namespace emulator { namespace backend { ddl::VectorIndexOptionsProto::DistanceType ANNFunctionsToDistanceType( const std::string& function_name) { if (function_name == "approx_cosine_distance") { return ddl::VectorIndexOptionsProto::COSINE; } else if (function_name == "approx_dot_product") { return ddl::VectorIndexOptionsProto::DOT_PRODUCT; } else if (function_name == "approx_euclidean_distance") { return ddl::VectorIndexOptionsProto::EUCLIDEAN; } ABSL_LOG(FATAL) << "Invalid ANN function: " << function_name; // Crash OK } absl::Status GetAnnFunctionCall( const zetasql::ResolvedLimitOffsetScan* node, const zetasql::ResolvedFunctionCall*& ann_func, std::vector<const zetasql::ResolvedNode*>& child_nodes) { if (!node->input_scan()->Is<zetasql::ResolvedOrderByScan>()) { return absl::InvalidArgumentError("No order by clause."); } const zetasql::ResolvedOrderByScan* orderby_scan = node->input_scan()->GetAs<zetasql::ResolvedOrderByScan>(); if (orderby_scan->order_by_item_list().size() != 1 || !orderby_scan->order_by_item_list()[0] ->Is<zetasql::ResolvedOrderByItem>()) { return absl::InvalidArgumentError("Invalid order by clause."); } zetasql::ResolvedColumn order_by_column = orderby_scan->order_by_item_list()[0] ->GetAs<zetasql::ResolvedOrderByItem>() ->column_ref() ->column(); if (!orderby_scan->input_scan()->Is<zetasql::ResolvedProjectScan>()) { return absl::InvalidArgumentError("Input scan is not a project scan."); } const zetasql::ResolvedProjectScan* project_scan = node->input_scan() ->GetAs<zetasql::ResolvedOrderByScan>() ->input_scan() ->GetAs<zetasql::ResolvedProjectScan>(); project_scan->GetChildNodes(&child_nodes); for (auto child : child_nodes) { if (child->Is<zetasql::ResolvedComputedColumn>()) { const zetasql::ResolvedComputedColumn* cc = child->GetAs<zetasql::ResolvedComputedColumn>(); if (order_by_column != cc->column()) { return absl::InvalidArgumentError("Invalid order by clause."); } std::vector<const zetasql::ResolvedNode*> computed_children; cc->GetChildNodes(&computed_children); for (auto computed_child : computed_children) { if (computed_child->Is<zetasql::ResolvedFunctionCall>()) { const zetasql::ResolvedFunctionCall* func = computed_child->GetAs<zetasql::ResolvedFunctionCall>(); ZETASQL_RET_CHECK(func->function() != nullptr); if (IsANNFunction(func->function()->Name())) { ZETASQL_RET_CHECK(ann_func == nullptr); ann_func = func; } } } } } return absl::OkStatus(); } absl::Status GetNotNullColumns( std::vector<const zetasql::ResolvedNode*>& child_nodes, const zetasql::ResolvedScan*& scan, std::vector<zetasql::ResolvedColumn>& not_null_columns) { for (auto child : child_nodes) { if (child->Is<zetasql::ResolvedScan>()) { scan = child->GetAs<zetasql::ResolvedScan>(); if (scan->Is<zetasql::ResolvedFilterScan>()) { const zetasql::ResolvedFilterScan* filter_scan = scan->GetAs<zetasql::ResolvedFilterScan>(); scan = filter_scan->input_scan(); if (filter_scan->filter_expr() != nullptr && filter_scan->filter_expr()->Is<zetasql::ResolvedFunctionCall>()) { const zetasql::ResolvedFunctionCall* func = filter_scan->filter_expr() ->GetAs<zetasql::ResolvedFunctionCall>(); if (func->function()->Name() == "$and") { for (const auto& arg : func->argument_list()) { if (arg->Is<zetasql::ResolvedFunctionCall>()) { const zetasql::ResolvedFunctionCall* not_func = arg->GetAs<zetasql::ResolvedFunctionCall>(); if (not_func->function()->Name() == "$not" && not_func->argument_list_size() == 1 && not_func->argument_list(0) ->Is<zetasql::ResolvedFunctionCall>() && not_func->argument_list(0) ->GetAs<zetasql::ResolvedFunctionCall>() ->function() ->Name() == "$is_null") { const zetasql::ResolvedFunctionCall* is_null_func = not_func->argument_list(0) ->GetAs<zetasql::ResolvedFunctionCall>(); if (is_null_func->argument_list_size() == 1 && is_null_func->argument_list(0) ->Is<zetasql::ResolvedColumnRef>()) { not_null_columns.push_back( is_null_func->argument_list(0) ->GetAs<zetasql::ResolvedColumnRef>() ->column()); } } } } } else if (func->function()->Name() == "$not") { if (func->argument_list_size() == 1 && func->argument_list(0)->Is<zetasql::ResolvedFunctionCall>() && func->argument_list(0) ->GetAs<zetasql::ResolvedFunctionCall>() ->function() ->Name() == "$is_null") { const zetasql::ResolvedFunctionCall* is_null_func = func->argument_list(0) ->GetAs<zetasql::ResolvedFunctionCall>(); if (is_null_func->argument_list_size() == 1 && is_null_func->argument_list(0) ->Is<zetasql::ResolvedColumnRef>()) { not_null_columns.push_back( is_null_func->argument_list(0) ->GetAs<zetasql::ResolvedColumnRef>() ->column()); } } } } } } } return absl::OkStatus(); } absl::Status GetANNFunctionArguments( const zetasql::ResolvedFunctionCall* last_ann_func, zetasql::ResolvedColumn& ann_func_column, zetasql::Value& ann_func_value) { if (last_ann_func->argument_list(0)->Is<zetasql::ResolvedColumnRef>()) { ann_func_column = last_ann_func->argument_list()[0] ->GetAs<zetasql::ResolvedColumnRef>() ->column(); if (last_ann_func->argument_list(1)->Is<zetasql::ResolvedLiteral>()) { ann_func_value = last_ann_func->argument_list()[1] ->GetAs<zetasql::ResolvedLiteral>() ->value(); } else if (!last_ann_func->argument_list(1) ->Is<zetasql::ResolvedParameter>()) { return error::ApproxDistanceInvalidShape( last_ann_func->function()->Name()); } } else if (last_ann_func->argument_list(1) ->Is<zetasql::ResolvedColumnRef>()) { ann_func_column = last_ann_func->argument_list()[1] ->GetAs<zetasql::ResolvedColumnRef>() ->column(); if (last_ann_func->argument_list(0)->Is<zetasql::ResolvedLiteral>()) { ann_func_value = last_ann_func->argument_list()[0] ->GetAs<zetasql::ResolvedLiteral>() ->value(); } else if (!last_ann_func->argument_list(0) ->Is<zetasql::ResolvedParameter>()) { return error::ApproxDistanceInvalidShape( last_ann_func->function()->Name()); } } else { return error::ApproxDistanceInvalidShape(last_ann_func->function()->Name()); } return absl::OkStatus(); } absl::Status ValidateFunctionArguments( const zetasql::Value& ann_func_value, const zetasql::ResolvedFunctionCall* last_ann_func) { if (ann_func_value.is_valid()) { if (ann_func_value.is_null() || !ann_func_value.type()->IsArray()) { return error::ApproxDistanceInvalidShape( last_ann_func->function()->Name()); } std::vector<zetasql::Value> elements = ann_func_value.elements(); bool is_all_zero = true; for (const auto& element : elements) { if (element.is_null() || (!element.type()->IsDouble() && !element.type()->IsFloat())) { return error::ApproxDistanceInvalidShape( last_ann_func->function()->Name()); } double value = element.ToDouble(); if (value != 0) { is_all_zero = false; } } if (is_all_zero && last_ann_func->function()->Name() == "approx_cosine_distance") { return absl::InvalidArgumentError( "Cannot compute cosine distance against zero vector."); } } return absl::OkStatus(); } absl::StatusOr<const Index*> FindVectorIndex( const std::vector<const Index*>& indexes, const zetasql::ResolvedColumn& ann_func_column, const zetasql::Value& ann_func_value, const zetasql::ResolvedFunctionCall* last_ann_func, bool is_force_index) { ddl::VectorIndexOptionsProto::DistanceType distance_type = ANNFunctionsToDistanceType(last_ann_func->function()->Name()); int i = 0; bool found_column = false; for (; i < indexes.size(); ++i) { const Index* index = indexes[i]; ZETASQL_RET_CHECK(index->key_columns().size() == 1); const Column* key_column = index->key_columns()[0]->column(); if (index->indexed_table()->Name() == ann_func_column.table_name() && key_column->Name() == ann_func_column.name()) { found_column = true; if (!key_column->has_vector_length()) { return error::ApproxDistanceInvalidShape( last_ann_func->function()->Name()); } if (ann_func_value.is_valid()) { int vector_length = *key_column->vector_length(); if (vector_length != ann_func_value.elements().size()) { return error::ApproxDistanceLengthMismatch( last_ann_func->function()->Name(), ann_func_value.elements().size(), vector_length); } } ddl::VectorIndexOptionsProto::DistanceType index_distance_type; if (!index->vector_index_options().has_distance_type() || !ddl::VectorIndexOptionsProto::DistanceType_Parse( index->vector_index_options().distance_type(), &index_distance_type) || index_distance_type == ddl::VectorIndexOptionsProto::DISTANCE_TYPE_UNSPECIFIED || index_distance_type == distance_type) { break; } } } if (i == indexes.size()) { if (is_force_index) { if (found_column) { return error::VectorIndexesUnusableForceIndexWrongDistanceType( indexes[0]->Name(), indexes[0]->vector_index_options().distance_type(), last_ann_func->function()->Name(), ann_func_column.name()); } else { return error::VectorIndexesUnusableForceIndexWrongColumn( indexes[0]->Name(), last_ann_func->function()->Name(), ann_func_column.name()); } } return error::VectorIndexesUnusable( ddl::VectorIndexOptionsProto::DistanceType_Name(distance_type), ann_func_column.name(), last_ann_func->function()->Name()); } return indexes[i]; } absl::Status ANNValidator::VisitResolvedLimitOffsetScan( const zetasql::ResolvedLimitOffsetScan* node) { std::vector<const zetasql::ResolvedNode*> child_nodes; const zetasql::ResolvedFunctionCall* last_ann_func = nullptr; if (!GetAnnFunctionCall(node, last_ann_func, child_nodes).ok()) { return zetasql::ResolvedASTVisitor::DefaultVisit(node); } ann_functions_.insert(last_ann_func); std::vector<const Index*> indexes; std::vector<zetasql::ResolvedColumn> not_null_columns; bool is_force_index = false; const zetasql::ResolvedScan* scan = nullptr; ZETASQL_RETURN_IF_ERROR(GetNotNullColumns(child_nodes, scan, not_null_columns)); if (scan->Is<zetasql::ResolvedJoinScan>() && last_ann_func != nullptr) { return error::ApproxDistanceInvalidShape(last_ann_func->function()->Name()); } if (scan->Is<zetasql::ResolvedTableScan>() && !scan->GetAs<zetasql::ResolvedTableScan>()->hint_list().empty()) { for (const auto& hint : scan->GetAs<zetasql::ResolvedTableScan>()->hint_list()) { if (absl::EqualsIgnoreCase(hint->name(), "force_index")) { ZETASQL_RET_CHECK(hint->value()->Is<zetasql::ResolvedLiteral>()); indexes = schema_->FindIndexesUnderName( hint->value() ->GetAs<zetasql::ResolvedLiteral>() ->value() .string_value()); ZETASQL_RET_CHECK(indexes.size() == 1); is_force_index = true; if (!indexes[0]->is_vector_index() && last_ann_func != nullptr) { return error::NotVectorIndexes(indexes[0]->Name()); } } } } if (last_ann_func == nullptr) { return zetasql::ResolvedASTVisitor::DefaultVisit(node); } zetasql::ResolvedColumn ann_func_column; zetasql::Value ann_func_value; ZETASQL_RETURN_IF_ERROR( GetANNFunctionArguments(last_ann_func, ann_func_column, ann_func_value)); ZETASQL_RETURN_IF_ERROR(ValidateFunctionArguments(ann_func_value, last_ann_func)); if (indexes.empty()) { indexes = schema_->vector_indexes(); } ZETASQL_ASSIGN_OR_RETURN(const Index* vec_index, FindVectorIndex(indexes, ann_func_column, ann_func_value, last_ann_func, is_force_index)); ZETASQL_RET_CHECK(vec_index->key_columns().size() == 1); const KeyColumn* key_column = vec_index->key_columns()[0]; bool is_key_null_filtered = false; for (const auto* column : vec_index->null_filtered_columns()) { if (key_column->column()->Name() == column->Name()) { is_key_null_filtered = true; break; } } if (is_key_null_filtered) { bool is_not_null_column_found = false; for (const auto& not_null_column : not_null_columns) { if (vec_index->indexed_table()->Name() == not_null_column.table_name() && key_column->column()->Name() == not_null_column.name()) { is_not_null_column_found = true; break; } } if (!is_not_null_column_found) { return error::VectorIndexesUnusableNotNullFiltered( vec_index->Name(), key_column->column()->Name()); } } return zetasql::ResolvedASTVisitor::DefaultVisit(node); } } // namespace backend } // namespace emulator } // namespace spanner } // namespace google