HeteroCOO CPUHeteroInducer::InduceNext()

in graphlearn_torch/csrc/cpu/inducer.cc [138:181]


HeteroCOO CPUHeteroInducer::InduceNext(const HeteroNbr& nbrs) {
  std::unordered_map<std::string, std::vector<int64_t>> out_nodes;
  InsertGlob2Local(nbrs, out_nodes);
  auto tensor_option = std::get<0>(nbrs.begin()->second).options();
  TensorEdgeMap rows_dict;
  TensorEdgeMap cols_dict;
  TensorMap nodes_dict;

  for (auto& iter : out_nodes) {
    const auto& node_type = iter.first;
    auto& out_nodes = iter.second;
    torch::Tensor nodes = torch::empty(out_nodes.size(), tensor_option);
    std::copy(out_nodes.begin(), out_nodes.end(), nodes.data_ptr<int64_t>());
    nodes_dict.emplace(node_type, std::move(nodes));
  }

  for (const auto& iter : nbrs) {
    const auto src_ptr = std::get<0>(iter.second).data_ptr<int64_t>();
    const auto nbrs_ptr = std::get<1>(iter.second).data_ptr<int64_t>();
    const auto nbrs_num_ptr = std::get<2>(iter.second).data_ptr<int64_t>();
    const auto src_size = std::get<0>(iter.second).size(0);
    const auto edge_size = std::get<1>(iter.second).size(0);
    const auto& src_type = std::get<0>(iter.first);
    const auto& dst_type = std::get<2>(iter.first);
    const auto& src_glob2local = glob2local_[src_type];
    const auto& dst_glob2local = glob2local_[dst_type];

    torch::Tensor rows = torch::empty(edge_size, tensor_option);
    torch::Tensor cols = torch::empty(edge_size, tensor_option);
    auto rows_ptr = rows.data_ptr<int64_t>();
    auto cols_ptr = cols.data_ptr<int64_t>();
    int32_t cnt = 0;
    for (int32_t i = 0; i < src_size; ++i) {
      for (int32_t j = 0; j < nbrs_num_ptr[i]; ++j) {
        rows_ptr[cnt] = src_glob2local.at(src_ptr[i]);
        cols_ptr[cnt] = dst_glob2local.at(nbrs_ptr[cnt]);
        cnt++;
      }
    }
    rows_dict.emplace(iter.first, std::move(rows));
    cols_dict.emplace(iter.first, std::move(cols));
  }
  return std::make_tuple(nodes_dict, rows_dict, cols_dict);
}