torch::Tensor LoadVertexFeatures()

in graphlearn_torch/v6d/vineyard_utils.cc [174:215]


torch::Tensor LoadVertexFeatures(
  const std::string& ipc_socket, const std::string& object_id_str,
  const std::string& v_label_name, std::vector<std::string>& vcols) {

  auto frag = GetGraphFromVineyard(ipc_socket, object_id_str);
  auto v_label_id =  frag->schema().GetVertexLabelId(v_label_name);

  if (v_label_id < 0) {
    throw std::runtime_error("v_label_name not exist");
  }
  std::shared_ptr<arrow::Array> fscol;
  torch::Tensor feat;

  // By default merge all cols when `vcols` is empty.
  if (vcols.size() == 0) {
    vcols = frag->vertex_data_table(v_label_id)->ColumnNames();
  }

  // Consolidate given columns
  if (vcols.size() >= 2) {
    try {
      auto vfrag_id =
        frag->ConsolidateVertexColumns(vyclient, v_label_id, vcols, "vmerged").value();
      auto vfrag =
        std::dynamic_pointer_cast<GraphType>(vyclient.GetObject(vfrag_id));
      fscol = std::dynamic_pointer_cast<arrow::FixedSizeListArray>(
        vfrag->vertex_data_table(v_label_id)->GetColumnByName("vmerged")->chunk(0)
      )->values();
    } catch(...) {
      LOG(ERROR) << "Possibly different column types OR wrong column names.\n";
      throw std::runtime_error("ERROR: Unable to merge!");
    }
  } else if (vcols.size() == 1) {
    try {
      fscol = frag->vertex_data_table(v_label_id)->GetColumnByName(vcols[0])->chunk(0);
    } catch(...) {
      throw std::runtime_error("ERROR: Column name not exists!");
    }
  }
  feat = ArrowArray2Tensor(fscol, vcols.size());
  return feat;
}