fbgemm_gpu/codegen/embedding_forward_split_cpu.cpp (526 lines of code) (raw):

/* * Copyright (c) Meta Platforms, Inc. and affiliates. * All rights reserved. * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. */ #include "codegen/embedding_forward_split_cpu.h" #include "fbgemm/FbgemmEmbedding.h" #include "fbgemm/Types.h" #include "fbgemm/Utils.h" #include "fbgemm_gpu/cpu_utils.h" #include "fbgemm_gpu/embedding_common.h" #ifdef FBCODE_CAFFE2 #include <libdivide.h> #include "folly/container/F14Map.h" #else #include <omp.h> #endif #include <ATen/AccumulateType.h> using Tensor = at::Tensor; template <typename weights_t, typename ind_weights_t, typename output_t> void split_embedding_forward_cpu_kernel( Tensor weights, Tensor weights_offsets, Tensor D_offsets, int64_t total_D, Tensor hash_size_cumsum, Tensor indices, Tensor offsets, int64_t pooling_mode, Tensor indice_weights, Tensor output) { int64_t T = D_offsets.numel() - 1; TORCH_CHECK(T > 0); // offsets = [T x B + 1] int64_t B = (offsets.size(0) - 1) / T; TORCH_CHECK(B >= 0); TORCH_CHECK(weights.is_contiguous()); indices = indices.contiguous(); offsets = offsets.contiguous(); if (indice_weights.defined()) { indice_weights = indice_weights.contiguous(); } const auto D_offsets_data = D_offsets.accessor<int, 1>(); const auto weights_offsets_data = weights_offsets.accessor<int64_t, 1>(); const auto indices_data = indices.data_ptr<int64_t>(); const auto offsets_data = offsets.data_ptr<int64_t>(); const auto hash_size_cumsum_data = hash_size_cumsum.accessor<int64_t, 1>(); const auto weights_data = weights.data_ptr<weights_t>(); // If indice_weights not defined, then this accessor won't be used. // The else condition is just to make compiler happy const auto indice_weights_data = indice_weights.defined() ? indice_weights.data_ptr<ind_weights_t>() : nullptr; auto output_data = output.data_ptr<output_t>(); auto output_stride = output.size(1); constexpr bool use_fbgemm = (std::is_same<weights_t, float>::value || std::is_same<weights_t, at::Half>::value || std::is_same<weights_t, uint8_t>::value) && std::is_same<output_t, float>::value && std::is_same<ind_weights_t, float>::value; at::parallel_for(0, B, 0, [&](int64_t b_begin, int64_t b_end) { for (int t = 0; t < T; ++t) { const auto D_begin = D_offsets_data[t]; const auto D = D_offsets_data[t + 1] - D_offsets_data[t]; const auto table_begin = weights_offsets_data[t]; int64_t hash_size; int t_temp = t + 1; do { hash_size = hash_size_cumsum_data[t_temp] - hash_size_cumsum_data[t]; ++t_temp; } while (hash_size == 0); bool success = true; if (use_fbgemm) { using fbgemm_weight_t = typename std::conditional< std::is_same<weights_t, at::Half>::value, fbgemm::float16, weights_t>::type; auto kernel = fbgemm::GenerateEmbeddingSpMDMWithStrides< fbgemm_weight_t, /*IndexType=*/int64_t, /*OffsetType=*/int64_t>( D, indice_weights.defined(), static_cast<PoolingMode>(pooling_mode) == PoolingMode::MEAN, /*prefetch=*/16, /*is_weight_positional=*/false, /*use_offsets=*/true, output_stride); auto offsets_begin_ptr = offsets_data + t * B + b_begin; auto indices_size = offsets_data[t * B + b_end] - *offsets_begin_ptr; success = kernel( b_end - b_begin, indices_size, hash_size, reinterpret_cast<const fbgemm_weight_t*>( weights_data + table_begin), indices_data + *offsets_begin_ptr, offsets_begin_ptr, indice_weights.defined() ? reinterpret_cast<const float*>( indice_weights_data + *offsets_begin_ptr) : nullptr, reinterpret_cast<float*>( output_data + b_begin * output_stride + D_begin)); } else { at::acc_type<output_t, true> output_buf[D]; for (int b = b_begin; b < b_end; ++b) { const auto pool_begin = offsets_data[t * B + b]; const auto pool_end = offsets_data[t * B + b + 1]; const auto L = pool_end - pool_begin; memset(output_buf, 0, D * sizeof(at::acc_type<output_t, true>)); for (auto p = pool_begin; p < pool_end; ++p) { int64_t idx = indices_data[p]; if (idx < 0 || idx >= hash_size) { success = false; break; } const int64_t embedding_begin = table_begin + idx * D; for (int64_t d = 0; d < D; ++d) { output_buf[d] += (indice_weights.defined() ? static_cast<at::acc_type<output_t, true>>( weights_data[embedding_begin + d]) * static_cast<at::acc_type<output_t, true>>( indice_weights_data[p]) : static_cast<at::acc_type<output_t, true>>( weights_data[embedding_begin + d])); } } const double scale_factor = // NOTE: MEAN pooling will not work with indice_weights! (static_cast<PoolingMode>(pooling_mode) == PoolingMode::MEAN && !indice_weights.defined() && L > 0) ? 1.0 / L : 1.0; for (int d = 0; d < D; ++d) { output_data[b * output_stride + D_begin + d] = scale_factor * output_buf[d]; } if (!success) { break; } } // for each b } // !use_fbgemm if (!success) { fbgemm_gpu::report_embedding_error( t, B, b_begin, b_end, offsets_data, indices_data, hash_size); } // !success } // for each t }); // parallel for } Tensor split_embedding_codegen_forward_cpu( Tensor weights, Tensor weights_offsets, Tensor D_offsets, int64_t total_D, Tensor hash_size_cumsum, Tensor indices, Tensor offsets, int64_t pooling_mode, Tensor indice_weights, int64_t output_dtype) { int64_t T = D_offsets.numel() - 1; TORCH_CHECK(T > 0); // offsets = [T x B + 1] int64_t B = (offsets.size(0) - 1) / T; TORCH_CHECK(B >= 0); Tensor output; if (output_dtype == static_cast<int64_t>(SparseType::FP32)) { output = at::empty({B, total_D}, weights.options().dtype(at::kFloat)); } else if (output_dtype == static_cast<int64_t>(SparseType::FP16)) { output = at::empty({B, total_D}, weights.options().dtype(at::kHalf)); } else if (output_dtype == static_cast<int64_t>(SparseType::BF16)) { output = at::empty({B, total_D}, weights.options().dtype(at::kBFloat16)); } else { output = at::empty({B, total_D}, weights.options()); } // It is assumed that the indice_weights will always be float TORCH_CHECK( !indice_weights.defined() || indice_weights.scalar_type() != at::kHalf); AT_DISPATCH_FLOATING_TYPES( output.scalar_type(), "split_embedding_cpu_forward", [&] { using output_t = scalar_t; AT_DISPATCH_FLOATING_TYPES_AND2( at::ScalarType::Half, at::ScalarType::Byte, weights.scalar_type(), "split_embedding_cpu_forward", [&] { using ind_weights_t = std::conditional< std::is_same<scalar_t, double>::value, double, float>::type; split_embedding_forward_cpu_kernel< scalar_t, ind_weights_t, output_t>( weights, weights_offsets, D_offsets, total_D, hash_size_cumsum, indices, offsets, pooling_mode, indice_weights, output); }); }); return output; } template <typename weights_t, typename grad_t> void split_embedding_grad_indice_weights_cpu_kernel( Tensor grad_output, Tensor weights, Tensor weights_offsets, Tensor D_offsets, Tensor indices, Tensor offsets, Tensor feature_requires_grad, Tensor grad_indice_weights) { int64_t T = D_offsets.numel() - 1; TORCH_CHECK(T > 0); // offsets = [T x B + 1] int64_t B = (offsets.size(0) - 1) / T; TORCH_CHECK(B >= 0); const auto D_offsets_data = D_offsets.accessor<int, 1>(); const auto weights_offsets_data = weights_offsets.accessor<int64_t, 1>(); const auto offsets_data = offsets.accessor<int64_t, 1>(); const auto indices_data = indices.accessor<int64_t, 1>(); const auto weights_data = weights.accessor<weights_t, 1>(); const auto grad_output_data = grad_output.accessor<grad_t, 2>(); auto grad_indice_weights_data = grad_indice_weights.accessor<at::acc_type<grad_t, true>, 1>(); at::parallel_for(0, B, 0, [&](int64_t b_begin, int64_t b_end) { for (int64_t t = 0; t < T; ++t) { if (feature_requires_grad.defined() && !feature_requires_grad[t].is_nonzero()) { // NOTE: skip if the table does not require gradient computation! continue; } const auto D_begin = D_offsets_data[t]; const auto D = D_offsets_data[t + 1] - D_offsets_data[t]; const auto table_begin = weights_offsets_data[t]; for (int64_t b = b_begin; b < b_end; ++b) { const auto pool_begin = offsets_data[t * B + b]; const auto pool_end = offsets_data[t * B + b + 1]; for (auto p = pool_begin; p < pool_end; ++p) { const int64_t embedding_begin = table_begin + indices_data[p] * D; for (int64_t d = 0; d < D; ++d) { grad_indice_weights_data[p] += static_cast<at::acc_type<weights_t, true>>( grad_output_data[b][D_begin + d]) * weights_data[embedding_begin + d]; } } } } // for each t }); // parallel for } Tensor split_embedding_codegen_grad_indice_weights_cpu( Tensor grad_output, Tensor weights, Tensor weights_offsets, Tensor D_offsets, Tensor indices, Tensor offsets, Tensor feature_requires_grad) { auto grad_indice_weights = zeros_like( indices, indices.options().dtype( at::toAccumulateType(grad_output.scalar_type(), true))); AT_DISPATCH_FLOATING_TYPES_AND_HALF( grad_output.scalar_type(), "split_embedding_grad_indice_weights_cpu_outer", [&] { using grad_t = scalar_t; AT_DISPATCH_FLOATING_TYPES_AND_HALF( weights.scalar_type(), "split_embedding_grad_indice_weights_cpu", [&] { using weights_t = scalar_t; split_embedding_grad_indice_weights_cpu_kernel<weights_t, grad_t>( grad_output, weights, weights_offsets, D_offsets, indices, offsets, feature_requires_grad, grad_indice_weights); }); }); return grad_indice_weights; } namespace internal { template <typename scalar_t> void batched_csr2csc( BatchedHyperCompressedSparseColumn& batched_csc, int B, // TODO: use accessor for the following 3 parameters const at::TensorAccessor<int64_t, 1>& batched_csr_offsets, const at::TensorAccessor<int64_t, 1>& batched_csr_indices, const at::TensorAccessor<scalar_t, 1>& batched_csr_weights, int64_t pooling_mode, const int* table_to_feature_offset, int64_t num_embeddings) { int num_tables = 1; batched_csc.num_tables = num_tables; batched_csc.table_ptr = static_cast<int*>( fbgemm::fbgemmAlignedAlloc(64, (num_tables + 1) * sizeof(int))); batched_csc.table_ptr[0] = 0; int64_t nnz = batched_csr_offsets[table_to_feature_offset[num_tables] * B] - batched_csr_offsets[table_to_feature_offset[0] * B]; if (nnz == 0) { batched_csc.table_ptr[1] = 0; return; } batched_csc.row_indices = static_cast<int*>(fbgemm::fbgemmAlignedAlloc(64, nnz * sizeof(int))); bool has_weights = batched_csr_weights.data() != nullptr; if (has_weights || static_cast<PoolingMode>(pooling_mode) == PoolingMode::MEAN) { batched_csc.weights = static_cast<float*>( fbgemm::fbgemmAlignedAlloc(64, nnz * sizeof(float))); } int column_ptr_curr = 0; int t = 0; bool is_shared_table = table_to_feature_offset[t + 1] > table_to_feature_offset[t] + 1; auto NS = batched_csr_offsets[table_to_feature_offset[t + 1] * B] - batched_csr_offsets[table_to_feature_offset[t] * B]; int num_non_empty_segments = 0; if (!batched_csc.weights) { batched_csc.column_segment_ids = static_cast<int*>(fbgemm::fbgemmAlignedAlloc(64, nnz * sizeof(int))); int* tmpBufKeys = static_cast<int*>(fbgemm::fbgemmAlignedAlloc(64, NS * sizeof(int))); int* tmpBufValues = static_cast<int*>(fbgemm::fbgemmAlignedAlloc(64, NS * sizeof(int))); int* tmpBuf1Keys = static_cast<int*>(fbgemm::fbgemmAlignedAlloc(64, NS * sizeof(int))); int* tmpBuf1Values = static_cast<int*>(fbgemm::fbgemmAlignedAlloc(64, NS * sizeof(int))); const auto FBo = batched_csr_offsets[table_to_feature_offset[t] * B]; for (int feature = table_to_feature_offset[t]; feature < table_to_feature_offset[t + 1]; ++feature) { const auto FBs = (feature - table_to_feature_offset[t]) * B; #pragma omp parallel for for (int b = 0; b < B; ++b) { const auto FBb = feature * B + b; int64_t pool_begin = batched_csr_offsets[FBb]; int64_t pool_end = batched_csr_offsets[FBb + 1]; for (int64_t p = pool_begin; p < pool_end; ++p) { tmpBufKeys[p - FBo] = batched_csr_indices[p]; tmpBufValues[p - FBo] = FBs + b; } } } int* sorted_col_row_index_keys = nullptr; int* sorted_col_row_index_values = nullptr; std::tie(sorted_col_row_index_keys, sorted_col_row_index_values) = fbgemm_gpu::radix_sort_parallel( tmpBufKeys, tmpBufValues, tmpBuf1Keys, tmpBuf1Values, NS, num_embeddings); int max_thds = omp_get_max_threads(); int num_uniq[max_thds][64]; int U = 0; if (at::get_num_threads() > 1) { // This block is not needed for single thread #pragma omp parallel { int tid = omp_get_thread_num(); num_uniq[tid][0] = 0; #pragma omp for schedule(static) for (int i = 1; i < NS; i++) { if (sorted_col_row_index_keys[i] != sorted_col_row_index_keys[i - 1]) { num_uniq[tid][0]++; } } } num_uniq[0][0] += 1; for (int i = 1; i < max_thds; i++) num_uniq[i][0] += num_uniq[i - 1][0]; U = num_uniq[max_thds - 1][0]; } batched_csc.column_segment_ptr = static_cast<int*>( fbgemm::fbgemmAlignedAlloc(64, (NS + 1) * sizeof(int))); batched_csc.column_segment_indices = static_cast<int*>(fbgemm::fbgemmAlignedAlloc(64, NS * sizeof(int))); batched_csc.column_segment_ptr[0] = 0; batched_csc.row_indices[0] = sorted_col_row_index_values[0] % B; batched_csc.column_segment_indices[0] = sorted_col_row_index_keys[0]; batched_csc.column_segment_ids[0] = sorted_col_row_index_values[0] / B; #pragma omp parallel { int tid = omp_get_thread_num(); int* tstart = (tid == 0 ? batched_csc.column_segment_indices + 1 : batched_csc.column_segment_indices + num_uniq[tid - 1][0]); int* t_offs = (tid == 0 ? batched_csc.column_segment_ptr + 1 : batched_csc.column_segment_ptr + num_uniq[tid - 1][0]); if (!is_shared_table) { // For non shared table, no need for computing modulo. // As an optimization, pointer swap instead of copying. #pragma omp master std::swap( batched_csc.row_indices, sorted_col_row_index_values == tmpBufValues ? tmpBufValues : tmpBuf1Values); } else { #ifdef FBCODE_CAFFE2 libdivide::divider<int> divisor(B); #endif #pragma omp for schedule(static) for (int i = 1; i < NS; ++i) { int v = sorted_col_row_index_values[i]; #ifdef FBCODE_CAFFE2 int q = v / divisor; #else int q = v / B; #endif batched_csc.column_segment_ids[i] = q; batched_csc.row_indices[i] = v - q * B; } } #pragma omp for schedule(static) for (int i = 1; i < NS; ++i) { if (sorted_col_row_index_keys[i] != sorted_col_row_index_keys[i - 1]) { *tstart = sorted_col_row_index_keys[i]; *t_offs = i; tstart++; t_offs++; } } if (at::get_num_threads() == 1 && tid == 0) { // Special handling of single thread case U = t_offs - batched_csc.column_segment_ptr; } } // omp parallel batched_csc.table_ptr[t + 1] = batched_csc.table_ptr[t] + U; batched_csc.column_segment_ptr[U] = NS; column_ptr_curr += NS; fbgemm::fbgemmAlignedFree(tmpBufKeys); fbgemm::fbgemmAlignedFree(tmpBufValues); fbgemm::fbgemmAlignedFree(tmpBuf1Keys); fbgemm::fbgemmAlignedFree(tmpBuf1Values); } else { // batched_csc.weights #ifdef FBCODE_CAFFE2 folly::F14FastMap< #else std::unordered_map< #endif int64_t, std::vector<std::vector<std::pair<int, scalar_t>>>> non_empty_columns; int f_begin = table_to_feature_offset[t]; int f_end = table_to_feature_offset[t + 1]; for (int feature = f_begin; feature < f_end; ++feature) { for (int b = 0; b < B; ++b) { int64_t pool_begin = batched_csr_offsets[feature * B + b]; int64_t pool_end = batched_csr_offsets[feature * B + b + 1]; int64_t L = pool_end - pool_begin; // MEAN pooling will not work with indice_weights! double scale_factor = (static_cast<PoolingMode>(pooling_mode) == PoolingMode::MEAN && !has_weights && L > 0) ? 1.0 / L : 1.0; for (int64_t p = pool_begin; p < pool_end; ++p) { auto itr = non_empty_columns.find(batched_csr_indices[p]); if (itr == non_empty_columns.end()) { itr = non_empty_columns .emplace( batched_csr_indices[p], std::vector<std::vector<std::pair<int, scalar_t>>>( f_end - f_begin)) .first; } if (itr->second[feature - f_begin].empty()) { ++num_non_empty_segments; } itr->second[feature - f_begin].emplace_back( b, scale_factor * (has_weights ? batched_csr_weights[p] : 1.0f)); } } } // for each feature batched_csc.table_ptr[t + 1] = batched_csc.table_ptr[t] + num_non_empty_segments; batched_csc.column_segment_ptr = static_cast<int*>( fbgemm::fbgemmAlignedAlloc(64, (NS + 1) * sizeof(int))); batched_csc.column_segment_ptr[0] = 0; batched_csc.column_segment_indices = static_cast<int*>(fbgemm::fbgemmAlignedAlloc(64, NS * sizeof(int))); batched_csc.column_segment_ids = static_cast<int*>(fbgemm::fbgemmAlignedAlloc(64, NS * sizeof(int))); int k = 1; for (auto const& column : non_empty_columns) { int feature = f_begin; for (auto const& column_segment : column.second) { if (!column_segment.empty()) { batched_csc.column_segment_ptr[k] = column_ptr_curr + column_segment.size(); batched_csc.column_segment_indices[k - 1] = column.first; batched_csc.column_segment_ids[k - 1] = feature - f_begin; k++; for (auto const& non_zero : column_segment) { batched_csc.row_indices[column_ptr_curr] = non_zero.first; batched_csc.weights[column_ptr_curr] = non_zero.second; ++column_ptr_curr; } } ++feature; } // for each column segment } // for each column } // !batched_csc.weights.empty() assert(column_ptr_curr == nnz); } template void batched_csr2csc<float>( BatchedHyperCompressedSparseColumn& batched_csc, int B, const at::TensorAccessor<int64_t, 1>& batched_csr_offsets, const at::TensorAccessor<int64_t, 1>& batched_csr_indices, const at::TensorAccessor<float, 1>& batched_csr_weights, int64_t pooling_mode, const int* table_to_feature_offset, int64_t num_embeddings); template void batched_csr2csc<double>( BatchedHyperCompressedSparseColumn& batched_csc, int B, const at::TensorAccessor<int64_t, 1>& batched_csr_offsets, const at::TensorAccessor<int64_t, 1>& batched_csr_indices, const at::TensorAccessor<double, 1>& batched_csr_weights, int64_t pooling_mode, const int* table_to_feature_offset, int64_t num_embeddings); } // namespace internal