fbgemm_gpu/codegen/embedding_backward_split_cpu_template.cpp (340 lines of code) (raw):

/* * Copyright (c) Meta Platforms, Inc. and affiliates. * All rights reserved. * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. */ // clang-format off #include <map> #include <tuple> #include <utility> #include <ATen/ATen.h> #include <ATen/AccumulateType.h> #include "codegen/embedding_forward_split_cpu.h" #include "fbgemm/FbgemmEmbedding.h" #include "fbgemm/Types.h" #include "fbgemm_gpu/embedding_common.h" #include "fbgemm_gpu/cpu_utils.h" using Tensor = at::Tensor; namespace internal { template <typename T> struct half2float16 { using type = T; }; template <> struct half2float16<at::Half> { using type = fbgemm::float16; }; } // namespace internal namespace { template <typename scalar_t, typename grad_t> void split_embedding_backward_exact_cpu_kernel( Tensor grad_output, Tensor host_weights, const at::TensorAccessor<int64_t, 1> weights_offsets_data, const at::TensorAccessor<int, 1> D_offsets_data, Tensor hash_size_cumsum, Tensor indices, Tensor offsets, int64_t pooling_mode, Tensor indice_weights, int num_tables, int B, const int* table_to_feature_offset, {% if "momentum1_offsets" in args.split_function_arg_names %} const at::TensorAccessor<int64_t, 1> momentum1_offsets_data, {% endif %} {% if "momentum2_offsets" in args.split_function_arg_names %} const at::TensorAccessor<int64_t, 1> momentum2_offsets_data, {% endif %} {{ args.split_cpu_kernel_args | join(", ") }}) { const grad_t* grad_output_data = grad_output.data_ptr<grad_t>(); auto host_weights_data = host_weights.accessor<scalar_t, 1>(); const auto hash_size_cumsum_data = hash_size_cumsum.accessor<int64_t, 1>(); const bool has_weights = indice_weights.defined(); auto grad_stride = grad_output.size(1); std::vector<::internal::BatchedHyperCompressedSparseColumn> batched_cscs( num_tables); auto get_hash_size = [&hash_size_cumsum_data](int feature_begin) { int64_t hash_size; int t_temp = feature_begin + 1; do { hash_size = hash_size_cumsum_data[t_temp] - hash_size_cumsum_data[feature_begin]; ++t_temp; } while (hash_size == 0); TORCH_CHECK( hash_size < ((1L << 31) - 1), "CPU exact rowwise adagrad currently doesn't support embedding tables " "with more than 2B rows"); return hash_size; }; for (int t = 0; t < num_tables; ++t) { int feature_begin = table_to_feature_offset[t]; int64_t hash_size = get_hash_size(feature_begin); ::internal::batched_csr2csc( batched_cscs[t], B, offsets.accessor<int64_t, 1>(), indices.accessor<int64_t, 1>(), indice_weights.defined() ? indice_weights.accessor<at::acc_type<scalar_t, true>, 1>() : at::TensorAccessor<at::acc_type<scalar_t, true>, 1>(nullptr, nullptr, nullptr), pooling_mode, table_to_feature_offset + t, hash_size); } // sort based csr2csc handles segment_ids differently bool is_csr2csc_sort = batched_cscs[0].weights == nullptr; for (int t = 0; t < num_tables; ++t) { int feature_begin = table_to_feature_offset[t]; int c_begin = batched_cscs[t].table_ptr[0]; int c_end = batched_cscs[t].table_ptr[1]; int* col_segment_ptr = batched_cscs[t].column_segment_ptr; int* col_segment_indices = batched_cscs[t].column_segment_indices; auto hash_size = get_hash_size(feature_begin); const auto D_begin = D_offsets_data[feature_begin]; const auto D = D_offsets_data[feature_begin + 1] - D_offsets_data[feature_begin]; const auto table_begin = weights_offsets_data[feature_begin]; bool is_shared_table = table_to_feature_offset[t + 1] > table_to_feature_offset[t] + 1; {% if optimizer == "rowwise_adagrad" %} constexpr bool use_fbgemm = std::is_same<scalar_t, float>::value && std::is_same<scalar_t, grad_t>::value; // || std::is_same<scalar_t, at::Half>::value; if (use_fbgemm && !is_shared_table) { // fbgemm handles common case of no shared table using fbgemm_weight_t = typename ::internal::half2float16<scalar_t>::type; auto spmdm_kernel = fbgemm::GenerateEmbeddingSpMDMWithStrides< fbgemm_weight_t, /*IndexType=*/int32_t, /*OffsetType=*/int32_t>( D, batched_cscs[t].weights != nullptr, /*normalize_by_lengths=*/false, /*prefetch=*/16, /*is_weight_positional=*/false, /*use_offsets=*/true, /*output_stride=*/-1, /*input_stride=*/grad_stride); auto rowwise_adagrad_kernel = fbgemm::GenerateSparseAdaGrad</*IndexType=*/int>(D, /*rowwise=*/true); constexpr int C_BLOCK = 64; at::parallel_for(c_begin, c_end, C_BLOCK, [&](int64_t c0, int64_t c1) { grad_t grad_blocked_buffer[C_BLOCK * D]; for (int64_t c = c0; c < c1; c += C_BLOCK) { const int* offsets_begin_ptr = col_segment_ptr + c; int64_t c_block_end = std::min(c + C_BLOCK, c1); bool success = spmdm_kernel( c_block_end - c, col_segment_ptr[c_block_end] - *offsets_begin_ptr, B, reinterpret_cast<const fbgemm_weight_t*>( grad_output_data + D_begin), batched_cscs[t].row_indices + *offsets_begin_ptr, offsets_begin_ptr, batched_cscs[t].weights == nullptr ? nullptr : batched_cscs[t].weights + *offsets_begin_ptr, reinterpret_cast<float*>(grad_blocked_buffer)); if (!success) { fbgemm_gpu::report_embedding_error( t, B, c, c_block_end, col_segment_ptr, batched_cscs[t].row_indices, hash_size, /*allow_minus_one=*/false); } int num_rows_processed = rowwise_adagrad_kernel( c_block_end - c, hash_size * D, reinterpret_cast<float*>(&host_weights_data[table_begin]), reinterpret_cast<const float*>(grad_blocked_buffer), reinterpret_cast<float*>( &momentum1_host[momentum1_offsets_data[feature_begin]]), col_segment_indices + c, eps, -learning_rate, /*weight_decay=*/0, /*counter=*/nullptr, /*counter_halflife=*/0); TORCH_CHECK(num_rows_processed == c_block_end - c, "num of rows processed by adagrad: ", num_rows_processed, "does not match c_block size: ", c_block_end - c); } // for each c }); // parallel for } else {% endif %} { // no fbgemm // TODO: to parallelize, we should easily identify segments belong to // the same column. at::acc_type<grad_t, true> grad_buffer[D]; for (int c = c_begin; c < c_end; ++c) { int64_t idx = col_segment_indices[c]; if (c == c_begin || col_segment_indices[c - 1] != idx) { memset(grad_buffer, 0, D * sizeof(at::acc_type<grad_t, true>)); } const int64_t embedding_begin = table_begin + idx * D; for (int r = col_segment_ptr[c]; r < col_segment_ptr[c + 1]; ++r) { int D_offset = D_begin; if (is_shared_table) { D_offset += batched_cscs[t].column_segment_ids[is_csr2csc_sort ? r : c] * D; } int b = batched_cscs[t].row_indices[r]; for (int64_t d = 0; d < D; ++d) { if (batched_cscs[t].weights != nullptr) { grad_buffer[d] += grad_output_data[b * grad_stride + D_offset + d] * batched_cscs[t].weights[r]; } else { grad_buffer[d] += grad_output_data[b * grad_stride + D_offset + d]; } } } if (c == c_end - 1 || col_segment_indices[c + 1] != idx) { {{ split_weight_update_cpu }} } } // for each c } // no fbgemm } // for each table } template <typename scalar_t> void split_embedding_backward_exact_cpu_dense_kernel( Tensor grad, Tensor grad_output, const at::TensorAccessor<int64_t, 1> weights_offsets_data, const at::TensorAccessor<int, 1> D_offsets_data, Tensor indices, Tensor offsets, int64_t pooling_mode, Tensor indice_weights, int num_tables, int B, const int* table_to_feature_offset) { auto grad_data = grad.data_ptr<scalar_t>(); auto grad_output_data = grad_output.accessor<scalar_t, 2>(); const auto indices_data = indices.accessor<int64_t, 1>(); const auto offsets_data = offsets.accessor<int64_t, 1>(); const auto indice_weights_data = indice_weights.defined() ? // If indice_weights are not defined, then this accessor won't be // used indice_weights.accessor<scalar_t, 1>() : grad.accessor<scalar_t, 1>(); // this is just to make compiler // happy at::parallel_for(0, num_tables, 0, [&](int64_t t_begin, int64_t t_end) { for (int64_t t = table_to_feature_offset[t_begin]; t < table_to_feature_offset[t_end]; ++t) { const auto D_begin = D_offsets_data[t]; const auto D = D_offsets_data[t + 1] - D_offsets_data[t]; const auto table_begin = weights_offsets_data[t]; for (int64_t b = 0; b < B; ++b) { const auto pool_begin = offsets_data[t * B + b]; const auto pool_end = offsets_data[t * B + b + 1]; const auto L = pool_end - pool_begin; const scalar_t scale_factor = // NOTE: MEAN pooling will not work with indice_weights! (static_cast<PoolingMode>(pooling_mode) == PoolingMode::MEAN && !indice_weights.defined() && L > 0) ? 1.0 / L : 1.0; for (auto p = pool_begin; p < pool_end; ++p) { const int64_t embedding_begin = table_begin + indices_data[p] * D; const scalar_t v = indice_weights.defined() ? (indice_weights_data[p] * scale_factor) : scale_factor; for (int64_t d = 0; d < D; ++d) { grad_data[embedding_begin + d] += grad_output_data[b][D_begin + d] * v; } } } } }); // parallel_for } } // namespace // The template for exact optimizers {{ "void" if not dense else "Tensor" }} split_embedding_backward_codegen_{{ optimizer }}_cpu( Tensor grad_output, Tensor host_weights, {% if not dense %} Tensor weights_placements, {% endif %} Tensor weights_offsets, Tensor D_offsets, int64_t max_D, Tensor hash_size_cumsum, int64_t total_hash_size_bits, Tensor indices, Tensor offsets, int64_t pooling_mode, Tensor indice_weights, {% if not dense %} bool stochastic_rounding, {{ args.split_function_args | join(", ") }}, int64_t output_dtype {% else %} {{ args.split_function_args | join(", ") }} {% endif %} ) { int64_t T = D_offsets.numel() - 1; TORCH_CHECK(T > 0); // offsets = [T x B + 1] int64_t B = (offsets.size(0) - 1) / T; TORCH_CHECK(B >= 0); const auto weights_offsets_data = weights_offsets.accessor<int64_t, 1>(); const auto D_offsets_data = D_offsets.accessor<int, 1>(); const auto hash_size_cumsum_data = hash_size_cumsum.accessor<int64_t, 1>(); int num_tables = 0; // # of physical tables int table_to_feature_offset[T + 1]; table_to_feature_offset[0] = 0; for (int feature = 0; feature < T - 1; ++feature) { if (hash_size_cumsum_data[feature + 1] != hash_size_cumsum_data[feature]) { ++num_tables; table_to_feature_offset[num_tables] = feature + 1; } } ++num_tables; table_to_feature_offset[num_tables] = T; TORCH_CHECK(host_weights.dim() == 1); {% if not dense %} {% if "momentum1_offsets" in args.split_function_arg_names %} const auto momentum1_offsets_data = momentum1_offsets.accessor<int64_t, 1>(); {% endif %} {% if "momentum2_offsets" in args.split_function_arg_names %} const auto momentum2_offsets_data = momentum2_offsets.accessor<int64_t, 1>(); {% endif %} grad_output = grad_output.contiguous(); AT_DISPATCH_FLOATING_TYPES_AND_HALF( host_weights.scalar_type(), "split_embedding_backward_exact_cpu", [&] { // TODO: respect output_dtype using grad_t = float; split_embedding_backward_exact_cpu_kernel<scalar_t, grad_t>( grad_output, host_weights, weights_offsets_data, D_offsets_data, hash_size_cumsum, indices, offsets, pooling_mode, indice_weights, num_tables, B, table_to_feature_offset, {% if "momentum1_offsets" in args.split_function_arg_names %} momentum1_offsets_data, {% endif %} {% if "momentum2_offsets" in args.split_function_arg_names %} momentum2_offsets_data, {% endif %} {{ args.split_cpu_kernel_arg_constructors | join(", ") }}); }); return; {% else %} // When input is dense enough, avoid sorting and just treat as dense. auto grad = zeros_like(host_weights, grad_output.dtype()); AT_DISPATCH_FLOATING_TYPES_AND_HALF( grad_output.scalar_type(), "split_embedding_backward_exact_cpu", [&] { split_embedding_backward_exact_cpu_dense_kernel<scalar_t>( grad, grad_output, weights_offsets_data, D_offsets_data, indices, offsets, pooling_mode, indice_weights, num_tables, B, table_to_feature_offset); }); // dispatch host_weights.scalar_type() return grad; {% endif %} } // clang-format on