void split_embedding_forward_cpu_kernel()

in fbgemm_gpu/codegen/embedding_forward_split_cpu.cpp [25:164]


void split_embedding_forward_cpu_kernel(
    Tensor weights,
    Tensor weights_offsets,
    Tensor D_offsets,
    int64_t total_D,
    Tensor hash_size_cumsum,
    Tensor indices,
    Tensor offsets,
    int64_t pooling_mode,
    Tensor indice_weights,
    Tensor output) {
  int64_t T = D_offsets.numel() - 1;
  TORCH_CHECK(T > 0);
  // offsets = [T x B  + 1]
  int64_t B = (offsets.size(0) - 1) / T;
  TORCH_CHECK(B >= 0);

  TORCH_CHECK(weights.is_contiguous());
  indices = indices.contiguous();
  offsets = offsets.contiguous();
  if (indice_weights.defined()) {
    indice_weights = indice_weights.contiguous();
  }

  const auto D_offsets_data = D_offsets.accessor<int, 1>();
  const auto weights_offsets_data = weights_offsets.accessor<int64_t, 1>();
  const auto indices_data = indices.data_ptr<int64_t>();
  const auto offsets_data = offsets.data_ptr<int64_t>();
  const auto hash_size_cumsum_data = hash_size_cumsum.accessor<int64_t, 1>();

  const auto weights_data = weights.data_ptr<weights_t>();
  // If indice_weights not defined, then this accessor won't be used.
  // The else condition is just to make compiler happy
  const auto indice_weights_data = indice_weights.defined()
      ? indice_weights.data_ptr<ind_weights_t>()
      : nullptr;

  auto output_data = output.data_ptr<output_t>();
  auto output_stride = output.size(1);

  constexpr bool use_fbgemm = (std::is_same<weights_t, float>::value ||
                               std::is_same<weights_t, at::Half>::value ||
                               std::is_same<weights_t, uint8_t>::value) &&
      std::is_same<output_t, float>::value &&
      std::is_same<ind_weights_t, float>::value;

  at::parallel_for(0, B, 0, [&](int64_t b_begin, int64_t b_end) {
    for (int t = 0; t < T; ++t) {
      const auto D_begin = D_offsets_data[t];
      const auto D = D_offsets_data[t + 1] - D_offsets_data[t];
      const auto table_begin = weights_offsets_data[t];

      int64_t hash_size;
      int t_temp = t + 1;
      do {
        hash_size = hash_size_cumsum_data[t_temp] - hash_size_cumsum_data[t];
        ++t_temp;
      } while (hash_size == 0);

      bool success = true;
      if (use_fbgemm) {
        using fbgemm_weight_t = typename std::conditional<
            std::is_same<weights_t, at::Half>::value,
            fbgemm::float16,
            weights_t>::type;
        auto kernel = fbgemm::GenerateEmbeddingSpMDMWithStrides<
            fbgemm_weight_t,
            /*IndexType=*/int64_t,
            /*OffsetType=*/int64_t>(
            D,
            indice_weights.defined(),
            static_cast<PoolingMode>(pooling_mode) == PoolingMode::MEAN,
            /*prefetch=*/16,
            /*is_weight_positional=*/false,
            /*use_offsets=*/true,
            output_stride);
        auto offsets_begin_ptr = offsets_data + t * B + b_begin;
        auto indices_size = offsets_data[t * B + b_end] - *offsets_begin_ptr;
        success = kernel(
            b_end - b_begin,
            indices_size,
            hash_size,
            reinterpret_cast<const fbgemm_weight_t*>(
                weights_data + table_begin),
            indices_data + *offsets_begin_ptr,
            offsets_begin_ptr,
            indice_weights.defined()
                ? reinterpret_cast<const float*>(
                      indice_weights_data + *offsets_begin_ptr)
                : nullptr,
            reinterpret_cast<float*>(
                output_data + b_begin * output_stride + D_begin));
      } else {
        at::acc_type<output_t, true> output_buf[D];
        for (int b = b_begin; b < b_end; ++b) {
          const auto pool_begin = offsets_data[t * B + b];
          const auto pool_end = offsets_data[t * B + b + 1];
          const auto L = pool_end - pool_begin;
          memset(output_buf, 0, D * sizeof(at::acc_type<output_t, true>));
          for (auto p = pool_begin; p < pool_end; ++p) {
            int64_t idx = indices_data[p];
            if (idx < 0 || idx >= hash_size) {
              success = false;
              break;
            }
            const int64_t embedding_begin = table_begin + idx * D;
            for (int64_t d = 0; d < D; ++d) {
              output_buf[d] +=
                  (indice_weights.defined()
                       ? static_cast<at::acc_type<output_t, true>>(
                             weights_data[embedding_begin + d]) *
                           static_cast<at::acc_type<output_t, true>>(
                               indice_weights_data[p])
                       : static_cast<at::acc_type<output_t, true>>(
                             weights_data[embedding_begin + d]));
            }
          }
          const double scale_factor =
              // NOTE: MEAN pooling will not work with indice_weights!
              (static_cast<PoolingMode>(pooling_mode) == PoolingMode::MEAN &&
               !indice_weights.defined() && L > 0)
              ? 1.0 / L
              : 1.0;
          for (int d = 0; d < D; ++d) {
            output_data[b * output_stride + D_begin + d] =
                scale_factor * output_buf[d];
          }
          if (!success) {
            break;
          }
        } // for each b
      } // !use_fbgemm

      if (!success) {
        fbgemm_gpu::report_embedding_error(
            t, B, b_begin, b_end, offsets_data, indices_data, hash_size);
      } // !success
    } // for each t
  }); // parallel for
}