analytical_engine/apps/sampling_path/sampling_path.h (95 lines of code) (raw):

/** Copyright 2020 Alibaba Group Holding Limited. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #ifndef ANALYTICAL_ENGINE_APPS_SAMPLING_PATH_SAMPLING_PATH_H_ #define ANALYTICAL_ENGINE_APPS_SAMPLING_PATH_SAMPLING_PATH_H_ #include <queue> #include <utility> #include <vector> #include "apps/sampling_path/sampling_path_context.h" namespace gs { /** * @brief Sampling paths obey source label-edge label-destination label pattern. * @tparam FRAG_T */ template <typename FRAG_T> class SamplingPath : public PropertyAppBase<FRAG_T, SamplingPathContext<FRAG_T>>, public grape::Communicator { public: INSTALL_DEFAULT_PROPERTY_WORKER(SamplingPath<FRAG_T>, SamplingPathContext<FRAG_T>, FRAG_T) using vid_t = typename FRAG_T::vid_t; using vertex_t = typename FRAG_T::vertex_t; using path_t = typename context_t::path_t; using layer_t = int; void BFS(const fragment_t& frag, context_t& ctx, message_manager_t& messages, std::queue<std::pair<layer_t, path_t>>& paths) { auto& path_pattern = ctx.path_pattern; auto& path_result = ctx.path_result; while (!paths.empty()) { auto& pair = paths.front(); auto level = pair.first; auto& path = pair.second; if ((size_t)(level + 2) < ctx.path_pattern.size()) { vertex_t u; auto curr_e_label = ctx.path_pattern[level + 1]; auto curr_v_label = ctx.path_pattern[level + 2]; CHECK_GT(path.size(), 0); CHECK(frag.Gid2Vertex(path[path.size() - 1], u)); auto oes = frag.GetOutgoingAdjList(u, curr_e_label); for (auto& e : oes) { auto v = e.neighbor(); if (frag.vertex_label(v) == curr_v_label) { std::vector<vid_t> new_path(path); new_path.push_back(frag.Vertex2Gid(v)); // |pattern| = k, the result should have "k / 2 + 1" vertices // e.g. pattern = "v0-e0-v1-e1-v2", path = "v0 v1 v2" if (new_path.size() == path_pattern.size() / 2 + 1) { path_result.push_back(new_path); } else { auto new_pair = std::make_pair(level + 2, new_path); if (frag.IsInnerVertex(v)) { paths.push(new_pair); } else { messages.SendToFragment(frag.GetFragId(v), new_pair); } } } } } paths.pop(); } } void PEval(const fragment_t& frag, context_t& ctx, message_manager_t& messages) { auto curr_u_label = ctx.path_pattern[0]; auto inner_vertices = frag.InnerVertices(curr_u_label); std::queue<std::pair<layer_t, path_t>> paths; for (auto u : inner_vertices) { std::pair<layer_t, path_t> pair(0, {frag.Vertex2Gid(u)}); paths.push(pair); } BFS(frag, ctx, messages, paths); } void IncEval(const fragment_t& frag, context_t& ctx, message_manager_t& messages) { std::queue<std::pair<layer_t, path_t>> paths; { std::pair<layer_t, path_t> msg; while (messages.GetMessage(msg)) { paths.push(msg); } } // A rough implementation to limit path count uint32_t total_path_count; Sum((uint32_t) ctx.path_result.size(), total_path_count); if (total_path_count >= ctx.total_path_limit) { auto& path_result = ctx.path_result; std::vector<size_t> shape{path_result.size(), ctx.path_pattern.size() / 2 + 1}; ctx.set_shape(shape); auto* data = ctx.tensor().data(); size_t idx = 0; for (auto& path : path_result) { for (auto gid : path) { data[idx++] = frag.Gid2Oid(gid); } } return; } BFS(frag, ctx, messages, paths); } }; } // namespace gs #endif // ANALYTICAL_ENGINE_APPS_SAMPLING_PATH_SAMPLING_PATH_H_