fbgemm_gpu/codegen/embedding_backward_split_template.cu (1,085 lines of code) (raw):

/* * Copyright (c) Meta Platforms, Inc. and affiliates. * All rights reserved. * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. */ // clang-format off {% set wdesc = "weighted" if weighted else "unweighted" %} #include "fbgemm_gpu/embedding_backward_template_helpers.cuh" #include "fbgemm_gpu/split_embeddings_utils.cuh" {% if not dense %} constexpr int32_t kCacheLocationMissing = -1; {% endif %} constexpr size_t kBackwardMaxThreads = 512; using Tensor = at::Tensor; using namespace fbgemm_gpu; __global__ void split_embedding_backward_codegen_{{ optimizer }}_{{ wdesc }}_find_long_segments( const at::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits> sorted_linear_indices_num_runs, const at::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits> sorted_linear_indices_run_lengths, at::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits> long_run_ids, at::PackedTensorAccessor32<int64_t, 1, at::RestrictPtrTraits> num_long_run_ids, int32_t max_segment_length_per_warp) { const int32_t num_runs = sorted_linear_indices_num_runs[0]; for (auto run_id = blockIdx.x * blockDim.x + threadIdx.x; run_id < num_runs; run_id += blockDim.x * gridDim.x) { if (sorted_linear_indices_run_lengths[run_id] >= max_segment_length_per_warp) { auto long_run_idx = gpuAtomicIncrement(&num_long_run_ids[0]); long_run_ids[long_run_idx] = run_id; } } } template <typename grad_t> __global__ void __launch_bounds__(kMaxThreads) grad_mean_kernel( const at::PackedTensorAccessor32<grad_t, 2, at::RestrictPtrTraits> grad_output, const at::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits> D_offsets, const at::PackedTensorAccessor32<int64_t, 1, at::RestrictPtrTraits> offsets, at::PackedTensorAccessor32<grad_t, 2, at::RestrictPtrTraits> grad_output_mean) { int32_t B = grad_output.size(0); int32_t T = D_offsets.size(0) - 1; int32_t b_t = blockIdx.x * blockDim.y + threadIdx.y; int32_t b = b_t % B; int32_t t = b_t / B; if (b_t >= B * T) { return; } int32_t D_start = D_offsets[t]; int32_t D_end = D_offsets[t + 1]; int32_t D = D_end - D_start; int64_t indices_start = offsets[t * B + b]; int64_t indices_end = offsets[t * B + b + 1]; int32_t L = indices_end - indices_start; if (L != 0) { for (int32_t d = threadIdx.x; d * 4 < D; d += blockDim.x) { Vec4T<grad_t> grad_out_vec(&grad_output[b][D_start + d * 4]); grad_out_vec.mul_(1.0 / L); grad_out_vec.store(&grad_output_mean[b][D_start + d * 4]); } } else { for (int32_t d = threadIdx.x; d * 4 < D; d += blockDim.x) { Vec4T<grad_t> grad_out_vec(&grad_output[b][D_start + d * 4]); grad_out_vec.store(&grad_output_mean[b][D_start + d * 4]); } } } {% for nobag in [True, False] %} {% if not nobag or not weighted %} template < typename emb_t, typename grad_t, typename cache_t, size_t kMaxVecsPerThread> __global__ void __launch_bounds__(kMaxThreads) split_embedding{{ "_nobag" if nobag else "" }}_backward_codegen_{{ optimizer }}_{{ wdesc }}_kernel_cta_per_row_1( const at::PackedTensorAccessor32<grad_t, 2, at::RestrictPtrTraits> grad_output, at::PackedTensorAccessor64<emb_t, 1, at::RestrictPtrTraits> dev_weights, {% if not dense %} at::PackedTensorAccessor64<emb_t, 1, at::RestrictPtrTraits> uvm_weights, at::PackedTensorAccessor64<cache_t, 2, at::RestrictPtrTraits> lxu_cache_weights, const at::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits> weights_placements, {% endif %} const at::PackedTensorAccessor32<int64_t, 1, at::RestrictPtrTraits> weights_offsets, {% if not nobag %} const at::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits> D_offsets, {% else %} int32_t B, int64_t D, {% endif %} const at::PackedTensorAccessor32<int64_t, 1, at::RestrictPtrTraits> hash_size_cumsum, const at::PackedTensorAccessor32<int64_t, 1, at::RestrictPtrTraits> sorted_linear_indices_run, const at::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits> sorted_linear_indices_cumulative_run_lengths, const at::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits> long_run_ids, const at::PackedTensorAccessor32<int64_t, 1, at::RestrictPtrTraits> num_long_run_ids, {% if not nobag %} const at::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits> sorted_infos, {% else %} const at::PackedTensorAccessor32<int64_t, 1, at::RestrictPtrTraits> sorted_infos, {% endif %} {% if not dense %} const at::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits> sorted_lxu_cache_locations, {% endif %} {% if weighted %} const at::PackedTensorAccessor32<at::acc_type<cache_t, true>, 1, at::RestrictPtrTraits> sorted_indice_weights, {% endif %} {% if not dense %} bool stochastic_rounding, at::PhiloxCudaState stochastic_rounding_philox_args, {% else %} at::PackedTensorAccessor64<cache_t, 1, at::RestrictPtrTraits> grad_dev_weights, {% endif %} {% if not nobag %} FixedDivisor fd, {% endif %} {{ args.split_kernel_args | join(", ") }}) { {% if not nobag %} int32_t T = D_offsets.size(0) - 1; const int32_t B = grad_output.size(0); {% else %} int32_t T = weights_offsets.size(0); {% endif %} const int32_t num_long_runs = num_long_run_ids[0]; for (int32_t long_run_id = blockIdx.x; long_run_id < num_long_runs; long_run_id += gridDim.x) { int32_t current_run_id = long_run_ids[long_run_id]; const int64_t linear_index = sorted_linear_indices_run[current_run_id]; const int32_t segment_start = sorted_linear_indices_cumulative_run_lengths[current_run_id]; const int32_t segment_end = sorted_linear_indices_cumulative_run_lengths[current_run_id + 1]; const int32_t SL = segment_end - segment_start; const int32_t warp_id = threadIdx.y; const int32_t lane_id = threadIdx.x; // Note that with shared embedding tables we can have multiple tables // (i.e. different values of `t` sharing the same segment). // const auto info_0 = sorted_infos[segment_start]; {% if not nobag %} int32_t t_0 = fd.Div(info_0); //info_0 / B; {% else %} int32_t t_0 = info_0 % T; {% endif %} int64_t hash_size = hash_size_cumsum[t_0]; {% if not nobag %} int32_t D = D_offsets[t_0 + 1] - D_offsets[t_0]; {% endif %} int64_t idx = linear_index - hash_size; const int32_t SL_per_warp = div_round_up(SL, blockDim.y); const int32_t sl_start = SL_per_warp * warp_id; const int32_t sl_end = min(SL_per_warp * (warp_id + 1), SL); Vec4T<at::acc_type<cache_t, true>> grad_sum[kMaxVecsPerThread]; for (int32_t sl = sl_start; sl < sl_end; sl += kWarpSize) { int32_t sl_j = sl + threadIdx.x; {% if not nobag %} int32_t b_t = sl_j < sl_end ? sorted_infos[segment_start + sl_j] : 0; int32_t b; //= b_t % B; int32_t t; //= b_t / B; fd.DivMod(b_t, &t, &b); int32_t D_start = sl_j < sl_end ? D_offsets[t] : 0; {% else %} int64_t l_t = sl_j < sl_end ? sorted_infos[segment_start + sl_j] : 0; int32_t l = l_t / T; {% endif %} {% if weighted %} at::acc_type<cache_t, true> idx_weight = sl_j < sl_end ? sorted_indice_weights[segment_start + sl_j] : 0.0; {% endif %} for (int32_t j = 0; j < kWarpSize && sl + j < sl_end; ++j) { {% if not nobag %} int32_t b_j = shfl_sync(b, j); int32_t D_start_j = shfl_sync(D_start, j); {% else %} int32_t l_j = shfl_sync(l, j); {% endif %} {% if weighted %} at::acc_type<cache_t, true> idx_weight_j = shfl_sync(idx_weight, j); {% endif %} #pragma unroll kMaxVecsPerThread for (int32_t i = 0; i < kMaxVecsPerThread && 4 * kWarpSize * i + threadIdx.x * 4 < D; ++i) { int32_t d = 4 * kWarpSize * i + threadIdx.x * 4; {% if not nobag %} Vec4T<at::acc_type<grad_t, true>> grad_out_vec( &grad_output[b_j][0] + D_start_j + d); {% else %} Vec4T<at::acc_type<grad_t, true>> grad_out_vec(&grad_output[l_j][d]); {% endif %} {% if weighted %} grad_sum[i].fma_(grad_out_vec, idx_weight_j); {% else %} grad_sum[i].add_(grad_out_vec); {% endif %} } } } // do shared memory reduction only if we used multiple blocks. if (SL > SL_per_warp) { struct SharedMemory<Vec4T<at::acc_type<cache_t, true>>> smem; Vec4T<at::acc_type<cache_t, true>>* shared_grad_sums = smem.getPointer(); #pragma unroll kMaxVecsPerThread for (int32_t i = 0; i < kMaxVecsPerThread && 4 * kWarpSize * i + threadIdx.x * 4 < D; ++i) { shared_grad_sums [lane_id + i * kWarpSize + warp_id * kMaxVecsPerThread * kWarpSize] = grad_sum[i]; } __syncthreads(); if (blockDim.y >= 32) { if (warp_id < 16) { #pragma unroll kMaxVecsPerThread for (int32_t i = 0; i < kMaxVecsPerThread && 4 * kWarpSize * i + threadIdx.x * 4 < D; ++i) { shared_grad_sums [lane_id + i * kWarpSize + warp_id * kMaxVecsPerThread * kWarpSize] = vec4_acc( shared_grad_sums [lane_id + i * kWarpSize + warp_id * kMaxVecsPerThread * kWarpSize], shared_grad_sums [lane_id + i * kWarpSize + (warp_id + 16) * kMaxVecsPerThread * kWarpSize]); } } __syncthreads(); } if (blockDim.y >= 16) { if (warp_id < 8) { #pragma unroll kMaxVecsPerThread for (int32_t i = 0; i < kMaxVecsPerThread && 4 * kWarpSize * i + threadIdx.x * 4 < D; ++i) { shared_grad_sums [lane_id + i * kWarpSize + warp_id * kMaxVecsPerThread * kWarpSize] = vec4_acc( shared_grad_sums [lane_id + i * kWarpSize + warp_id * kMaxVecsPerThread * kWarpSize], shared_grad_sums [lane_id + i * kWarpSize + (warp_id + 8) * kMaxVecsPerThread * kWarpSize]); } } __syncthreads(); } if (blockDim.y >= 8) { if (warp_id < 4) { #pragma unroll kMaxVecsPerThread for (int32_t i = 0; i < kMaxVecsPerThread && 4 * kWarpSize * i + threadIdx.x * 4 < D; ++i) { shared_grad_sums [lane_id + i * kWarpSize + warp_id * kMaxVecsPerThread * kWarpSize] = vec4_acc( shared_grad_sums [lane_id + i * kWarpSize + warp_id * kMaxVecsPerThread * kWarpSize], shared_grad_sums [lane_id + i * kWarpSize + (warp_id + 4) * kMaxVecsPerThread * kWarpSize]); } } __syncthreads(); } if (blockDim.y >= 4) { if (warp_id < 2) { #pragma unroll kMaxVecsPerThread for (int32_t i = 0; i < kMaxVecsPerThread && 4 * kWarpSize * i + threadIdx.x * 4 < D; ++i) { shared_grad_sums [lane_id + i * kWarpSize + warp_id * kMaxVecsPerThread * kWarpSize] = vec4_acc( shared_grad_sums [lane_id + i * kWarpSize + warp_id * kMaxVecsPerThread * kWarpSize], shared_grad_sums [lane_id + i * kWarpSize + (warp_id + 2) * kMaxVecsPerThread * kWarpSize]); } } __syncthreads(); } if (warp_id == 0) { #pragma unroll kMaxVecsPerThread for (int32_t i = 0; i < kMaxVecsPerThread && 4 * kWarpSize * i + threadIdx.x * 4 < D; ++i) { grad_sum[i] = vec4_acc( shared_grad_sums [lane_id + i * kWarpSize + warp_id * kMaxVecsPerThread * kWarpSize], shared_grad_sums [lane_id + i * kWarpSize + (warp_id + 1) * kMaxVecsPerThread * kWarpSize]); } } } if (warp_id == 0) { int64_t weights_offset = weights_offsets[t_0]; {% if not dense %} emb_t* __restrict__ weights{nullptr}; cache_t* __restrict__ cache_weights{nullptr}; int32_t D_emb = D; if (std::is_same<emb_t, uint8_t>::value) { D_emb += kINT8QparamsBytes; } const auto weights_placement = static_cast<PlacementType>(weights_placements[t_0]); if (weights_placement == PlacementType::DEVICE) { weights = &dev_weights[weights_offset + idx * D_emb]; } else { weights = &uvm_weights[weights_offset + idx * D_emb]; } if (weights_placement == PlacementType::MANAGED_CACHING) { int32_t cache_idx = sorted_lxu_cache_locations[segment_start]; if (cache_idx != kCacheLocationMissing) { cache_weights = &lxu_cache_weights[cache_idx][0]; } } {% for tensor in args.split_tensors %} at::acc_type<cache_t, true>* __restrict__ {{ tensor }}; const auto {{ tensor }}_placement = static_cast<PlacementType>({{ tensor }}_placements[t_0]); int64_t {{ tensor }}_offset = {{ tensor }}_offsets[t_0]; if ({{ tensor }}_placement == PlacementType::DEVICE) { {{ tensor }} = &{{ tensor }}_dev[{{ tensor }}_offset]; } else { {{ tensor }} = &{{ tensor }}_uvm[{{ tensor }}_offset]; } {% endfor %} struct SharedMemory<Vec4T<at::acc_type<cache_t, true>>> weight_update_buffer; Vec4T<at::acc_type<cache_t, true>>* shared_weight_update_row = weight_update_buffer.getPointer(); auto weight_row_template = WeightRow<emb_t, cache_t, at::acc_type<cache_t, true>>(weights, cache_weights, D, nullptr); if (!std::is_same<emb_t, float>::value && stochastic_rounding) { StochasticRoundingRNGState state; // different for every *run* and every *thread*. auto stochastic_rounding_seeds = at::cuda::philox::unpack(stochastic_rounding_philox_args); stochastic_rounding_init( std::get<0>(stochastic_rounding_seeds) ^ std::get<1>(stochastic_rounding_seeds), threadIdx.x + current_run_id * blockDim.x, &state); weight_row_template.set_stoc_state(&state); } float2 qparams_template; if (std::is_same<emb_t, uint8_t>::value && !cache_weights) { qparams_template = weight_row_template.load_qparams(); } {{ split_precomputation }} float2 qparams_new; #pragma unroll kMaxVecsPerThread for (int32_t i = 0; i < kMaxVecsPerThread && 4 * kWarpSize * i + threadIdx.x * 4 < D; ++i) { int32_t d = 4 * kWarpSize * i + threadIdx.x * 4; Vec4T<at::acc_type<cache_t, true>> weight_new = weight_row_template.load(d, qparams_template); auto& grad = grad_sum[i]; {{ split_weight_update }} if (std::is_same<emb_t, uint8_t>::value && !cache_weights) { shared_weight_update_row[lane_id + i * kWarpSize] = weight_new; } else { weight_row_template.store(weight_new, d, qparams_new); // qparams_new not used if embedding is not int8 } } if (std::is_same<emb_t, uint8_t>::value && !cache_weights) { // calculate qparams from updated weight row qparams_new = thrust_find_qparams<at::acc_type<cache_t, true>>(shared_weight_update_row, D); weight_row_template.store_qparams(qparams_new); #pragma unroll kMaxVecsPerThread for (int32_t i = 0; i < kMaxVecsPerThread && 4 * kWarpSize * i + threadIdx.x * 4 < D; ++i) { int32_t d = 4 * kWarpSize * i + threadIdx.x * 4; weight_row_template.store(shared_weight_update_row[lane_id + i * kWarpSize], d, qparams_new); } } {% else %} #pragma unroll kMaxVecsPerThread for (int32_t i = 0; i < kMaxVecsPerThread && 4 * kWarpSize * i + threadIdx.x * 4 < D; ++i) { int32_t d = 4 * kWarpSize * i + threadIdx.x * 4; auto& grad = grad_sum[i]; grad.store(&grad_dev_weights[weights_offset + idx * D + d]); } {% endif %} } } } template < typename emb_t, typename grad_t, typename cache_t, size_t kMaxVecsPerThread> __global__ __launch_bounds__(kBackwardMaxThreads) void split_embedding{{ "_nobag" if nobag else "" }}_backward_codegen_{{ optimizer }}_{{ wdesc }}_kernel_warp_per_row_1( const at::PackedTensorAccessor32<grad_t, 2, at::RestrictPtrTraits> grad_output, at::PackedTensorAccessor64<emb_t, 1, at::RestrictPtrTraits> dev_weights, {% if not dense %} at::PackedTensorAccessor64<emb_t, 1, at::RestrictPtrTraits> uvm_weights, at::PackedTensorAccessor64<cache_t, 2, at::RestrictPtrTraits> lxu_cache_weights, const at::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits> weights_placements, {% endif %} const at::PackedTensorAccessor32<int64_t, 1, at::RestrictPtrTraits> weights_offsets, {% if not nobag %} const at::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits> D_offsets, {% else %} int32_t B, int64_t D, {% endif %} const at::PackedTensorAccessor32<int64_t, 1, at::RestrictPtrTraits> hash_size_cumsum, const at::PackedTensorAccessor32<int64_t, 1, at::RestrictPtrTraits> sorted_linear_indices_run, const at::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits> sorted_linear_indices_cumulative_run_lengths, {% if not nobag %} const at::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits> sorted_infos, {% else %} const at::PackedTensorAccessor32<int64_t, 1, at::RestrictPtrTraits> sorted_infos, {% endif %} {% if not dense %} const at::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits> sorted_lxu_cache_locations, {% endif %} {% if weighted %} const at::PackedTensorAccessor32<at::acc_type<cache_t, true>, 1, at::RestrictPtrTraits> sorted_indice_weights, {% endif %} const at::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits> sorted_linear_indices_num_runs, int32_t max_segment_length_per_warp, {% if not dense %} bool stochastic_rounding, at::PhiloxCudaState stochastic_rounding_philox_args, {% else %} at::PackedTensorAccessor64<cache_t, 1, at::RestrictPtrTraits> grad_dev_weights, {% endif %} {% if not nobag %} FixedDivisor fd, {% endif %} {{ args.split_kernel_args | join(", ") }}) { {% if not nobag %} int32_t T = D_offsets.size(0) - 1; const int32_t B = grad_output.size(0); {% else %} int32_t T = weights_offsets.size(0); {% endif %} const int32_t run_id = blockIdx.x * blockDim.y + threadIdx.y; if (run_id >= sorted_linear_indices_run.size(0)) { return; } if (run_id >= sorted_linear_indices_num_runs[0]) { return; } const int64_t linear_index = sorted_linear_indices_run[run_id]; const int32_t segment_start = sorted_linear_indices_cumulative_run_lengths[run_id]; const int32_t segment_end = sorted_linear_indices_cumulative_run_lengths[run_id + 1]; const int32_t SL = segment_end - segment_start; if (SL >= max_segment_length_per_warp) { return; } // now, each segment corresponds to exactly one table `t` and row in // that table (`idx`). Thus, we can hoist out some of the book-keeping. const auto info_0 = sorted_infos[segment_start]; {% if not nobag %} int32_t t_0 = fd.Div(info_0); // info_0 / B; {% else %} int32_t t_0 = info_0 % T; {% endif %} int64_t hash_size = hash_size_cumsum[t_0]; {% if not nobag %} int32_t D = D_offsets[t_0 + 1] - D_offsets[t_0]; {% endif %} int64_t idx = linear_index - hash_size; const int32_t SL_per_warp = div_round_up(SL, blockDim.y); const int32_t sl_start = 0; const int32_t sl_end = SL; Vec4T<at::acc_type<cache_t, true>> grad_sum[kMaxVecsPerThread]; for (int32_t sl = sl_start; sl < sl_end; sl += kWarpSize) { int32_t sl_j = sl + threadIdx.x; {% if not nobag %} int32_t b_t = sl_j < sl_end ? sorted_infos[segment_start + sl_j] : 0; int32_t b; //= b_t % B; int32_t t; //= b_t / B; fd.DivMod(b_t, &t, &b); int32_t D_start = D_offsets[t]; {% else %} int64_t l_t = sl_j < sl_end ? sorted_infos[segment_start + sl_j] : 0; int32_t l = l_t / T; {% endif %} {% if weighted %} at::acc_type<cache_t, true> idx_weight = sl_j < sl_end ? sorted_indice_weights[segment_start + sl_j] : 0.0; {% endif %} for (int32_t j = 0; j < kWarpSize && sl + j < sl_end; ++j) { {% if not nobag %} int32_t b_j = shfl_sync(b, j); int32_t D_start_j = shfl_sync(D_start, j); {% else %} int32_t l_j = shfl_sync(l, j); {% endif %} {% if weighted %} at::acc_type<cache_t, true> idx_weight_j = shfl_sync(idx_weight, j); {% endif %} #pragma unroll kMaxVecsPerThread for (int32_t i = 0; i < kMaxVecsPerThread && 4 * kWarpSize * i + threadIdx.x * 4 < D; ++i) { int32_t d = 4 * kWarpSize * i + threadIdx.x * 4; {% if not nobag %} Vec4T<at::acc_type<grad_t, true>> grad_out_vec( &grad_output[b_j][0] + D_start_j + d); {% else %} Vec4T<at::acc_type<grad_t, true>> grad_out_vec(&grad_output[l_j][d]); {% endif %} {% if weighted %} grad_sum[i].fma_(grad_out_vec, idx_weight_j); {% else %} grad_sum[i].add_(grad_out_vec); {% endif %} } } } int64_t weights_offset = weights_offsets[t_0]; {% if not dense %} emb_t* __restrict__ weights{nullptr}; cache_t* __restrict__ cache_weights{nullptr}; int32_t D_emb = D; if (std::is_same<emb_t, uint8_t>::value) { D_emb += kINT8QparamsBytes; } const auto weights_placement = static_cast<PlacementType>(weights_placements[t_0]); if (weights_placement == PlacementType::DEVICE) { weights = &dev_weights[weights_offset + idx * D_emb]; } else { weights = &uvm_weights[weights_offset + idx * D_emb]; } if (weights_placement == PlacementType::MANAGED_CACHING) { int32_t cache_idx = sorted_lxu_cache_locations[segment_start]; if (cache_idx != kCacheLocationMissing) { cache_weights = &lxu_cache_weights[cache_idx][0]; } } {% for tensor in args.split_tensors %} at::acc_type<cache_t, true>* __restrict__ {{ tensor }}; const auto {{ tensor }}_placement = static_cast<PlacementType>({{ tensor }}_placements[t_0]); int64_t {{ tensor }}_offset = {{ tensor }}_offsets[t_0]; if ({{ tensor }}_placement == PlacementType::DEVICE) { {{ tensor }} = &{{ tensor }}_dev[{{ tensor }}_offset]; } else { {{ tensor }} = &{{ tensor }}_uvm[{{ tensor }}_offset]; } {% endfor %} struct SharedMemory<Vec4T<at::acc_type<cache_t, true>>> weight_update_buffer; Vec4T<at::acc_type<cache_t, true>>* shared_weight_update_row = weight_update_buffer.getPointer(); auto weight_row_template = WeightRow<emb_t, cache_t, at::acc_type<cache_t, true>>(weights, cache_weights, D, nullptr); if (!std::is_same<emb_t, float>::value && stochastic_rounding) { StochasticRoundingRNGState state; // different for every *run* and every *thread*. auto stochastic_rounding_seeds = at::cuda::philox::unpack(stochastic_rounding_philox_args); stochastic_rounding_init( std::get<0>(stochastic_rounding_seeds) ^ std::get<1>(stochastic_rounding_seeds), threadIdx.x + run_id * blockDim.x, &state); weight_row_template.set_stoc_state(&state); } float2 qparams_template; if (std::is_same<emb_t, uint8_t>::value && !cache_weights){ qparams_template = weight_row_template.load_qparams(); } {{ split_precomputation }} float2 qparams_new; #pragma unroll kMaxVecsPerThread for (int32_t i = 0; i < kMaxVecsPerThread && 4 * kWarpSize * i + threadIdx.x * 4 < D; ++i) { int32_t d = 4 * kWarpSize * i + threadIdx.x * 4; Vec4T<at::acc_type<cache_t, true>> weight_new = weight_row_template.load(d, qparams_template); auto& grad = grad_sum[i]; {{ split_weight_update }} if (std::is_same<emb_t, uint8_t>::value && !cache_weights) { shared_weight_update_row[threadIdx.x + i * kWarpSize + threadIdx.y * kMaxVecsPerThread * kWarpSize] = weight_new; } else { weight_row_template.store(weight_new, d, qparams_new); // qparams_new not used if type is not int8 } } if (std::is_same<emb_t, uint8_t>::value && !cache_weights) { // calculate new qparams after row update qparams_new = thrust_find_qparams<at::acc_type<cache_t, true>>(&shared_weight_update_row[threadIdx.y * kMaxVecsPerThread * kWarpSize], D); weight_row_template.store_qparams(qparams_new); // fetch cached updated row from shared mem and quantize on-the-fly when saving to lowp embedding #pragma unroll kMaxVecsPerThread for (int32_t i = 0; i < kMaxVecsPerThread && 4 * kWarpSize * i + threadIdx.x * 4 < D; ++i) { int32_t d = 4 * kWarpSize * i + threadIdx.x * 4; weight_row_template.store(shared_weight_update_row[threadIdx.x + i * kWarpSize + threadIdx.y * kMaxVecsPerThread * kWarpSize], d, qparams_new); } } {% else %} #pragma unroll kMaxVecsPerThread for (int32_t i = 0; i < kMaxVecsPerThread && 4 * kWarpSize * i + threadIdx.x * 4 < D; ++i) { int32_t d = 4 * kWarpSize * i + threadIdx.x * 4; auto& grad = grad_sum[i]; grad.store(&grad_dev_weights[weights_offset + idx * D + d]); } {% endif %} } {{ "void" if not dense else "Tensor" }} split_embedding{{ "_nobag" if nobag else "" }}_backward_codegen_{{ optimizer }}_{{ wdesc }}_exact_cuda( Tensor grad_output, Tensor dev_weights, {% if not dense %} Tensor uvm_weights, Tensor lxu_cache_weights, Tensor weights_placements, {% endif %} Tensor weights_offsets, {% if not nobag %} Tensor D_offsets, int64_t max_D, {% else %} int64_t D, {% endif %} Tensor hash_size_cumsum, int64_t total_hash_size_bits, Tensor indices, Tensor offsets, {% if not nobag %} int64_t pooling_mode, {% endif %} {% if weighted %} Tensor indice_weights, {% endif %} {% if not dense %} Tensor lxu_cache_locations, {% endif %} int64_t unused_, int64_t max_segment_length_per_warp, {% if not dense %} bool stochastic_rounding, {% endif %} {{ args.split_function_args | join(", ") }}) { TENSOR_ON_CUDA_GPU(grad_output); TENSOR_ON_CUDA_GPU(dev_weights); {% if not dense %} TENSOR_ON_CUDA_GPU(uvm_weights); TENSOR_ON_CUDA_GPU(lxu_cache_weights); TENSOR_ON_CUDA_GPU(weights_placements); {% endif %} TENSOR_ON_CUDA_GPU(weights_offsets); {% if not nobag %} TENSOR_ON_CUDA_GPU(D_offsets); {% endif %} TENSOR_ON_CUDA_GPU(hash_size_cumsum); TENSOR_ON_CUDA_GPU(indices); TENSOR_ON_CUDA_GPU(offsets); {% if weighted %} TENSOR_ON_CUDA_GPU(indice_weights); {% endif %} {% if not dense %} TENSOR_ON_CUDA_GPU(lxu_cache_locations); {% endif %} at::cuda::OptionalCUDAGuard device_guard; device_guard.set_index(dev_weights.get_device()); {% if dense %} auto grad_dev_weights = zeros_like(dev_weights); {% endif %} // short-circuit if there are zero indices. if (indices.numel() == 0) { return {{ "grad_dev_weights" if dense else "" }}; } {% if not nobag %} int32_t T = D_offsets.numel() - 1; {% else %} int32_t T = weights_offsets.numel(); {% endif %} TORCH_CHECK(T > 0); // offsets = [B x T + 1] const auto B = (offsets.size(0) - 1) / T; TORCH_CHECK(B > 0); auto BT_block_size = kMaxThreads / kWarpSize; TORCH_CHECK(BT_block_size * kWarpSize <= kMaxThreads); {% if not nobag %} TORCH_CHECK(max_D <= {{ max_embedding_dim }}); {% else %} TORCH_CHECK(D <= {{ max_embedding_dim }}); {% endif %} // V100: 96 KB; A100: 160 KB. int max_shared_bytes = 0; #ifndef __HIP_PLATFORM_HCC__ cudaDeviceGetAttribute(&max_shared_bytes, cudaDevAttrMaxSharedMemoryPerBlockOptin, dev_weights.get_device()); #else // MI100 has 64 KB local memory (shared memory) per workgroup max_shared_bytes = 64 << 10; #endif C10_CUDA_KERNEL_LAUNCH_CHECK(); int shared_kb = max_shared_bytes >> 10; // V100: 64 KB; A100: 96 KB. // Use 2/3 of the available GPU shared mem; leave rooms for L1$. int used_shared_kb = round_down(shared_kb * 2 / 3, 16); TORCH_CHECK(used_shared_kb > 0); int used_shared_bytes = used_shared_kb << 10; Tensor linear_indices, linear_indices_sorted; Tensor infos_sorted; Tensor sorted_linear_indices_run, sorted_linear_indices_run_lengths, sorted_linear_indices_num_runs, sorted_linear_indices_cumulative_run_lengths; std::tie( linear_indices, linear_indices_sorted, infos_sorted, sorted_linear_indices_run, sorted_linear_indices_run_lengths, sorted_linear_indices_num_runs, sorted_linear_indices_cumulative_run_lengths) = transpose_embedding_input( hash_size_cumsum, total_hash_size_bits, indices, offsets, {{"true" if nobag else "false"}}); {% if not dense %} auto lxu_cache_locations_sorted = at::empty_like(lxu_cache_locations); if (lxu_cache_locations.size(0) > 0) { size_t temp_storage_bytes = 0; AT_CUDA_CHECK(radix_sort_pairs( nullptr, temp_storage_bytes, linear_indices.data_ptr<int64_t>(), linear_indices_sorted.data_ptr<int64_t>(), lxu_cache_locations.data_ptr<int32_t>(), lxu_cache_locations_sorted.data_ptr<int32_t>(), linear_indices.numel(), 0, total_hash_size_bits, at::cuda::getCurrentCUDAStream(), false)); auto temp_storage = at::empty( {static_cast<int64_t>(temp_storage_bytes)}, indices.options().dtype(at::kByte)); AT_CUDA_CHECK(radix_sort_pairs( temp_storage.data_ptr(), temp_storage_bytes, linear_indices.data_ptr<int64_t>(), linear_indices_sorted.data_ptr<int64_t>(), lxu_cache_locations.data_ptr<int32_t>(), lxu_cache_locations_sorted.data_ptr<int32_t>(), linear_indices.numel(), 0, total_hash_size_bits, at::cuda::getCurrentCUDAStream(), false)); } {% endif %} {% if not dense %} DISPATCH_EMB_GRAD_CACHE_TYPES( dev_weights.type(), grad_output.type(), lxu_cache_weights.type(), {% else %} AT_DISPATCH_FLOATING_TYPES_AND_HALF( dev_weights.type(), {% endif %} "split_embedding_backward_{{ optimizer }}_exact_kernel", [&] { {% if weighted %} auto indice_weights_sorted = at::empty_like(indice_weights); { size_t temp_storage_bytes = 0; AT_CUDA_CHECK(radix_sort_pairs( nullptr, temp_storage_bytes, linear_indices.data_ptr<int64_t>(), linear_indices_sorted.data_ptr<int64_t>(), {% if not dense %} indice_weights.data_ptr<at::acc_type<cache_t, true>>(), indice_weights_sorted.data_ptr<at::acc_type<cache_t, true>>(), {% else %} indice_weights.data_ptr<at::acc_type<scalar_t, true>>(), indice_weights_sorted.data_ptr<at::acc_type<scalar_t, true>>(), {% endif %} linear_indices.numel(), 0, total_hash_size_bits, at::cuda::getCurrentCUDAStream(), false)); auto temp_storage = at::empty( {static_cast<int64_t>(temp_storage_bytes)}, indices.options().dtype(at::kByte)); AT_CUDA_CHECK(radix_sort_pairs( temp_storage.data_ptr(), temp_storage_bytes, linear_indices.data_ptr<int64_t>(), linear_indices_sorted.data_ptr<int64_t>(), {% if not dense %} indice_weights.data_ptr<at::acc_type<cache_t, true>>(), indice_weights_sorted.data_ptr<at::acc_type<cache_t, true>>(), {% else %} indice_weights.data_ptr<at::acc_type<scalar_t, true>>(), indice_weights_sorted.data_ptr<at::acc_type<scalar_t, true>>(), {% endif %} linear_indices.numel(), 0, total_hash_size_bits, at::cuda::getCurrentCUDAStream(), false)); } {% endif %} // early memory release linear_indices.reset(); linear_indices_sorted.reset(); auto grad_output_accessor = grad_output.packed_accessor32< {{ "at::acc_type<scalar_t, true>" if dense else "grad_t" }}, 2, at::RestrictPtrTraits>(); {% if not nobag %} Tensor grad_output_mean; if (static_cast<PoolingMode>(pooling_mode) == PoolingMode::MEAN) { grad_output_mean = at::empty_like(grad_output); grad_mean_kernel<{{ "at::acc_type<scalar_t, true>" if dense else "grad_t" }}> <<<div_round_up((B * T), kMaxThreads / kWarpSize), dim3(kWarpSize, kMaxThreads / kWarpSize), 0, at::cuda::getCurrentCUDAStream()>>>( grad_output_accessor, D_offsets .packed_accessor32<int32_t, 1, at::RestrictPtrTraits>(), offsets .packed_accessor32<int64_t, 1, at::RestrictPtrTraits>(), grad_output_mean.packed_accessor32< {{ "at::acc_type<scalar_t, true>" if dense else "grad_t" }}, 2, at::RestrictPtrTraits>()); C10_CUDA_KERNEL_LAUNCH_CHECK(); grad_output_accessor = grad_output_mean.packed_accessor32< {{ "at::acc_type<scalar_t, true>" if dense else "grad_t" }}, 2, at::RestrictPtrTraits>(); } {% endif %} {% if not dense %} at::PhiloxCudaState rng_engine_inputs; if (stochastic_rounding && !std::is_same<emb_t, float>::value) { auto gen = at::cuda::detail::getDefaultCUDAGenerator(); std::lock_guard<std::mutex> lock(gen.mutex()); rng_engine_inputs = at::check_generator<at::CUDAGeneratorImpl>(gen) ->philox_cuda_state(4); } {% endif %} {% for kMaxVecsPerThread in range(1, max_embedding_dim // 128 + 1) %} {% if not nobag %} if (max_D <= {{ 128 * kMaxVecsPerThread }}) { {% else %} if (D <= {{ 128 * kMaxVecsPerThread }}) { {% endif %} // Stay under used_shared_kb of shared memory (V100: 64 KB; A100: 96 KB), BT_block_size must be a power of two. while (BT_block_size * sizeof(at::acc_type<{{ "scalar_t" if dense else "cache_t" }}, true>) * 4 * kWarpSize * {{ kMaxVecsPerThread }} >= used_shared_bytes) { BT_block_size /= 2; } TORCH_CHECK(BT_block_size >= 1); if (std::is_same<{{ "scalar_t" if dense else "emb_t" }}, double>::value) { // Otherwise we see CUDA kernel launch failures despite the above checks. BT_block_size = 1; } auto long_run_ids = at::empty_like(sorted_linear_indices_run_lengths); auto num_long_run_ids = at::zeros({1}, indices.options().dtype(at::kLong)); split_embedding_backward_codegen_{{ optimizer }}_{{ wdesc }}_find_long_segments<<< div_round_up(sorted_linear_indices_run_lengths.numel(), kMaxThreads), kMaxThreads, 0, at::cuda::getCurrentCUDAStream() >>>( sorted_linear_indices_num_runs.packed_accessor32<int32_t, 1, at::RestrictPtrTraits>(), sorted_linear_indices_run_lengths.packed_accessor32<int32_t, 1, at::RestrictPtrTraits>(), long_run_ids.packed_accessor32<int32_t, 1, at::RestrictPtrTraits>(), num_long_run_ids.packed_accessor32<int64_t, 1, at::RestrictPtrTraits>(), max_segment_length_per_warp); C10_CUDA_KERNEL_LAUNCH_CHECK(); // Check https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#shared-memory-7-x // "Compute capability 7.x devices allow a single thread block to // address the full capacity of shared memory: 96 KB on Volta, // 64 KB on Turing. Kernels relying on shared memory allocations // over 48 KB per block are architecture-specific, as such they // must use dynamic shared memory (rather than statically sized // arrays) and require an explicit opt-in using cudaFuncSetAttribute()". #ifndef __HIP_PLATFORM_HCC__ cudaFuncSetAttribute( split_embedding{{ "_nobag" if nobag else "" }}_backward_codegen_{{ optimizer }}_{{ wdesc }}_kernel_cta_per_row_1< {% if not dense %} emb_t, grad_t, cache_t, {% else %} scalar_t, at::acc_type<scalar_t, true>, scalar_t, {% endif %} {{ kMaxVecsPerThread }}>, cudaFuncAttributeMaxDynamicSharedMemorySize, used_shared_bytes); // V100: 64 KB; A100: 96 KB. #endif C10_CUDA_KERNEL_LAUNCH_CHECK(); // dividing by kMaxThreads is a heuristic to avoid num of blocks far exceeding num_long_run_ids[0] split_embedding{{ "_nobag" if nobag else "" }}_backward_codegen_{{ optimizer }}_{{ wdesc }}_kernel_cta_per_row_1< {% if not dense %} emb_t, grad_t, cache_t, {% else %} scalar_t, at::acc_type<scalar_t, true>, scalar_t, {% endif %} {{ kMaxVecsPerThread }}> <<<div_round_up(long_run_ids.numel(), kMaxThreads), dim3(kWarpSize, BT_block_size), BT_block_size * sizeof(at::acc_type<{{ "scalar_t" if dense else "cache_t" }}, true>) * 4 * kWarpSize * {{ kMaxVecsPerThread }}, at::cuda::getCurrentCUDAStream()>>>( grad_output_accessor, {% if not dense %} dev_weights.packed_accessor64<emb_t, 1, at::RestrictPtrTraits>(), uvm_weights.packed_accessor64<emb_t, 1, at::RestrictPtrTraits>(), lxu_cache_weights.packed_accessor64<cache_t, 2, at::RestrictPtrTraits>(), weights_placements.packed_accessor32<int32_t, 1, at::RestrictPtrTraits>(), {% else %} dev_weights.packed_accessor64<scalar_t, 1, at::RestrictPtrTraits>(), {% endif %} weights_offsets.packed_accessor32<int64_t, 1, at::RestrictPtrTraits>(), {% if not nobag %} D_offsets.packed_accessor32<int32_t, 1, at::RestrictPtrTraits>(), {% else %} B, D, {% endif %} hash_size_cumsum.packed_accessor32<int64_t, 1, at::RestrictPtrTraits>(), sorted_linear_indices_run .packed_accessor32<int64_t, 1, at::RestrictPtrTraits>(), sorted_linear_indices_cumulative_run_lengths .packed_accessor32<int32_t, 1, at::RestrictPtrTraits>(), long_run_ids.packed_accessor32<int32_t, 1, at::RestrictPtrTraits>(), num_long_run_ids.packed_accessor32<int64_t, 1, at::RestrictPtrTraits>(), {% if not nobag %} infos_sorted.packed_accessor32<int32_t, 1, at::RestrictPtrTraits>(), {% else %} infos_sorted.packed_accessor32<int64_t, 1, at::RestrictPtrTraits>(), {% endif %} {% if not dense %} lxu_cache_locations_sorted.packed_accessor32<int32_t, 1, at::RestrictPtrTraits>(), {% endif %} {% if weighted %} indice_weights_sorted.packed_accessor32<at::acc_type<{{ "scalar_t" if dense else "cache_t" }}, true>, 1, at::RestrictPtrTraits>(), {% endif %} {% if not dense %} stochastic_rounding, rng_engine_inputs, {% else %} grad_dev_weights.packed_accessor64<scalar_t, 1, at::RestrictPtrTraits>(), {% endif %} {% if not nobag %} FixedDivisor(B), {% endif %} {{ args.split_kernel_arg_constructors | join(", ") }}); C10_CUDA_KERNEL_LAUNCH_CHECK(); #ifndef __HIP_PLATFORM_HCC__ cudaFuncSetAttribute( split_embedding{{ "_nobag" if nobag else "" }}_backward_codegen_{{ optimizer }}_{{ wdesc }}_kernel_warp_per_row_1< {% if not dense %} emb_t, grad_t, cache_t, {% else %} scalar_t, at::acc_type<scalar_t, true>, scalar_t, {% endif %} {{ kMaxVecsPerThread }}>, cudaFuncAttributeMaxDynamicSharedMemorySize, used_shared_bytes); // V100: 64 KB; A100: 96 KB. #endif C10_CUDA_KERNEL_LAUNCH_CHECK(); split_embedding{{ "_nobag" if nobag else "" }}_backward_codegen_{{ optimizer }}_{{ wdesc }}_kernel_warp_per_row_1< {% if not dense %} emb_t, grad_t, cache_t, {% else %} scalar_t, at::acc_type<scalar_t, true>, scalar_t, {% endif %} {{ kMaxVecsPerThread }}> <<<div_round_up(sorted_linear_indices_run.numel(), kBackwardMaxThreads / kWarpSize), dim3(kWarpSize, kBackwardMaxThreads / kWarpSize), BT_block_size * sizeof( at::acc_type< {% if not dense %} cache_t {% else %} scalar_t {% endif %}, true>) * 4 * kWarpSize * {{ kMaxVecsPerThread }}, at::cuda::getCurrentCUDAStream()>>>( grad_output_accessor, {% if not dense %} dev_weights.packed_accessor64<emb_t, 1, at::RestrictPtrTraits>(), uvm_weights.packed_accessor64<emb_t, 1, at::RestrictPtrTraits>(), lxu_cache_weights.packed_accessor64<cache_t, 2, at::RestrictPtrTraits>(), weights_placements.packed_accessor32<int32_t, 1, at::RestrictPtrTraits>(), {% else %} dev_weights.packed_accessor64<scalar_t, 1, at::RestrictPtrTraits>(), {% endif %} weights_offsets.packed_accessor32<int64_t, 1, at::RestrictPtrTraits>(), {% if not nobag %} D_offsets.packed_accessor32<int32_t, 1, at::RestrictPtrTraits>(), {% else %} B, D, {% endif %} hash_size_cumsum.packed_accessor32<int64_t, 1, at::RestrictPtrTraits>(), sorted_linear_indices_run .packed_accessor32<int64_t, 1, at::RestrictPtrTraits>(), sorted_linear_indices_cumulative_run_lengths .packed_accessor32<int32_t, 1, at::RestrictPtrTraits>(), {% if not nobag %} infos_sorted.packed_accessor32<int32_t, 1, at::RestrictPtrTraits>(), {% else %} infos_sorted.packed_accessor32<int64_t, 1, at::RestrictPtrTraits>(), {% endif %} {% if not dense %} lxu_cache_locations_sorted.packed_accessor32<int32_t, 1, at::RestrictPtrTraits>(), {% endif %} {% if weighted %} indice_weights_sorted.packed_accessor32<at::acc_type<{{ "scalar_t" if dense else "cache_t" }}, true>, 1, at::RestrictPtrTraits>(), {% endif %} sorted_linear_indices_num_runs .packed_accessor32<int32_t, 1, at::RestrictPtrTraits>(), max_segment_length_per_warp, {% if not dense %} stochastic_rounding, rng_engine_inputs, {% else %} grad_dev_weights.packed_accessor64<scalar_t, 1, at::RestrictPtrTraits>(), {% endif %} {% if not nobag %} FixedDivisor(B), {% endif %} {{ args.split_kernel_arg_constructors | join(", ") }}); C10_CUDA_KERNEL_LAUNCH_CHECK(); return; } {% endfor %} }); return {{ "grad_dev_weights" if dense else "" }}; } {% endif %} {% endfor %} // clang-format on