std::tuple ToCSR()

in graphlearn_torch/v6d/vineyard_utils.cc [63:143]


std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> ToCSR(
  const std::string& ipc_socket, const std::string& object_id_str,
  const std::string& v_label_name, const std::string& e_label_name,
  const std::string& edge_dir, bool has_eid) {

  if (edge_dir != "in" && edge_dir != "out") {
    throw std::runtime_error("Invalid edge_dir value. edge_dir must be 'in' or 'out'.");
  }

  auto vineyard_graph = GetGraphFromVineyard(ipc_socket, object_id_str);

  auto v_label_id =  vineyard_graph->schema().GetVertexLabelId(v_label_name);
  if (v_label_id < 0) {
    throw std::runtime_error("v_label_name not exist");
  }
  auto e_label_id =  vineyard_graph->schema().GetEdgeLabelId(e_label_name);
  if (e_label_id < 0) {
    throw std::runtime_error("e_label_name not exist");
  }

  int64_t* offsets;
  int64_t offset_len;

  if (edge_dir == "out") {
    offsets = const_cast<int64_t*>(
      vineyard_graph->GetOutgoingOffsetArray(v_label_id, e_label_id));
    offset_len = vineyard_graph->GetOutgoingOffsetLength(v_label_id, e_label_id);
  } else {
    offsets = const_cast<int64_t*>(
      vineyard_graph->GetIncomingOffsetArray(v_label_id, e_label_id));
    offset_len = vineyard_graph->GetIncomingOffsetLength(v_label_id, e_label_id);
  }

  auto iv = vineyard_graph->InnerVertices(v_label_id);
  int64_t indice_len = 0;

  for (auto v: iv) {
    if (edge_dir == "out") {
      auto oe = vineyard_graph->GetOutgoingRawAdjList(v, e_label_id);
      indice_len += oe.Size();
    } else {
      auto oe = vineyard_graph->GetIncomingRawAdjList(v, e_label_id);
      indice_len += oe.Size();
    }
  }

  int64_t* cols = new int64_t[indice_len];
  int64_t* eids = new int64_t[indice_len];

  int64_t i = 0;

  for (auto v : iv) {
    if (edge_dir == "out") {
      auto oe = vineyard_graph->GetOutgoingAdjList(v, e_label_id);
      for (auto& e : oe) {
        cols[i] = vineyard_graph->Vertex2Gid(e.get_neighbor());
        if (has_eid) {
          eids[i++] = e.edge_id();
        } else {
          ++i;
        }
      }
    } else {
      auto oe = vineyard_graph->GetIncomingAdjList(v, e_label_id);
      for (auto& e : oe) {
        cols[i] = vineyard_graph->Vertex2Gid(e.get_neighbor());
        if (has_eid) {
          eids[i++] = e.edge_id();
        } else {
          ++i;
        }
      }
    }
  }

  auto options = torch::TensorOptions().dtype(torch::kInt64).device(torch::kCPU);
  torch::Tensor indptr = torch::from_blob(offsets, offset_len, options);
  torch::Tensor indices = torch::from_blob(cols, indice_len, customDeleter<int64_t>, options);
  torch::Tensor edge_ids = torch::from_blob(eids, indice_len, customDeleter<int64_t>, options);
  return {indptr, indices, edge_ids};
}