fbgemm_gpu/codegen/embedding_backward_dense_host_cpu.cpp (164 lines of code) (raw):
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/
#include <ATen/ATen.h>
#include <ATen/core/op_registration/op_registration.h>
#include <torch/script.h>
#include "codegen/embedding_forward_split_cpu.h"
#include "fbgemm_gpu/embedding_common.h"
#include "fbgemm_gpu/sparse_ops_utils.h"
using Tensor = at::Tensor;
Tensor split_embedding_backward_codegen_dense_cpu(
Tensor grad_output,
Tensor host_weights,
Tensor weights_offsets,
Tensor D_offsets,
int64_t max_D,
Tensor hash_size_cumsum,
int64_t total_hash_size_bits,
Tensor indices,
Tensor offsets,
int64_t pooling_mode,
Tensor indice_weights,
double unused);
namespace {
class SplitLookupFunction_Dense_Op
: public torch::autograd::Function<SplitLookupFunction_Dense_Op> {
public:
static torch::autograd::variable_list forward(
torch::autograd::AutogradContext* ctx,
Tensor host_weights,
Tensor weights_offsets,
Tensor D_offsets,
int64_t total_D,
int64_t max_D,
Tensor hash_size_cumsum,
int64_t total_hash_size_bits,
Tensor indices,
Tensor offsets,
int64_t pooling_mode,
c10::optional<Tensor> indice_weights,
c10::optional<Tensor> feature_requires_grad) {
Tensor indice_weights_value = indice_weights.value_or(Tensor());
Tensor feature_requires_grad_value =
feature_requires_grad.value_or(Tensor());
ctx->save_for_backward({
host_weights,
weights_offsets,
D_offsets,
hash_size_cumsum,
indices,
offsets,
indice_weights_value,
feature_requires_grad_value,
});
ctx->saved_data["total_D"] = total_D;
ctx->saved_data["max_D"] = max_D;
ctx->saved_data["total_hash_size_bits"] = total_hash_size_bits;
ctx->saved_data["pooling_mode"] = pooling_mode;
int64_t output_dtype = -1 /* double */;
if (host_weights.scalar_type() == at::kHalf ||
host_weights.scalar_type() == at::ScalarType::Byte) {
output_dtype = static_cast<int64_t>(SparseType::FP32);
}
return {split_embedding_codegen_forward_cpu(
host_weights,
weights_offsets,
D_offsets,
total_D,
hash_size_cumsum,
indices,
offsets,
pooling_mode,
indice_weights_value,
output_dtype)};
}
static torch::autograd::variable_list backward(
torch::autograd::AutogradContext* ctx,
torch::autograd::variable_list grad_outputs) {
const auto saved = ctx->get_saved_variables();
auto savedItr = std::begin(saved);
auto host_weights = *savedItr++;
auto weights_offsets = *savedItr++;
auto D_offsets = *savedItr++;
auto hash_size_cumsum = *savedItr++;
auto indices = *savedItr++;
auto offsets = *savedItr++;
auto indice_weights = *savedItr++;
auto feature_requires_grad = *savedItr++;
auto max_D = ctx->saved_data["max_D"].toInt();
auto total_hash_size_bits = ctx->saved_data["total_hash_size_bits"].toInt();
auto pooling_mode = ctx->saved_data["pooling_mode"].toInt();
TORCH_CHECK(grad_outputs.size() == 1);
using torch::autograd::Variable;
auto grad_host_weights = split_embedding_backward_codegen_dense_cpu(
grad_outputs[0],
host_weights,
weights_offsets,
D_offsets,
max_D,
hash_size_cumsum,
total_hash_size_bits,
indices,
offsets,
pooling_mode,
indice_weights,
/* unused=*/0.0);
// NOTE: MEAN pooling will not work with indice_weights!
auto grad_indice_weights = indice_weights.defined()
? split_embedding_codegen_grad_indice_weights_cpu(
grad_outputs[0],
host_weights,
weights_offsets,
D_offsets,
indices,
offsets,
feature_requires_grad)
: Variable();
return {
grad_host_weights,
Variable(), // weights_offsets
Variable(), // D_offsets
Variable(), // total_D
Variable(), // max_D
Variable(), // hash_size_cumsum
Variable(), // total_hash_size_bits
Variable(), // indices
Variable(), // offsets
Variable(), // pooling_mode
grad_indice_weights,
Variable(), // feature_requires_grad
};
}
};
Tensor split_embedding_codegen_lookup_dense_function(
Tensor host_weights,
Tensor weights_offsets,
Tensor D_offsets,
int64_t total_D,
int64_t max_D,
Tensor hash_size_cumsum,
int64_t total_hash_size_bits,
Tensor indices,
Tensor offsets,
int64_t pooling_mode,
c10::optional<Tensor> indice_weights,
c10::optional<Tensor> feature_requires_grad) {
return SplitLookupFunction_Dense_Op::apply(
host_weights,
weights_offsets,
D_offsets,
total_D,
max_D,
hash_size_cumsum,
total_hash_size_bits,
indices,
offsets,
pooling_mode,
indice_weights,
feature_requires_grad)[0];
}
TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
m.def(
"dense_embedding_codegen_lookup_function(Tensor dev_weights, Tensor weights_offsets, Tensor D_offsets, int total_D, int max_D, Tensor hash_size_cumsum, int total_hash_size_bits, Tensor indices, Tensor offsets, int pooling_mode, Tensor? indice_weights, Tensor? feature_requires_grad) -> Tensor");
DISPATCH_TO_CPU(
"dense_embedding_codegen_lookup_function",
split_embedding_codegen_lookup_dense_function);
}
} // namespace