fbgemm_gpu/codegen/embedding_backward_split_host_template.cpp (467 lines of code) (raw):
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/
// clang-format off
#include <ATen/ATen.h>
#include <ATen/TypeDefault.h>
#include <ATen/core/op_registration/op_registration.h>
#include <torch/script.h>
#include "fbgemm_gpu/embedding_common.h"
#include "fbgemm_gpu/sparse_ops_utils.h"
using Tensor = at::Tensor;
Tensor split_embedding_codegen_forward_unweighted_cuda(
Tensor dev_weights,
Tensor uvm_weights,
Tensor lxu_cache_weights,
Tensor weights_placements,
Tensor weights_offsets,
Tensor D_offsets,
int64_t total_D,
int64_t max_D,
Tensor indices,
Tensor offsets,
int64_t pooling_mode,
Tensor lxu_cache_locations,
int64_t output_dtype,
int64_t BT_block_size);
Tensor split_embedding_codegen_forward_weighted_cuda(
Tensor dev_weights,
Tensor uvm_weights,
Tensor lxu_cache_weights,
Tensor weights_placements,
Tensor weights_offsets,
Tensor D_offsets,
int64_t total_D,
int64_t max_D,
Tensor indices,
Tensor offsets,
int64_t pooling_mode,
Tensor indice_weights,
Tensor lxu_cache_locations,
int64_t output_dtype,
int64_t BT_block_size);
Tensor split_embedding_codegen_grad_indice_weights_cuda(
Tensor grad_output,
Tensor dev_weights,
Tensor uvm_weights,
Tensor lxu_cache_weights,
Tensor weights_placements,
Tensor weights_offsets,
Tensor D_offsets,
int64_t max_D,
Tensor indices,
Tensor offsets,
Tensor lxu_cache_locations,
Tensor feature_requires_grad);
void split_embedding_backward_codegen_{{ optimizer }}_unweighted_exact_cuda(
Tensor grad_output,
Tensor dev_weights,
Tensor uvm_weights,
Tensor lxu_cache_weights,
Tensor weights_placements,
Tensor weights_offsets,
Tensor D_offsets,
int64_t max_D,
Tensor hash_size_cumsum,
int64_t total_hash_size_bits,
Tensor indices,
Tensor offsets,
int64_t pooling_mode,
Tensor lxu_cache_locations,
int64_t BT_block_size,
int64_t max_segment_length_per_warp,
bool stochastic_rounding,
{{ args.split_function_args | join(", ") }});
void split_embedding_backward_codegen_{{ optimizer }}_weighted_exact_cuda(
Tensor grad_output,
Tensor dev_weights,
Tensor uvm_weights,
Tensor lxu_cache_weights,
Tensor weights_placements,
Tensor weights_offsets,
Tensor D_offsets,
int64_t max_D,
Tensor hash_size_cumsum,
int64_t total_hash_size_bits,
Tensor indices,
Tensor offsets,
int64_t pooling_mode,
Tensor indice_weights,
Tensor lxu_cache_locations,
int64_t BT_block_size,
int64_t max_segment_length_per_warp,
bool stochastic_rounding,
{{ args.split_function_args | join(", ") }});
Tensor split_embedding_nobag_codegen_forward_unweighted_cuda(
Tensor dev_weights,
Tensor uvm_weights,
Tensor lxu_cache_weights,
Tensor weights_placements,
Tensor weights_offsets,
int64_t D,
Tensor indices,
Tensor offsets,
Tensor lxu_cache_locations,
int64_t unused);
void split_embedding_nobag_backward_codegen_{{ optimizer }}_unweighted_exact_cuda(
Tensor grad_output,
Tensor dev_weights,
Tensor uvm_weights,
Tensor lxu_cache_weights,
Tensor weights_placements,
Tensor weights_offsets,
int64_t D,
Tensor hash_size_cumsum,
int64_t total_hash_size_bits,
Tensor indices,
Tensor offsets,
Tensor lxu_cache_locations,
int64_t BT_block_size,
int64_t max_segment_length_per_warp,
bool stochastic_rounding,
{{ args.split_function_args | join(", ") }});
{% for nobag in [True, False] %}
class Split{{ "NoBag" if nobag else "" }}LookupFunction_{{ optimizer }}_Op :
public torch::autograd::Function<Split{{ "NoBag" if nobag else "" }}LookupFunction_{{ optimizer }}_Op> {
public:
static torch::autograd::variable_list forward(
torch::autograd::AutogradContext* ctx,
Tensor placeholder_autograd_tensor,
{% if not nobag %}
int64_t output_dtype,
{% endif %}
Tensor dev_weights,
Tensor uvm_weights,
Tensor lxu_cache_weights,
Tensor weights_placements,
Tensor weights_offsets,
{% if not nobag %}
Tensor D_offsets,
int64_t total_D,
int64_t max_D,
{% else %}
int64_t D,
{% endif %}
Tensor hash_size_cumsum,
int64_t total_hash_size_bits,
Tensor indices,
Tensor offsets,
{% if not nobag %}
int64_t pooling_mode,
c10::optional<Tensor> indice_weights,
c10::optional<Tensor> feature_requires_grad,
{% endif %}
Tensor lxu_cache_locations,
bool gradient_clipping,
double max_gradient,
bool stochastic_rounding,
{{ args.split_function_args | join(", ") }}) {
ctx->save_for_backward({
dev_weights, uvm_weights, lxu_cache_weights, weights_placements, weights_offsets, {% if not nobag %} D_offsets, {% endif %} hash_size_cumsum,
indices, offsets, {% if not nobag %} indice_weights.value_or(Tensor()), feature_requires_grad.value_or(Tensor()), {% endif %} lxu_cache_locations, {{ args.split_saved_tensors | join(", ") }} });
{% if not nobag %}
ctx->saved_data["max_D"] = max_D;
ctx->saved_data["pooling_mode"] = pooling_mode;
{% else %}
ctx->saved_data["D"] = D;
{% endif %}
ctx->saved_data["total_hash_size_bits"] = total_hash_size_bits;
ctx->saved_data["gradient_clipping"] = gradient_clipping;
ctx->saved_data["max_gradient"] = max_gradient;
ctx->saved_data["stochastic_rounding"] = stochastic_rounding;
{% for (var, _) in args.saved_data %}
ctx->saved_data["{{ var }}"] = {{ var }};
{% endfor %}
{% if not nobag %}
#ifdef __HIP_PLATFORM_HCC__
constexpr int32_t BT_block_size = 64;
#else
constexpr int32_t BT_block_size = 32;
#endif
if (!indice_weights) {
return {split_embedding_codegen_forward_unweighted_cuda(
dev_weights, uvm_weights, lxu_cache_weights, weights_placements, weights_offsets,
D_offsets, total_D, max_D, indices, offsets, pooling_mode, lxu_cache_locations, output_dtype, BT_block_size)};
} else {
return {split_embedding_codegen_forward_weighted_cuda(
dev_weights, uvm_weights, lxu_cache_weights, weights_placements, weights_offsets,
D_offsets, total_D, max_D, indices, offsets, pooling_mode, *indice_weights, lxu_cache_locations, output_dtype, BT_block_size)};
}
{% else %}
return {split_embedding_nobag_codegen_forward_unweighted_cuda(
dev_weights,
uvm_weights,
lxu_cache_weights,
weights_placements,
weights_offsets,
D,
indices,
offsets,
lxu_cache_locations,
0)};
{% endif %}
}
static torch::autograd::variable_list backward(
torch::autograd::AutogradContext* ctx,
torch::autograd::variable_list grad_outputs) {
const auto saved = ctx->get_saved_variables();
auto savedItr = std::begin(saved);
auto dev_weights = *savedItr++;
auto uvm_weights = *savedItr++;
auto lxu_cache_weights = *savedItr++;
auto weights_placements = *savedItr++;
auto weights_offsets = *savedItr++;
{% if not nobag %}
auto D_offsets = *savedItr++;
{% endif %}
auto hash_size_cumsum = *savedItr++;
auto indices = *savedItr++;
auto offsets = *savedItr++;
{% if not nobag %}
auto indice_weights = *savedItr++;
auto feature_requires_grad = *savedItr++;
{% endif %}
auto lxu_cache_locations = *savedItr++;
{% for tensor in args.split_saved_tensors %}
auto {{ tensor }} = *savedItr++;
{% endfor %}
{% if not nobag %}
auto max_D = ctx->saved_data["max_D"].toInt();
auto pooling_mode = ctx->saved_data["pooling_mode"].toInt();
{% else %}
auto D = ctx->saved_data["D"].toInt();
{% endif %}
auto total_hash_size_bits = ctx->saved_data["total_hash_size_bits"].toInt();
auto gradient_clipping = ctx->saved_data["gradient_clipping"].toBool();
auto max_gradient = ctx->saved_data["max_gradient"].toDouble();
auto stochastic_rounding = ctx->saved_data["stochastic_rounding"].toBool();
{% for (var, ivalue_cast) in args.saved_data %}
auto {{ var }} = ctx->saved_data["{{ var }}"].{{ ivalue_cast }}();
{% endfor %}
TORCH_CHECK(grad_outputs.size() == 1);
#ifdef __HIP_PLATFORM_HCC__
constexpr int32_t BT_block_size = 64;
constexpr int32_t max_segment_length_per_warp = 64;
#else
constexpr int32_t BT_block_size = 32;
constexpr int32_t max_segment_length_per_warp = 32;
#endif
using torch::autograd::Variable;
auto grad_output = gradient_clipping ? clamp(grad_outputs[0], -max_gradient, max_gradient) : grad_outputs[0];
if (reinterpret_cast<uint64_t>(grad_output.data_ptr()) % 16 != 0 ||
grad_output.stride(1) != 1 ||
grad_output.stride(0) % 4 != 0) {
grad_output = grad_output.contiguous();
}
{% if not nobag %}
if (!indice_weights.defined()) {
split_embedding_backward_codegen_{{ optimizer }}_unweighted_exact_cuda(
grad_output,
dev_weights,
uvm_weights,
lxu_cache_weights,
weights_placements,
weights_offsets,
D_offsets,
max_D,
hash_size_cumsum,
total_hash_size_bits,
indices,
offsets,
pooling_mode,
lxu_cache_locations,
BT_block_size,
max_segment_length_per_warp,
stochastic_rounding,
{{ args.split_function_arg_names | join(", ") }});
return {
Tensor(), // placeholder autograd tensor
Variable(), // output_dtype
Tensor(), // dev_weights
Variable(), // uvm_weights
Variable(), // lxu_cache_weights
Variable(), // weights_placements
Variable(), // weights_offsets
Variable(), // D_offsets
Variable(), // total_D
Variable(), // max_D
Variable(), // hash_size_cumsum
Variable(), //total_hash_size_bits
Variable(), // indices
Variable(), // offsets
Variable(), // pooling_mode
Variable(), // indice_weights
Variable(), // feature_requires_grad
Variable(), // lxu_cache_locations
Variable(), // gradient_clipping
Variable(), // max_gradient
Variable(), // stochastic_rounding
{{ args.split_variables | join(", ") }}
};
} else {
auto grad_indice_weights = split_embedding_codegen_grad_indice_weights_cuda(
grad_output,
dev_weights,
uvm_weights,
lxu_cache_weights,
weights_placements,
weights_offsets,
D_offsets,
max_D,
indices,
offsets,
lxu_cache_locations,
feature_requires_grad);
split_embedding_backward_codegen_{{ optimizer }}_weighted_exact_cuda(
grad_output,
dev_weights,
uvm_weights,
lxu_cache_weights,
weights_placements,
weights_offsets,
D_offsets,
max_D,
hash_size_cumsum,
total_hash_size_bits,
indices,
offsets,
pooling_mode,
indice_weights,
lxu_cache_locations,
BT_block_size,
max_segment_length_per_warp,
stochastic_rounding,
{{ args.split_function_arg_names | join(", ") }});
return {
Tensor(), // placeholder autograd tensor
Variable(), // output_dtype
Tensor(), // dev_weights
Variable(), // uvm_weights
Variable(), // lxu_cache_weights
Variable(), // weights_placements
Variable(), // weights_offsets
Variable(), // D_offsets
Variable(), // total_D
Variable(), // max_D
Variable(), // hash_size_cumsum
Variable(), //total_hash_size_bits
Variable(), // indices
Variable(), // offsets
Variable(), // pooling_mode
grad_indice_weights,
Variable(), // indice_weights
Variable(), // feature_requires_grad
Variable(), // lxu_cache_locations
Variable(), // gradient_clipping
Variable(), // max_gradient
Variable(), // stochastic_rounding
{{ args.split_variables | join(", ") }}
};
}
{% else %}
split_embedding_nobag_backward_codegen_{{ optimizer }}_unweighted_exact_cuda(
grad_output,
dev_weights,
uvm_weights,
lxu_cache_weights,
weights_placements,
weights_offsets,
D,
hash_size_cumsum,
total_hash_size_bits,
indices,
offsets,
lxu_cache_locations,
BT_block_size,
max_segment_length_per_warp,
stochastic_rounding,
{{ args.split_function_arg_names | join(", ") }});
return {
Tensor(), // placeholder autograd tensor
Tensor(), // dev_weights
Variable(), // uvm_weights
Variable(), // lxu_cache_weights
Variable(), // weights_placements
Variable(), // weights_offsets
Variable(), // D
Variable(), // hash_size_cumsum
Variable(), // total_hash_size_bits
Variable(), // indices
Variable(), // offsets
Variable(), // lxu_cache_locations
Variable(), // gradient_clipping
Variable(), // max_gradient
Variable(), // stochastic_rounding
{{ args.split_variables | join(", ") }}
};
{% endif %}
}
};
{% endfor %}
Tensor split_embedding_codegen_lookup_{{ optimizer }}_function(
Tensor placeholder_autograd_tensor,
Tensor dev_weights,
Tensor uvm_weights,
Tensor lxu_cache_weights,
Tensor weights_placements,
Tensor weights_offsets,
Tensor D_offsets,
int64_t total_D,
int64_t max_D,
Tensor hash_size_cumsum,
int64_t total_hash_size_bits,
Tensor indices,
Tensor offsets,
int64_t pooling_mode,
c10::optional<Tensor> indice_weights,
c10::optional<Tensor> feature_requires_grad,
Tensor lxu_cache_locations,
bool gradient_clipping,
double max_gradient,
bool stochastic_rounding,
{{ args.split_function_args | join(", ") }},
int64_t output_dtype) {
if (static_cast<PoolingMode>(pooling_mode) == PoolingMode::NONE) {
return SplitNoBagLookupFunction_{{ optimizer }}_Op::apply(
placeholder_autograd_tensor,
dev_weights,
uvm_weights,
lxu_cache_weights,
weights_placements,
weights_offsets,
max_D,
hash_size_cumsum,
total_hash_size_bits,
indices,
offsets,
lxu_cache_locations,
gradient_clipping,
max_gradient,
stochastic_rounding,
{{ args.split_function_arg_names | join(", ") }})[0];
} else {
return SplitLookupFunction_{{ optimizer }}_Op::apply(
placeholder_autograd_tensor,
output_dtype,
dev_weights,
uvm_weights,
lxu_cache_weights,
weights_placements,
weights_offsets,
D_offsets,
total_D,
max_D,
hash_size_cumsum,
total_hash_size_bits,
indices,
offsets,
pooling_mode,
indice_weights,
feature_requires_grad,
lxu_cache_locations,
gradient_clipping,
max_gradient,
stochastic_rounding,
{{ args.split_function_arg_names | join(", ") }})[0];
}
}
TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
m.def("split_embedding_codegen_lookup_{{ optimizer }}_function(Tensor placeholder_autograd_tensor, Tensor dev_weights, Tensor uvm_weights, Tensor lxu_cache_weights, Tensor weights_placements, Tensor weights_offsets, Tensor D_offsets, int total_D, int max_D, Tensor hash_size_cumsum, int total_hash_size_bits, Tensor indices, Tensor offsets, int pooling_mode, Tensor? indice_weights, Tensor? feature_requires_grad, Tensor lxu_cache_locations, bool gradient_clipping, float max_gradient, bool stochastic_rounding, {{ args.split_function_schemas | join(", ") }}, int output_dtype=0) -> Tensor");
DISPATCH_TO_CUDA("split_embedding_codegen_lookup_{{ optimizer }}_function", split_embedding_codegen_lookup_{{ optimizer }}_function);
}
// clang-format on