flex/engines/graph_db/runtime/execute/ops/retrieve/procedure_call.cc (360 lines of code) (raw):

/** Copyright 2020 Alibaba Group Holding Limited. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "flex/engines/graph_db/runtime/execute/ops/retrieve/procedure_call.h" #include "flex/engines/graph_db/database/graph_db.h" #include "flex/engines/graph_db/database/graph_db_session.h" #include "flex/engines/graph_db/runtime/common/columns/i_context_column.h" #include "flex/engines/graph_db/runtime/common/columns/value_columns.h" #include "flex/engines/graph_db/runtime/common/context.h" #include "flex/engines/graph_db/runtime/common/leaf_utils.h" #include "flex/engines/graph_db/runtime/common/rt_any.h" #include "flex/engines/graph_db/runtime/utils/opr_timer.h" #include "flex/proto_generated_gie/algebra.pb.h" #include "flex/proto_generated_gie/physical.pb.h" namespace gs { namespace runtime { namespace ops { std::shared_ptr<IContextColumn> any_vec_to_column( const std::vector<RTAny>& any_vec) { if (any_vec.empty()) { return nullptr; } auto first = any_vec[0].type(); if (first == RTAnyType::kBoolValue) { ValueColumnBuilder<bool> builder; for (auto& any : any_vec) { builder.push_back_opt(any.as_bool()); } return builder.finish(nullptr); } else if (first == RTAnyType::kI32Value) { ValueColumnBuilder<int32_t> builder; for (auto& any : any_vec) { builder.push_back_opt(any.as_int32()); } return builder.finish(nullptr); } else if (first == RTAnyType::kI64Value) { ValueColumnBuilder<int64_t> builder; for (auto& any : any_vec) { builder.push_back_opt(any.as_int64()); } return builder.finish(nullptr); } else if (first == RTAnyType::kU64Value) { ValueColumnBuilder<uint64_t> builder; for (auto& any : any_vec) { builder.push_back_opt(any.as_uint64()); } return builder.finish(nullptr); } else if (first == RTAnyType::kF64Value) { ValueColumnBuilder<double> builder; for (auto& any : any_vec) { builder.push_back_opt(any.as_double()); } return builder.finish(nullptr); } else if (first == RTAnyType::kStringValue) { ValueColumnBuilder<std::string_view> builder; std::shared_ptr<Arena> arena = std::make_shared<Arena>(); for (auto& any : any_vec) { auto ptr = StringImpl::make_string_impl(std::string(any.as_string())); auto sv = ptr->str_view(); arena->emplace_back(std::move(ptr)); builder.push_back_opt(sv); } return builder.finish(arena); } else if (first == RTAnyType::kTimestamp) { ValueColumnBuilder<Date> builder; for (auto& any : any_vec) { builder.push_back_opt(any.as_timestamp()); } return builder.finish(nullptr); } else { LOG(FATAL) << "Unsupported RTAny type: " << static_cast<int>(first); } } RTAny object_to_rt_any(const common::Value& val) { if (val.item_case() == common::Value::kBoolean) { return RTAny::from_bool(val.boolean()); } else if (val.item_case() == common::Value::kI32) { return RTAny::from_int32(val.i32()); } else if (val.item_case() == common::Value::kI64) { return RTAny::from_int64(val.i64()); } else if (val.item_case() == common::Value::kF64) { return RTAny::from_double(val.f64()); } else if (val.item_case() == common::Value::kStr) { return RTAny::from_string(val.str()); } else { LOG(FATAL) << "Unsupported value type: " << val.item_case(); } } Any property_to_any(const results::Property& prop) { // We just need the value; const auto& val = prop.value(); Any res; if (val.item_case() == common::Value::kBoolean) { res.set_bool(val.boolean()); } else if (val.item_case() == common::Value::kI32) { res.set_i32(val.i32()); } else if (val.item_case() == common::Value::kI64) { res.set_i64(val.i64()); } else if (val.item_case() == common::Value::kF64) { res.set_double(val.f64()); } else if (val.item_case() == common::Value::kStr) { res.set_string_view(std::string_view(val.str())); } else { LOG(FATAL) << "Unsupported value type: " << val.item_case(); } return res; } RTAny vertex_to_rt_any(const results::Vertex& vertex) { auto label_id = vertex.label().id(); auto label_id_vid = decode_unique_vertex_id(vertex.id()); CHECK(label_id == label_id_vid.first) << "Inconsistent label id."; return RTAny::from_vertex(label_id, label_id_vid.second); } RTAny edge_to_rt_any(const results::Edge& edge) { LOG(FATAL) << "Not implemented."; label_t src_label_id = (label_t) edge.src_label().id(); label_t dst_label_id = (label_t) edge.dst_label().id(); auto edge_triplet_tuple = decode_edge_label_id(edge.label().id()); CHECK((src_label_id == std::get<0>(edge_triplet_tuple)) && (dst_label_id == std::get<1>(edge_triplet_tuple))) << "Inconsistent src label id."; auto src_vertex_id = edge.src_id(); auto dst_vertex_id = edge.dst_id(); auto [_, src_vid] = decode_unique_vertex_id(src_vertex_id); auto [__, dst_vid] = decode_unique_vertex_id(dst_vertex_id); // properties auto properties = edge.properties(); LabelTriplet label_triplet{src_label_id, dst_label_id, std::get<2>(edge_triplet_tuple)}; if (properties.size() == 0) { EdgeRecord edge_record{label_triplet, src_vid, dst_vid, Any(), Direction::kOut}; return RTAny::from_edge(edge_record); } else if (properties.size() == 1) { LOG(FATAL) << "Not implemented."; EdgeRecord edge_record{label_triplet, src_vid, dst_vid, property_to_any(properties[0]), Direction::kOut}; return RTAny::from_edge(edge_record); } else { LOG(FATAL) << "Not implemented."; std::vector<Any> props; for (auto& prop : properties) { props.push_back(property_to_any(prop)); } Any any; any.set_record(props); return RTAny::from_edge( EdgeRecord{label_triplet, src_vid, dst_vid, any, Direction::kOut}); } } // namespace runtime RTAny graph_path_to_rt_any(const results::GraphPath& path) { LOG(FATAL) << "Not implemented."; } RTAny element_to_rt_any(const results::Element& element) { if (element.inner_case() == results::Element::kVertex) { return vertex_to_rt_any(element.vertex()); } else if (element.inner_case() == results::Element::kEdge) { return edge_to_rt_any(element.edge()); } else if (element.inner_case() == results::Element::kObject) { return object_to_rt_any(element.object()); } else if (element.inner_case() == results::Element::kGraphPath) { return graph_path_to_rt_any(element.graph_path()); } else { LOG(FATAL) << "Unsupported element type: " << element.inner_case(); } } RTAny collection_to_rt_any(const results::Collection& collection) { std::vector<RTAny> values; for (const auto& element : collection.collection()) { values.push_back(element_to_rt_any(element)); } LOG(FATAL) << "Not implemented."; return RTAny(); } RTAny column_to_rt_any(const results::Column& column) { auto& entry = column.entry(); if (entry.has_element()) { return element_to_rt_any(entry.element()); } else if (entry.has_collection()) { return collection_to_rt_any(entry.collection()); } else { LOG(FATAL) << "Unsupported column entry type: " << entry.inner_case(); } } std::vector<RTAny> result_to_rt_any(const results::Results& result) { auto& record = result.record(); if (record.columns_size() == 0) { LOG(WARNING) << "Empty result."; return {}; } else { std::vector<RTAny> tuple; for (int32_t i = 0; i < record.columns_size(); ++i) { tuple.push_back(column_to_rt_any(record.columns(i))); } return tuple; } } std::pair<std::vector<std::shared_ptr<IContextColumn>>, std::vector<size_t>> collective_result_vec_to_column( int32_t expect_col_num, const std::vector<results::CollectiveResults>& collective_results_vec) { std::vector<size_t> offsets; offsets.push_back(0); size_t record_cnt = 0; for (size_t i = 0; i < collective_results_vec.size(); ++i) { record_cnt += collective_results_vec[i].results_size(); offsets.push_back(record_cnt); } std::vector<std::vector<RTAny>> any_vec(expect_col_num); for (size_t i = 0; i < collective_results_vec.size(); ++i) { for (int32_t j = 0; j < collective_results_vec[i].results_size(); ++j) { auto tuple = result_to_rt_any(collective_results_vec[i].results(j)); CHECK(tuple.size() == (size_t) expect_col_num) << "Inconsistent column number."; for (int32_t k = 0; k < expect_col_num; ++k) { any_vec[k].push_back(tuple[k]); } } } std::vector<std::shared_ptr<IContextColumn>> columns; for (int32_t i = 0; i < expect_col_num; ++i) { columns.push_back(any_vec_to_column(any_vec[i])); } return std::make_pair(columns, offsets); } bl::result<procedure::Query> fill_in_query(const procedure::Query& query, const Context& ctx, size_t idx) { procedure::Query real_query; real_query.mutable_query_name()->CopyFrom(query.query_name()); for (auto& param : query.arguments()) { auto argument = real_query.add_arguments(); if (param.value_case() == procedure::Argument::kVar) { auto& var = param.var(); auto tag = var.tag().id(); auto col = ctx.get(tag); if (col == nullptr) { LOG(ERROR) << "Tag not found: " << tag; continue; } auto val = col->get_elem(idx); auto const_value = argument->mutable_const_(); if (val.type() == gs::runtime::RTAnyType::kVertex) { RETURN_BAD_REQUEST_ERROR("The input param should not be a vertex"); } else if (val.type() == gs::runtime::RTAnyType::kEdge) { RETURN_BAD_REQUEST_ERROR("The input param should not be an edge"); } else if (val.type() == gs::runtime::RTAnyType::kI64Value) { const_value->set_i64(val.as_int64()); } else if (val.type() == gs::runtime::RTAnyType::kI32Value) { const_value->set_i32(val.as_int32()); } else if (val.type() == gs::runtime::RTAnyType::kStringValue) { const_value->set_str(std::string(val.as_string())); } else if (val.type() == gs::runtime::RTAnyType::kF64Value) { const_value->set_f64(val.as_double()); } else if (val.type() == gs::runtime::RTAnyType::kBoolValue) { const_value->set_boolean(val.as_bool()); } else if (val.type() == gs::runtime::RTAnyType::kDate32) { const_value->set_i64(val.as_timestamp().milli_second); } else { LOG(ERROR) << "Unsupported type: " << static_cast<int32_t>(val.type()); } } else { argument->CopyFrom(param); } } return real_query; } /** * @brief Evaluate the ProcedureCall operator. * The ProcedureCall operator is used to call a stored procedure, which is * already registered in the system. The return value of the stored procedure * is a result::CollectiveResults object, we need to convert it to a Column, * and append to the current context. * * * @param opr The ProcedureCall operator. * @param txn The read transaction. * @param ctx The input context. * * @return bl::result<Context> The output context. * * */ bl::result<Context> eval_procedure_call(const std::vector<int32_t>& aliases, const physical::ProcedureCall& opr, const GraphReadInterface& txn, Context&& ctx) { auto& query = opr.query(); auto& proc_name = query.query_name(); if (proc_name.item_case() == common::NameOrId::kName) { const auto& sess = txn.GetSession(); // cast off const, to get the app pointer. // Why do we need to cast off const? Because current GetApp method is not // const. // TODO(zhanglei): Refactor the GetApp method to be const(maybe create the // app once initialize, not on need). GraphDBSession& sess_cast = const_cast<GraphDBSession&>(sess); AppBase* app = const_cast<AppBase*>(sess_cast.GetApp(proc_name.name())); if (!app) { RETURN_BAD_REQUEST_ERROR("Stored procedure not found: " + proc_name.name()); } ReadAppBase* read_app = dynamic_cast<ReadAppBase*>(app); if (!app) { RETURN_BAD_REQUEST_ERROR("Stored procedure is not a read procedure: " + proc_name.name()); } std::vector<results::CollectiveResults> results; // Iterate over current context. for (size_t i = 0; i < ctx.row_num(); ++i) { // Call the procedure. // Use real values from the context to replace the placeholders in the // query. BOOST_LEAF_AUTO(real_query, fill_in_query(query, ctx, i)); // We need to serialize the protobuf-based arguments to the input format // that a cypher procedure can accept. auto query_str = real_query.SerializeAsString(); // append CYPHER_PROTO as the last byte as input_format query_str.push_back(static_cast<char>( GraphDBSession::InputFormat::kCypherProtoProcedure)); std::vector<char> buffer; Encoder encoder(buffer); Decoder decoder(query_str.data(), query_str.size()); if (!read_app->Query(sess, decoder, encoder)) { RETURN_CALL_PROCEDURE_ERROR("Failed to call procedure: "); } // Decode the result from the encoder. Decoder result_decoder(buffer.data(), buffer.size()); if (result_decoder.size() < 4) { LOG(ERROR) << "Unexpected result size: " << result_decoder.size(); RETURN_CALL_PROCEDURE_ERROR("Unexpected result size"); } std::string collective_results_str(result_decoder.get_string()); results::CollectiveResults collective_results; if (!collective_results.ParseFromString(collective_results_str)) { LOG(ERROR) << "Failed to parse CollectiveResults"; RETURN_CALL_PROCEDURE_ERROR("Failed to parse procedure's result"); } results.push_back(collective_results); } auto column_and_offsets = collective_result_vec_to_column(aliases.size(), results); auto& columns = column_and_offsets.first; auto& offsets = column_and_offsets.second; if (columns.size() != aliases.size()) { LOG(ERROR) << "Column size mismatch: " << columns.size() << " vs " << aliases.size(); RETURN_CALL_PROCEDURE_ERROR("Column size mismatch"); } if (columns.size() >= 1) { ctx.set_with_reshuffle(aliases[0], columns[0], offsets); } for (size_t i = 1; i < columns.size(); ++i) { ctx.set(aliases[i], columns[i]); } return std::move(ctx); } else { LOG(ERROR) << "Currently only support calling stored procedure by name"; RETURN_UNSUPPORTED_ERROR( "Currently only support calling stored procedure by name"); } } class ProcedureCallOpr : public IReadOperator { public: ProcedureCallOpr(const std::vector<int32_t>& aliases, const physical::ProcedureCall& opr) : aliases_(aliases), opr_(opr) {} std::string get_operator_name() const override { return "ProcedureCallOpr"; } bl::result<Context> Eval(const GraphReadInterface& txn, const std::map<std::string, std::string>&, Context&& ctx, OprTimer&) override { auto ret = eval_procedure_call(aliases_, opr_, txn, std::move(ctx)); return ret; } private: std::vector<int32_t> aliases_; physical::ProcedureCall opr_; }; bl::result<ReadOpBuildResultT> ProcedureCallOprBuilder::Build( const gs::Schema& schema, const ContextMeta& ctx_meta, const physical::PhysicalPlan& plan, int op_idx) { auto& opr = plan.plan(op_idx); std::vector<int32_t> aliases; ContextMeta ret_meta; for (int32_t i = 0; i < opr.meta_data_size(); ++i) { aliases.push_back(opr.meta_data(i).alias()); ret_meta.set(opr.meta_data(i).alias()); } return std::make_pair( std::make_unique<ProcedureCallOpr>(aliases, opr.opr().procedure_call()), ret_meta); } } // namespace ops } // namespace runtime } // namespace gs