fbgemm_gpu/codegen/embedding_forward_quantized_host.cpp (180 lines of code) (raw):

/* * Copyright (c) Meta Platforms, Inc. and affiliates. * All rights reserved. * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. */ #include <ATen/ATen.h> #include <ATen/TypeDefault.h> #include <ATen/core/op_registration/op_registration.h> #include <ATen/cuda/CUDAContext.h> #include <torch/library.h> #include "c10/core/ScalarType.h" #include "fbgemm_gpu/embedding_common.h" #include "fbgemm_gpu/sparse_ops_utils.h" using Tensor = at::Tensor; Tensor int_nbit_split_embedding_codegen_forward_unweighted_cuda( Tensor dev_weights, Tensor uvm_weights, Tensor weights_placements, Tensor weights_offsets, Tensor weights_tys, Tensor D_offsets, int64_t total_D, int64_t max_int2_D, int64_t max_int4_D, int64_t max_int8_D, int64_t max_float16_D, int64_t max_float32_D, Tensor indices, Tensor offsets, int64_t pooling_mode, int64_t row_alignment, int64_t output_dtype, Tensor lxu_cache_weights, Tensor lxu_cache_locations, int64_t unused); Tensor int_nbit_split_embedding_codegen_forward_weighted_cuda( Tensor dev_weights, Tensor uvm_weights, Tensor weights_placements, Tensor weights_offsets, Tensor weights_tys, Tensor D_offsets, int64_t total_D, int64_t max_int2_D, int64_t max_int4_D, int64_t max_int8_D, int64_t max_float16_D, int64_t max_float32_D, Tensor indices, Tensor offsets, int64_t pooling_mode, int64_t row_alignment, Tensor indice_weights, int64_t output_dtype, Tensor lxu_cache_weights, Tensor lxu_cache_locations, int64_t unused); Tensor int_nbit_split_embedding_nobag_codegen_forward_unweighted_cuda( Tensor dev_weights, Tensor uvm_weights, Tensor weights_placements, Tensor weights_offsets, Tensor weights_tys, int64_t D, int64_t max_int2_D, int64_t max_int4_D, int64_t max_int8_D, int64_t max_float16_D, int64_t max_float32_D, Tensor indices, Tensor offsets, int64_t row_alignment, int64_t output_dtype, Tensor lxu_cache_weights, Tensor lxu_cache_locations, int64_t unused); Tensor int_nbit_split_embedding_codegen_lookup_function( Tensor dev_weights, Tensor uvm_weights, Tensor weights_placements, Tensor weights_offsets, Tensor weights_tys, Tensor D_offsets, int64_t total_D, int64_t max_int2_D, int64_t max_int4_D, int64_t max_int8_D, int64_t max_float16_D, int64_t max_float32_D, Tensor indices, Tensor offsets, int64_t pooling_mode, c10::optional<Tensor> indice_weights, int64_t output_dtype, c10::optional<Tensor> lxu_cache_weights, c10::optional<Tensor> lxu_cache_locations, int64_t row_alignment) { if (static_cast<PoolingMode>(pooling_mode) == PoolingMode::NONE) { std::vector<int64_t> max_D_list{ max_int2_D, max_int4_D, max_int8_D, max_float16_D, max_float32_D}; int64_t max_D = *std::max_element(max_D_list.begin(), max_D_list.end()); return int_nbit_split_embedding_nobag_codegen_forward_unweighted_cuda( dev_weights, uvm_weights, weights_placements, weights_offsets, weights_tys, max_D, max_int2_D, max_int4_D, max_int8_D, max_float16_D, max_float32_D, indices, offsets, row_alignment, output_dtype, lxu_cache_weights.value_or(at::empty({0, 0}, at::kByte)), lxu_cache_locations.value_or(at::empty({0}, at::kInt)), 0); } if (!indice_weights) { return int_nbit_split_embedding_codegen_forward_unweighted_cuda( dev_weights, uvm_weights, weights_placements, weights_offsets, weights_tys, D_offsets, total_D, max_int2_D, max_int4_D, max_int8_D, max_float16_D, max_float32_D, indices, offsets, pooling_mode, row_alignment, output_dtype, lxu_cache_weights.value_or(at::empty({0, 0}, at::kByte)), lxu_cache_locations.value_or(at::empty({0}, at::kInt)), 0); } return int_nbit_split_embedding_codegen_forward_weighted_cuda( dev_weights, uvm_weights, weights_placements, weights_offsets, weights_tys, D_offsets, total_D, max_int2_D, max_int4_D, max_int8_D, max_float16_D, max_float32_D, indices, offsets, pooling_mode, row_alignment, *indice_weights, output_dtype, lxu_cache_weights.value_or(at::empty({0, 0}, at::kByte)), lxu_cache_locations.value_or(at::empty({0}, at::kInt)), 0); } Tensor pruned_hashmap_lookup_unweighted_cuda( Tensor indices, Tensor offsets, Tensor hash_table, Tensor hash_table_offsets); Tensor pruned_array_lookup_cuda( Tensor indices, Tensor offsets, Tensor index_remappings, Tensor index_remappings_offsets); TORCH_LIBRARY_FRAGMENT(fbgemm, m) { DISPATCH_TO_CUDA( "int_nbit_split_embedding_codegen_lookup_function", int_nbit_split_embedding_codegen_lookup_function); DISPATCH_TO_CUDA( "pruned_hashmap_lookup", pruned_hashmap_lookup_unweighted_cuda); DISPATCH_TO_CUDA("pruned_array_lookup", pruned_array_lookup_cuda); }