fbgemm_gpu/codegen/embedding_forward_split_cpu.h (58 lines of code) (raw):
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/
#pragma once
#include <ATen/ATen.h>
#include <ATen/Parallel.h>
#include "fbgemm/Utils.h"
at::Tensor split_embedding_codegen_forward_cpu(
at::Tensor weights,
at::Tensor weights_offsets,
at::Tensor D_offsets,
int64_t total_D,
at::Tensor hash_size_cumsum,
at::Tensor indices,
at::Tensor offsets,
int64_t pooling_mode,
at::Tensor indice_weights,
int64_t output_dtype = 0 /* SparseType.FP32 */);
at::Tensor split_embedding_codegen_grad_indice_weights_cpu(
at::Tensor grad_output,
at::Tensor weights,
at::Tensor weights_offsets,
at::Tensor D_offsets,
at::Tensor indices,
at::Tensor offsets,
at::Tensor feature_requires_grad);
namespace internal {
// A batch of compressed sparse row but each sparse matrix is hyper sparse
// meaning there can be many columns without any non-zeros.
struct BatchedHyperCompressedSparseColumn {
int num_tables; // # of matrices (or tables)
// pointers to the beginning of each table in column_ptr (length T + 1)
int* table_ptr = nullptr;
// pointers to the beginning of each column segment in row_indices
// (length table_ptr[T] + 1)
// For a shared table, a column can have multiple segments, each for a
// feature sharing the table. In this case, the segments will have the
// same column_segment_indices but different column_segment_ids.
int* column_segment_ptr = nullptr;
int* column_segment_indices = nullptr; // length table_ptr[T]
int* column_segment_ids = nullptr; // length table_ptr[T]
int* row_indices = nullptr; // length column_ptr[table_ptr[T]]
float* weights = nullptr; // length column_ptr[table_ptr[T]]
~BatchedHyperCompressedSparseColumn() {
if (table_ptr) {
fbgemm::fbgemmAlignedFree(table_ptr);
}
if (column_segment_ptr) {
fbgemm::fbgemmAlignedFree(column_segment_ptr);
fbgemm::fbgemmAlignedFree(column_segment_indices);
fbgemm::fbgemmAlignedFree(column_segment_ids);
fbgemm::fbgemmAlignedFree(row_indices);
}
if (weights) {
fbgemm::fbgemmAlignedFree(weights);
}
}
};
template <typename scalar_t>
void batched_csr2csc(
BatchedHyperCompressedSparseColumn& batched_csc,
int B,
const at::TensorAccessor<int64_t, 1>& batched_csr_offsets,
const at::TensorAccessor<int64_t, 1>& batched_csr_indices,
const at::TensorAccessor<scalar_t, 1>& batched_csr_weights,
int64_t pooling_mode,
const int* table_to_feature_offset,
int64_t num_embeddings);
} // namespace internal