include/fbgemm/FbgemmEmbedding.h (234 lines of code) (raw):

/* * Copyright (c) Meta Platforms, Inc. and affiliates. * All rights reserved. * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. */ #pragma once #include <cstdint> #include <functional> #include "fbgemm/FbgemmBuild.h" namespace fbgemm { template < typename InType, typename IndexType, typename OffsetType = std::int32_t, typename OutType = float> class EmbeddingSpMDMKernelSignature { public: /** * Behavior is as the follow pseudocode * (when use_offsets == true, lengths[i] == offsets[i + 1] - offsets[i]) * (when is_weight_positional == true, use weights[j - offsets[i]] instead of * weights[j]) * * for i in range(output_size): * out[i * block_size : (i + 1) * block_size] = 0 * for j in range(offsets[i], offsets[i + 1]): * for k in range(block_size): * out[i * block_size + k] += input[indices[j] * block_size + k] * * weights ? weights[j] : 1; * if normalize_weights and lengths[i] > 0: * out[i * block_size : (i + 1) * block_size] /= lengths[i] * * @param data_size the number of rows in embedding table */ using Type = std::function<bool( std::int64_t output_size, std::int64_t index_size, std::int64_t data_size, const InType* input, const IndexType* indices, const OffsetType* offsets_or_lengths, const float* weights, // optional, can be null for non-weighted sum OutType* out)>; }; /** * @tparam InType can be float, float16, or uint8_t * @tparam IndexType can be int32_t or int64_t * @tparam IndexType can be int32_t or int64_t * * @param use_offsets If true, the generated code assumes we will pass offsets * instead of lengths that confirms PyTorch EmbeddingBag * interface. In this case, the length of offsets array * should be output_size + 1 and offsets[output_size] should * be index_size. * If false, the generate code assumes we will pass lengths * that confirms Caffe2 SparseLengthsSum interface. */ template < typename InType, typename IndexType, typename OffsetType = std::int32_t, typename OutType = float> FBGEMM_API typename EmbeddingSpMDMKernelSignature< InType, IndexType, OffsetType, OutType>::Type GenerateEmbeddingSpMDM( const std::int64_t block_size, bool has_weight, bool normalize_by_lengths, int prefetch = 16, bool is_weight_positional = false, bool use_offsets = true); /** * @param output_stride If -1, output_stride is same as block_size * @param input_stride If -1, input_stride is same as block_size * @param scale_bias_last if false, scale and bias appear at the beginning * of each row and are in fp16 for table batched embedding (TBE) * in FBGEMM_GPU. If false, it can also take -1 indices (output from * pruned embedding id mapping) */ template < typename InType, typename IndexType, typename OffsetType = std::int32_t, typename OutType = float, bool THREAD_LOCAL = false> FBGEMM_API typename EmbeddingSpMDMKernelSignature< InType, IndexType, OffsetType, OutType>::Type GenerateEmbeddingSpMDMWithStrides( const std::int64_t block_size, bool has_weight, bool normalize_by_lengths, int prefetch = 16, bool is_weight_positional = false, bool use_offsets = true, std::int64_t output_stride = -1, std::int64_t input_stride = -1, bool scale_bias_last = true); /** * @tparam IndexType can be int32_t or int64_t * @tparam OffsetType can be int32_t or int64_t * @param bit_rate can be 2 or 4 */ template < typename IndexType, typename OffsetType = std::int32_t, typename OutType = float> FBGEMM_API typename EmbeddingSpMDMKernelSignature< std::uint8_t, IndexType, OffsetType, OutType>::Type GenerateEmbeddingSpMDMNBit( int bit_rate, const std::int64_t block_size, bool has_weight, bool normalize_by_lengths, int prefetch = 16, bool is_weight_positional = false, bool use_offsets = true); /** * @param output_stride If -1, output_stride is same as block_size * @param input_stride in Bytes. If -1, input_stride is same as * block_size / num_elem_per_byte + 2 * sizeof(float16) * @param scale_bias_last if false, scale and bias appear at the beginning * of each row and are in fp16 for table batched embedding (TBE) * in FBGEMM_GPU. If false, it can also take -1 indices (output from * pruned embedding id mapping) */ template < typename IndexType, typename OffsetType = std::int32_t, typename OutType = float, bool THREAD_LOCAL = false> FBGEMM_API typename EmbeddingSpMDMKernelSignature< std::uint8_t, IndexType, OffsetType, OutType>::Type GenerateEmbeddingSpMDMNBitWithStrides( int bit_rate, const std::int64_t block_size, bool has_weight, bool normalize_by_lengths, int prefetch = 16, bool is_weight_positional = false, bool use_offsets = true, std::int64_t output_stride = -1, std::int64_t input_stride = -1, bool scale_bias_last = true); template < typename InType, typename IndexType, typename OffsetType = std::int32_t> class EmbeddingSpMDMRowWiseSparseKernelSignature { public: using Type = std::function<bool( std::int64_t output_size, std::int64_t index_size, std::int64_t uncompressed_data_size, // TODO: add compressed_data_size and check array bound const InType* input, const IndexType* indices, const OffsetType* offsets_or_lengths, const float* weights, // optional, can be null for non-weighted sum float* out, const std::int32_t* compressed_indices_table)>; }; /** * @tparam InType can be float, float16, or uint8_t * @tparam IndexType can be int32_t or int64_t * @tparam OffsetType can be int32_t or int64_t */ template < typename InType, typename IndexType, typename OffsetType = std::int32_t> FBGEMM_API typename EmbeddingSpMDMRowWiseSparseKernelSignature< InType, IndexType, OffsetType>::Type GenerateEmbeddingSpMDMRowWiseSparse( const std::int64_t block_size, bool has_weight, bool normalize_by_lengths, int prefetch = 16, bool is_weight_positional = false, bool use_offsets = true); /** * @tparam IndexType can be int32_t or int64_t * @tparam OffsetType can be int32_t or int64_t * @param bit_rate can be 2 or 4 */ template <typename IndexType, typename OffsetType = std::int32_t> FBGEMM_API typename EmbeddingSpMDMRowWiseSparseKernelSignature< std::uint8_t, IndexType, OffsetType>::Type GenerateEmbeddingSpMDMNBitRowWiseSparse( int bit_rate, const std::int64_t block_size, bool has_weight, bool normalize_by_lengths, int prefetch = 16, bool is_weight_positional = false, bool use_offsets = true); /** * @return The number of rows processed. If smaller than num_rows, an error * must have happened at the last row processed. */ template <typename IndexType> class SparseAdaGradSignature { public: using Type = std::function<int( int num_rows, // number of rows reading std::uint64_t param_size, // total number of parameters float* w, // input/output parameters const float* g, // input gradients float* h, // input/output momentums const IndexType* indices, // indices of each row float epsilon, float lr, float weight_decay, const double* counter, // used for weight_decay adjusted for frequency // nullptr when frequency adjustment is not used. // ignored when the kernel is generated with // use_weight_decay = false. std::int64_t counter_halflife)>; // frequency adjust happens only after }; template <typename IndexType> FBGEMM_API typename SparseAdaGradSignature<IndexType>::Type GenerateSparseAdaGrad( int block_size, // number of parameters per row bool rowwise = false, int prefetch = 16, bool use_weight_decay = false); // RowWiseSparseAdaGrad fused with SLS gradient // Weights can be either float or float16 template < typename IndexType, typename OffsetType = std::int32_t, typename DataType = float> class RowWiseSparseAdaGradFusedSignature { public: using Type = std::function<bool( std::int64_t output_size, std::int64_t index_size, std::int64_t data_size, // number of rows in w DataType* w, // input/output parameters const float* g, // input gradients float* h, // input/output momentums const IndexType* indices, // indices of each row const OffsetType* offsets_or_lengths, float epsilon, float lr)>; }; /** * @param grad_stride If -1, grad_stride is same as block size */ template < typename IndexType, typename OffsetType = std::int32_t, typename DataType = float> FBGEMM_API typename RowWiseSparseAdaGradFusedSignature< IndexType, OffsetType, DataType>::Type GenerateRowWiseSparseAdaGradFused( int block_size, // number of parameters per row int prefetch = 16, bool use_offsets = true, bool use_stochastic_rounding = true, int grad_stride = -1); namespace internal { // Specialization for block size 1 internally called by GenerateEmbeddingSpMDM template <typename InType, typename IndexType, typename OffsetType> FBGEMM_API bool EmbeddingSpMDMBlockSize1_( const std::int64_t output_size, const std::int64_t index_size, const std::int64_t data_size, // the number of rows in input const InType* input, const IndexType* indices, const OffsetType* offsets_or_lengths, const float* weights, // optional, can be null for non-weighted sum bool normalize_by_lengths, float* out, bool is_weight_positional = false, bool use_offsets = true); template <typename IndexType, bool HAS_WEIGHTS> void compressed_indices_remap_avx512( std::int32_t offsets_numel, const IndexType* indices, const int32_t* compressed_indices_mapping, const IndexType* offsets, const float* weights, // optional, can be null, IndexType* out_indices, IndexType* out_offsets, float* out_weights); } // namespace internal template <typename IndexType> FBGEMM_API void compressed_indices_remap( std::int32_t offsets_numel, const IndexType* indices, const int32_t* compressed_indices_mapping, const IndexType* offsets, const float* weights, // optional, can be null, IndexType* out_indices, IndexType* out_offsets, float* out_weights); } // namespace fbgemm