fbgemm_gpu/codegen/embedding_backward_dense_host.cpp (363 lines of code) (raw):
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/
#include <ATen/ATen.h>
#include <ATen/TypeDefault.h>
#include <ATen/core/op_registration/op_registration.h>
#include <torch/script.h>
#include "fbgemm_gpu/embedding_common.h"
#include "fbgemm_gpu/sparse_ops_utils.h"
using Tensor = at::Tensor;
Tensor dense_embedding_codegen_forward_unweighted_cuda(
Tensor dev_weights,
Tensor weights_offsets,
Tensor D_offsets,
int64_t total_D,
int64_t max_D,
Tensor indices,
Tensor offsets,
int64_t pooling_mode,
int64_t BT_block_size);
Tensor dense_embedding_codegen_forward_weighted_cuda(
Tensor dev_weights,
Tensor weights_offsets,
Tensor D_offsets,
int64_t total_D,
int64_t max_D,
Tensor indices,
Tensor offsets,
int64_t pooling_mode,
Tensor indice_weights,
int64_t BT_block_size);
Tensor dense_embedding_codegen_grad_indice_weights_cuda(
Tensor grad_output,
Tensor dev_weights,
Tensor weights_offsets,
Tensor D_offsets,
int64_t max_D,
Tensor indices,
Tensor offsets,
Tensor feature_requires_grad);
Tensor split_embedding_backward_codegen_dense_unweighted_exact_cuda(
Tensor grad_output,
Tensor dev_weights,
Tensor weights_offsets,
Tensor D_offsets,
int64_t max_D,
Tensor hash_size_cumsum,
int64_t total_hash_size_bits,
Tensor indices,
Tensor offsets,
int64_t pooling_mode,
int64_t BT_block_size,
int64_t max_segment_length_per_warp,
double unused);
Tensor split_embedding_backward_codegen_dense_weighted_exact_cuda(
Tensor grad_output,
Tensor dev_weights,
Tensor weights_offsets,
Tensor D_offsets,
int64_t max_D,
Tensor hash_size_cumsum,
int64_t total_hash_size_bits,
Tensor indices,
Tensor offsets,
int64_t pooling_mode,
Tensor indice_weights,
int64_t BT_block_size,
int64_t max_segment_length_per_warp,
double unused);
class SplitLookupFunction_Dense_Op
: public torch::autograd::Function<SplitLookupFunction_Dense_Op> {
public:
static torch::autograd::variable_list forward(
torch::autograd::AutogradContext* ctx,
Tensor dev_weights,
Tensor weights_offsets,
Tensor D_offsets,
int64_t total_D,
int64_t max_D,
Tensor hash_size_cumsum,
int64_t total_hash_size_bits,
Tensor indices,
Tensor offsets,
int64_t pooling_mode,
c10::optional<Tensor> indice_weights,
c10::optional<Tensor> feature_requires_grad) {
ctx->save_for_backward({
dev_weights,
weights_offsets,
D_offsets,
hash_size_cumsum,
indices,
offsets,
indice_weights.value_or(Tensor()),
feature_requires_grad.value_or(Tensor()),
});
ctx->saved_data["total_D"] = total_D;
ctx->saved_data["max_D"] = max_D;
ctx->saved_data["total_hash_size_bits"] = total_hash_size_bits;
ctx->saved_data["pooling_mode"] = pooling_mode;
#ifdef __HIP_PLATFORM_HCC__
constexpr int32_t BT_block_size = 64;
#else
constexpr int32_t BT_block_size = 32;
#endif
if (!indice_weights.has_value()) {
return {dense_embedding_codegen_forward_unweighted_cuda(
dev_weights,
weights_offsets,
D_offsets,
total_D,
max_D,
indices,
offsets,
pooling_mode,
BT_block_size)};
} else {
return {dense_embedding_codegen_forward_weighted_cuda(
dev_weights,
weights_offsets,
D_offsets,
total_D,
max_D,
indices,
offsets,
pooling_mode,
indice_weights.value(),
BT_block_size)};
}
}
static torch::autograd::variable_list backward(
torch::autograd::AutogradContext* ctx,
torch::autograd::variable_list grad_outputs) {
const auto saved = ctx->get_saved_variables();
auto savedItr = std::begin(saved);
auto dev_weights = *savedItr++;
auto weights_offsets = *savedItr++;
auto D_offsets = *savedItr++;
auto hash_size_cumsum = *savedItr++;
auto indices = *savedItr++;
auto offsets = *savedItr++;
auto indice_weights = *savedItr++;
auto feature_requires_grad = *savedItr++;
auto total_D = ctx->saved_data["total_D"].toInt();
auto max_D = ctx->saved_data["max_D"].toInt();
auto total_hash_size_bits = ctx->saved_data["total_hash_size_bits"].toInt();
auto pooling_mode = ctx->saved_data["pooling_mode"].toInt();
TORCH_CHECK(grad_outputs.size() == 1);
#ifdef __HIP_PLATFORM_HCC__
constexpr int32_t BT_block_size = 64;
constexpr int32_t max_segment_length_per_warp = 64;
#else
constexpr int32_t BT_block_size = 32;
constexpr int32_t max_segment_length_per_warp = 32;
#endif
using torch::autograd::Variable;
auto grad_output = grad_outputs[0];
if (reinterpret_cast<uint64_t>(grad_output.data_ptr()) % 16 != 0 ||
grad_output.stride(1) != 1 || grad_output.stride(0) % 4 != 0) {
grad_output = grad_output.contiguous();
}
if (!indice_weights.defined()) {
auto grad_dev_weights =
split_embedding_backward_codegen_dense_unweighted_exact_cuda(
grad_output,
dev_weights,
weights_offsets,
D_offsets,
max_D,
hash_size_cumsum,
total_hash_size_bits,
indices,
offsets,
pooling_mode,
BT_block_size,
max_segment_length_per_warp,
/* unused=*/0.0);
return {
grad_dev_weights,
Variable(), // weights_offsets
Variable(), // D_offsets
Variable(), // total_D
Variable(), // max_D
Variable(), // hash_size_cumsum
Variable(), // total_hash_size_bits
Variable(), // indices
Variable(), // offsets
Variable(), // pooling_mode
Variable(), // indice_weights
Variable(), // feature_requires_grad
};
} else {
auto grad_indice_weights =
dense_embedding_codegen_grad_indice_weights_cuda(
grad_output,
dev_weights,
weights_offsets,
D_offsets,
max_D,
indices,
offsets,
feature_requires_grad);
auto grad_dev_weights =
split_embedding_backward_codegen_dense_weighted_exact_cuda(
grad_output,
dev_weights,
weights_offsets,
D_offsets,
max_D,
hash_size_cumsum,
total_hash_size_bits,
indices,
offsets,
pooling_mode,
indice_weights,
BT_block_size,
max_segment_length_per_warp,
/* unused=*/0.0);
return {
grad_dev_weights,
Variable(), // weights_offsets
Variable(), // D_offsets
Variable(), // total_D
Variable(), // max_D
Variable(), // hash_size_cumsum
Variable(), // total_hash_size_bits
Variable(), // indices
Variable(), // offsets
Variable(), // pooling_mode
grad_indice_weights,
Variable(), // feature_requires_grad
};
}
}
};
/******** nobag ops ********/
Tensor dense_embedding_nobag_codegen_forward_unweighted_cuda(
Tensor dev_weights,
Tensor weights_offsets,
int64_t D,
Tensor indices,
Tensor offsets,
int64_t unused);
Tensor split_embedding_nobag_backward_codegen_dense_unweighted_exact_cuda(
Tensor grad_output,
Tensor dev_weights,
Tensor weights_offsets,
int64_t D,
Tensor hash_size_cumsum,
int64_t total_hash_size_bits,
Tensor indices,
Tensor offsets,
int64_t BT_block_size,
int64_t max_segment_length_per_warp,
double unused);
class SplitNoBagLookupFunction_Dense_Op
: public torch::autograd::Function<SplitNoBagLookupFunction_Dense_Op> {
public:
static torch::autograd::variable_list forward(
torch::autograd::AutogradContext* ctx,
Tensor dev_weights,
Tensor weights_offsets,
int64_t D,
Tensor hash_size_cumsum,
int64_t total_hash_size_bits,
Tensor indices,
Tensor offsets) {
ctx->save_for_backward({
dev_weights,
weights_offsets,
hash_size_cumsum,
indices,
offsets,
});
ctx->saved_data["D"] = D;
ctx->saved_data["total_hash_size_bits"] = total_hash_size_bits;
return {dense_embedding_nobag_codegen_forward_unweighted_cuda(
dev_weights, weights_offsets, D, indices, offsets, 0)};
}
static torch::autograd::variable_list backward(
torch::autograd::AutogradContext* ctx,
torch::autograd::variable_list grad_outputs) {
const auto saved = ctx->get_saved_variables();
auto savedItr = std::begin(saved);
auto dev_weights = *savedItr++;
auto weights_offsets = *savedItr++;
auto hash_size_cumsum = *savedItr++;
auto indices = *savedItr++;
auto offsets = *savedItr++;
auto D = ctx->saved_data["D"].toInt();
auto total_hash_size_bits = ctx->saved_data["total_hash_size_bits"].toInt();
TORCH_CHECK(grad_outputs.size() == 1);
constexpr int32_t BT_block_size = 32;
constexpr int32_t max_segment_length_per_warp = 32;
using torch::autograd::Variable;
auto grad_output = grad_outputs[0];
if (reinterpret_cast<uint64_t>(grad_output.data_ptr()) % 16 != 0 ||
grad_output.stride(1) != 1 || grad_output.stride(0) % 4 != 0) {
grad_output = grad_output.contiguous();
}
auto grad_dev_weights =
split_embedding_nobag_backward_codegen_dense_unweighted_exact_cuda(
grad_output,
dev_weights,
weights_offsets,
D,
hash_size_cumsum,
total_hash_size_bits,
indices,
offsets,
BT_block_size,
max_segment_length_per_warp,
0);
return {
grad_dev_weights, // grad_dev_weights
Variable(), // weights_offsets
Variable(), // D
Variable(), // hash_size_cumsum
Variable(), // total_hash_size_bits
Variable(), // indices
Variable(), // offsets
};
}
};
Tensor split_embedding_codegen_lookup_dense_function(
Tensor dev_weights,
Tensor weights_offsets,
Tensor D_offsets,
int64_t total_D,
int64_t max_D,
Tensor hash_size_cumsum,
int64_t total_hash_size_bits,
Tensor indices,
Tensor offsets,
int64_t pooling_mode,
c10::optional<Tensor> indice_weights,
c10::optional<Tensor> feature_requires_grad) {
if (static_cast<PoolingMode>(pooling_mode) == PoolingMode::NONE) {
return SplitNoBagLookupFunction_Dense_Op::apply(
dev_weights,
weights_offsets,
max_D,
hash_size_cumsum,
total_hash_size_bits,
indices,
offsets)[0];
} else {
return SplitLookupFunction_Dense_Op::apply(
dev_weights,
weights_offsets,
D_offsets,
total_D,
max_D,
hash_size_cumsum,
total_hash_size_bits,
indices,
offsets,
pooling_mode,
indice_weights,
feature_requires_grad)[0];
}
}
TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
DISPATCH_TO_CUDA(
"dense_embedding_codegen_lookup_function",
split_embedding_codegen_lookup_dense_function);
}