void Graph::InitCPUGraphFromCSR()

in graphlearn_torch/csrc/cpu/graph.cc [21:46]


void Graph::InitCPUGraphFromCSR(
    const torch::Tensor& indptr,
    const torch::Tensor& indices,
    const torch::Tensor& edge_ids,
    const torch::Tensor& edge_weights) {
  CheckEq<int64_t>(indptr.dim(), 1);
  CheckEq<int64_t>(indices.dim(), 1);

  row_ptr_ = indptr.data_ptr<int64_t>();
  col_idx_ = indices.data_ptr<int64_t>();
  row_count_ = indptr.size(0) - 1;
  edge_count_ = indices.size(0);
  col_count_ = std::get<0>(at::_unique(indices)).size(0);

  if (edge_ids.numel()) {
    CheckEq<int64_t>(edge_ids.dim(), 1);
    CheckEq<int64_t>(edge_ids.numel(), indices.numel());
    edge_id_ = edge_ids.data_ptr<int64_t>();
  }

  if (edge_weights.numel()) {
    CheckEq<int64_t>(edge_weights.dim(), 1);
    CheckEq<int64_t>(edge_weights.numel(), indices.numel());
    edge_weight_ = edge_weights.data_ptr<float>();
  }
}