void batched_csr2csc()

in fbgemm_gpu/codegen/embedding_forward_split_cpu.cpp [322:564]


void batched_csr2csc(
    BatchedHyperCompressedSparseColumn& batched_csc,
    int B,
    // TODO: use accessor for the following 3 parameters
    const at::TensorAccessor<int64_t, 1>& batched_csr_offsets,
    const at::TensorAccessor<int64_t, 1>& batched_csr_indices,
    const at::TensorAccessor<scalar_t, 1>& batched_csr_weights,
    int64_t pooling_mode,
    const int* table_to_feature_offset,
    int64_t num_embeddings) {
  int num_tables = 1;
  batched_csc.num_tables = num_tables;
  batched_csc.table_ptr = static_cast<int*>(
      fbgemm::fbgemmAlignedAlloc(64, (num_tables + 1) * sizeof(int)));
  batched_csc.table_ptr[0] = 0;
  int64_t nnz = batched_csr_offsets[table_to_feature_offset[num_tables] * B] -
      batched_csr_offsets[table_to_feature_offset[0] * B];
  if (nnz == 0) {
    batched_csc.table_ptr[1] = 0;
    return;
  }
  batched_csc.row_indices =
      static_cast<int*>(fbgemm::fbgemmAlignedAlloc(64, nnz * sizeof(int)));
  bool has_weights = batched_csr_weights.data() != nullptr;
  if (has_weights ||
      static_cast<PoolingMode>(pooling_mode) == PoolingMode::MEAN) {
    batched_csc.weights = static_cast<float*>(
        fbgemm::fbgemmAlignedAlloc(64, nnz * sizeof(float)));
  }

  int column_ptr_curr = 0;
  int t = 0;
  bool is_shared_table =
      table_to_feature_offset[t + 1] > table_to_feature_offset[t] + 1;
  auto NS = batched_csr_offsets[table_to_feature_offset[t + 1] * B] -
      batched_csr_offsets[table_to_feature_offset[t] * B];
  int num_non_empty_segments = 0;
  if (!batched_csc.weights) {
    batched_csc.column_segment_ids =
        static_cast<int*>(fbgemm::fbgemmAlignedAlloc(64, nnz * sizeof(int)));

    int* tmpBufKeys =
        static_cast<int*>(fbgemm::fbgemmAlignedAlloc(64, NS * sizeof(int)));
    int* tmpBufValues =
        static_cast<int*>(fbgemm::fbgemmAlignedAlloc(64, NS * sizeof(int)));
    int* tmpBuf1Keys =
        static_cast<int*>(fbgemm::fbgemmAlignedAlloc(64, NS * sizeof(int)));
    int* tmpBuf1Values =
        static_cast<int*>(fbgemm::fbgemmAlignedAlloc(64, NS * sizeof(int)));
    const auto FBo = batched_csr_offsets[table_to_feature_offset[t] * B];
    for (int feature = table_to_feature_offset[t];
         feature < table_to_feature_offset[t + 1];
         ++feature) {
      const auto FBs = (feature - table_to_feature_offset[t]) * B;
#pragma omp parallel for
      for (int b = 0; b < B; ++b) {
        const auto FBb = feature * B + b;
        int64_t pool_begin = batched_csr_offsets[FBb];
        int64_t pool_end = batched_csr_offsets[FBb + 1];
        for (int64_t p = pool_begin; p < pool_end; ++p) {
          tmpBufKeys[p - FBo] = batched_csr_indices[p];
          tmpBufValues[p - FBo] = FBs + b;
        }
      }
    }

    int* sorted_col_row_index_keys = nullptr;
    int* sorted_col_row_index_values = nullptr;
    std::tie(sorted_col_row_index_keys, sorted_col_row_index_values) =
        fbgemm_gpu::radix_sort_parallel(
            tmpBufKeys,
            tmpBufValues,
            tmpBuf1Keys,
            tmpBuf1Values,
            NS,
            num_embeddings);

    int max_thds = omp_get_max_threads();
    int num_uniq[max_thds][64];
    int U = 0;
    if (at::get_num_threads() > 1) {
      // This block is not needed for single thread
#pragma omp parallel
      {
        int tid = omp_get_thread_num();
        num_uniq[tid][0] = 0;
#pragma omp for schedule(static)
        for (int i = 1; i < NS; i++) {
          if (sorted_col_row_index_keys[i] !=
              sorted_col_row_index_keys[i - 1]) {
            num_uniq[tid][0]++;
          }
        }
      }
      num_uniq[0][0] += 1;
      for (int i = 1; i < max_thds; i++)
        num_uniq[i][0] += num_uniq[i - 1][0];
      U = num_uniq[max_thds - 1][0];
    }

    batched_csc.column_segment_ptr = static_cast<int*>(
        fbgemm::fbgemmAlignedAlloc(64, (NS + 1) * sizeof(int)));
    batched_csc.column_segment_indices =
        static_cast<int*>(fbgemm::fbgemmAlignedAlloc(64, NS * sizeof(int)));

    batched_csc.column_segment_ptr[0] = 0;
    batched_csc.row_indices[0] = sorted_col_row_index_values[0] % B;
    batched_csc.column_segment_indices[0] = sorted_col_row_index_keys[0];
    batched_csc.column_segment_ids[0] = sorted_col_row_index_values[0] / B;
#pragma omp parallel
    {
      int tid = omp_get_thread_num();
      int* tstart =
          (tid == 0
               ? batched_csc.column_segment_indices + 1
               : batched_csc.column_segment_indices + num_uniq[tid - 1][0]);

      int* t_offs =
          (tid == 0 ? batched_csc.column_segment_ptr + 1
                    : batched_csc.column_segment_ptr + num_uniq[tid - 1][0]);

      if (!is_shared_table) {
        // For non shared table, no need for computing modulo.
        // As an optimization, pointer swap instead of copying.
#pragma omp master
        std::swap(
            batched_csc.row_indices,
            sorted_col_row_index_values == tmpBufValues ? tmpBufValues
                                                        : tmpBuf1Values);
      } else {
#ifdef FBCODE_CAFFE2
        libdivide::divider<int> divisor(B);
#endif

#pragma omp for schedule(static)
        for (int i = 1; i < NS; ++i) {
          int v = sorted_col_row_index_values[i];
#ifdef FBCODE_CAFFE2
          int q = v / divisor;
#else
          int q = v / B;
#endif
          batched_csc.column_segment_ids[i] = q;
          batched_csc.row_indices[i] = v - q * B;
        }
      }

#pragma omp for schedule(static)
      for (int i = 1; i < NS; ++i) {
        if (sorted_col_row_index_keys[i] != sorted_col_row_index_keys[i - 1]) {
          *tstart = sorted_col_row_index_keys[i];
          *t_offs = i;
          tstart++;
          t_offs++;
        }
      }

      if (at::get_num_threads() == 1 && tid == 0) {
        // Special handling of single thread case
        U = t_offs - batched_csc.column_segment_ptr;
      }
    } // omp parallel
    batched_csc.table_ptr[t + 1] = batched_csc.table_ptr[t] + U;
    batched_csc.column_segment_ptr[U] = NS;
    column_ptr_curr += NS;
    fbgemm::fbgemmAlignedFree(tmpBufKeys);
    fbgemm::fbgemmAlignedFree(tmpBufValues);
    fbgemm::fbgemmAlignedFree(tmpBuf1Keys);
    fbgemm::fbgemmAlignedFree(tmpBuf1Values);
  } else {
    // batched_csc.weights
#ifdef FBCODE_CAFFE2
    folly::F14FastMap<
#else
    std::unordered_map<
#endif
        int64_t,
        std::vector<std::vector<std::pair<int, scalar_t>>>>
        non_empty_columns;
    int f_begin = table_to_feature_offset[t];
    int f_end = table_to_feature_offset[t + 1];
    for (int feature = f_begin; feature < f_end; ++feature) {
      for (int b = 0; b < B; ++b) {
        int64_t pool_begin = batched_csr_offsets[feature * B + b];
        int64_t pool_end = batched_csr_offsets[feature * B + b + 1];
        int64_t L = pool_end - pool_begin;
        // MEAN pooling will not work with indice_weights!
        double scale_factor =
            (static_cast<PoolingMode>(pooling_mode) == PoolingMode::MEAN &&
             !has_weights && L > 0)
            ? 1.0 / L
            : 1.0;
        for (int64_t p = pool_begin; p < pool_end; ++p) {
          auto itr = non_empty_columns.find(batched_csr_indices[p]);
          if (itr == non_empty_columns.end()) {
            itr = non_empty_columns
                      .emplace(
                          batched_csr_indices[p],
                          std::vector<std::vector<std::pair<int, scalar_t>>>(
                              f_end - f_begin))
                      .first;
          }
          if (itr->second[feature - f_begin].empty()) {
            ++num_non_empty_segments;
          }
          itr->second[feature - f_begin].emplace_back(
              b, scale_factor * (has_weights ? batched_csr_weights[p] : 1.0f));
        }
      }
    } // for each feature

    batched_csc.table_ptr[t + 1] =
        batched_csc.table_ptr[t] + num_non_empty_segments;
    batched_csc.column_segment_ptr = static_cast<int*>(
        fbgemm::fbgemmAlignedAlloc(64, (NS + 1) * sizeof(int)));
    batched_csc.column_segment_ptr[0] = 0;
    batched_csc.column_segment_indices =
        static_cast<int*>(fbgemm::fbgemmAlignedAlloc(64, NS * sizeof(int)));
    batched_csc.column_segment_ids =
        static_cast<int*>(fbgemm::fbgemmAlignedAlloc(64, NS * sizeof(int)));
    int k = 1;
    for (auto const& column : non_empty_columns) {
      int feature = f_begin;
      for (auto const& column_segment : column.second) {
        if (!column_segment.empty()) {
          batched_csc.column_segment_ptr[k] =
              column_ptr_curr + column_segment.size();
          batched_csc.column_segment_indices[k - 1] = column.first;
          batched_csc.column_segment_ids[k - 1] = feature - f_begin;
          k++;
          for (auto const& non_zero : column_segment) {
            batched_csc.row_indices[column_ptr_curr] = non_zero.first;
            batched_csc.weights[column_ptr_curr] = non_zero.second;
            ++column_ptr_curr;
          }
        }
        ++feature;
      } // for each column segment
    } // for each column
  } // !batched_csc.weights.empty()

  assert(column_ptr_curr == nnz);
}