in graphlearn_torch/csrc/cpu/graph.cc [21:46]
void Graph::InitCPUGraphFromCSR(
const torch::Tensor& indptr,
const torch::Tensor& indices,
const torch::Tensor& edge_ids,
const torch::Tensor& edge_weights) {
CheckEq<int64_t>(indptr.dim(), 1);
CheckEq<int64_t>(indices.dim(), 1);
row_ptr_ = indptr.data_ptr<int64_t>();
col_idx_ = indices.data_ptr<int64_t>();
row_count_ = indptr.size(0) - 1;
edge_count_ = indices.size(0);
col_count_ = std::get<0>(at::_unique(indices)).size(0);
if (edge_ids.numel()) {
CheckEq<int64_t>(edge_ids.dim(), 1);
CheckEq<int64_t>(edge_ids.numel(), indices.numel());
edge_id_ = edge_ids.data_ptr<int64_t>();
}
if (edge_weights.numel()) {
CheckEq<int64_t>(edge_weights.dim(), 1);
CheckEq<int64_t>(edge_weights.numel(), indices.numel());
edge_weight_ = edge_weights.data_ptr<float>();
}
}