fbgemm_gpu/include/fbgemm_gpu/split_embeddings_utils.cuh (139 lines of code) (raw):

/* * Copyright (c) Meta Platforms, Inc. and affiliates. * All rights reserved. * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. */ #pragma once #include <ATen/ATen.h> #include "fbgemm_gpu/fbgemm_cuda_utils.cuh" // Warp bitonic K/V sorting code from @jhj template <typename T> struct Comparator { __device__ static inline bool lt(T a, T b) { return a < b; } __device__ static inline bool gt(T a, T b) { return a > b; } }; template <typename T> inline __device__ void assign(bool assign, T& x, T y) { x = assign ? y : x; } template < typename K, typename V, int32_t L, bool Dir, typename Comp, bool IsBitonic> inline __device__ void warpBitonicMergeLE16(K& k, V& v) { static_assert( L <= fbgemm_gpu::kWarpSize / 2, "merge list size must be <= 16"); int32_t laneId = threadIdx.x; if (!IsBitonic) { // Reverse the first comparison stage. // For example, merging a list of size 8 has the exchanges: // 0 <-> 15, 1 <-> 14, ... K otherK = fbgemm_gpu::shfl_xor(k, 2 * L - 1); V otherV = fbgemm_gpu::shfl_xor(v, 2 * L - 1); // Whether we are the lesser thread in the exchange bool small = !(laneId & L); if (Dir) { // See the comment above how performing both of these // comparisons in the warp seems to win out over the // alternatives in practice bool s = small ? Comp::gt(k, otherK) : Comp::lt(k, otherK); assign(s, k, otherK); assign(s, v, otherV); } else { bool s = small ? Comp::lt(k, otherK) : Comp::gt(k, otherK); assign(s, k, otherK); assign(s, v, otherV); } } #pragma unroll for (int32_t stride = IsBitonic ? L : L / 2; stride > 0; stride /= 2) { K otherK = fbgemm_gpu::shfl_xor(k, stride); V otherV = fbgemm_gpu::shfl_xor(v, stride); // Whether we are the lesser thread in the exchange bool small = !(laneId & stride); if (Dir) { bool s = small ? Comp::gt(k, otherK) : Comp::lt(k, otherK); assign(s, k, otherK); assign(s, v, otherV); } else { bool s = small ? Comp::lt(k, otherK) : Comp::gt(k, otherK); assign(s, k, otherK); assign(s, v, otherV); } } } template <typename K, typename V, bool Dir, typename Comp> struct BitonicSort { static inline __device__ void sort(K k[1], V v[1]) { #ifdef __HIP_PLATFORM_HCC__ static_assert(fbgemm_gpu::kWarpSize == 64, "unexpected warp size"); #else static_assert(fbgemm_gpu::kWarpSize == 32, "unexpected warp size"); #endif warpBitonicMergeLE16<K, V, 1, Dir, Comp, false>(k[0], v[0]); warpBitonicMergeLE16<K, V, 2, Dir, Comp, false>(k[0], v[0]); warpBitonicMergeLE16<K, V, 4, Dir, Comp, false>(k[0], v[0]); warpBitonicMergeLE16<K, V, 8, Dir, Comp, false>(k[0], v[0]); warpBitonicMergeLE16<K, V, 16, Dir, Comp, false>(k[0], v[0]); } }; std::tuple<at::Tensor, at::Tensor, c10::optional<at::Tensor>> get_unique_indices_cuda( at::Tensor linear_indices, int64_t max_indices, bool compute_count); std::pair<at::Tensor, at::Tensor> lru_cache_find_uncached_cuda( at::Tensor unique_indices, at::Tensor unique_indices_length, int64_t max_indices, at::Tensor lxu_cache_state, int64_t time_stamp, at::Tensor lru_state); /** * "Transpose" embedding inputs by sorting indices by their values. * Logically this transpose compressed sparse row (CSR) representation * stored in indices and offsets to compressed sparse column (CSC). */ std::tuple< at::Tensor /*linear_indices*/, at::Tensor /*linear_indices_sorted*/, at::Tensor /*infos_sorted*/, at::Tensor /*sorted_linear_indices_run*/, at::Tensor /*sorted_linear_indices_run_lengths*/, at::Tensor /*sorted_linear_indices_num_runs*/, at::Tensor /*sorted_linear_indices_cumulative_run_lengths*/> transpose_embedding_input( at::Tensor hash_size_cumsum, int64_t total_hash_size_bits, at::Tensor indices, at::Tensor offsets, bool nobag = false); // Use these functions instead of directly calling cub functions // to reduce code size and compilation time. // Arguments are the same as cub::DeviceRadixSort::SortPairs #define DECL_RADIX_SORT_PAIRS_FN(KeyT, ValueT) \ cudaError_t radix_sort_pairs( \ void* d_temp_storage, \ size_t& temp_storage_bytes, \ const KeyT* d_keys_in, \ KeyT* d_keys_out, \ const ValueT* d_values_in, \ ValueT* d_values_out, \ int num_items, \ int begin_bit = 0, \ int end_bit = sizeof(KeyT) * 8, \ cudaStream_t stream = 0, \ bool debug_synchronous = false) DECL_RADIX_SORT_PAIRS_FN(int64_t, float); DECL_RADIX_SORT_PAIRS_FN(int64_t, double); DECL_RADIX_SORT_PAIRS_FN(int64_t, int64_t); DECL_RADIX_SORT_PAIRS_FN(int64_t, int32_t); #undef DECL_RADIX_SORT_PAIRS_FN