torch::Tensor LoadEdgeFeatures()

in graphlearn_torch/v6d/vineyard_utils.cc [218:260]


torch::Tensor LoadEdgeFeatures(
  const std::string& ipc_socket, const std::string& object_id_str,
  const std::string& e_label_name, std::vector<std::string>& ecols) {
  
  auto frag = GetGraphFromVineyard(ipc_socket, object_id_str);
  auto e_label_id = frag->schema().GetEdgeLabelId(e_label_name);
  if (e_label_id < 0) {
      throw std::runtime_error("e_label_name not exist");
  }

  std::shared_ptr<arrow::Array> fscol;
  torch::Tensor feat;

  // By default merge all cols when `vcols` is empty.
  if (ecols.size() == 0) {
    ecols = frag->edge_data_table(e_label_id)->ColumnNames();
  }

  // Consolidate given columns
  if (ecols.size() >= 2) {
    try {
      auto efrag_id =
        frag->ConsolidateEdgeColumns(vyclient, e_label_id, ecols, "emerged").value();
      auto efrag =
        std::dynamic_pointer_cast<GraphType>(vyclient.GetObject(efrag_id));
      fscol = std::dynamic_pointer_cast<arrow::FixedSizeListArray>(
        efrag->edge_data_table(e_label_id)->GetColumnByName("emerged")->chunk(0)
      )->values();
    } catch(...) {
      LOG(ERROR) << "Possibly different column types OR wrong column names.\n";
      throw std::runtime_error(
        "ERROR: Unable to merge!");
    }
  } else if (ecols.size() == 1) {
    try {
      fscol = frag->edge_data_table(e_label_id)->GetColumnByName(ecols[0])->chunk(0);
    } catch(...) {
      throw std::runtime_error("ERROR: Column name not exists!");
    }
  }
  feat = ArrowArray2Tensor(fscol, ecols.size());
  return feat;
}