in fbgemm_gpu/codegen/embedding_backward_split_host_template.cpp [221:422]
static torch::autograd::variable_list backward(
torch::autograd::AutogradContext* ctx,
torch::autograd::variable_list grad_outputs) {
const auto saved = ctx->get_saved_variables();
auto savedItr = std::begin(saved);
auto dev_weights = *savedItr++;
auto uvm_weights = *savedItr++;
auto lxu_cache_weights = *savedItr++;
auto weights_placements = *savedItr++;
auto weights_offsets = *savedItr++;
{% if not nobag %}
auto D_offsets = *savedItr++;
{% endif %}
auto hash_size_cumsum = *savedItr++;
auto indices = *savedItr++;
auto offsets = *savedItr++;
{% if not nobag %}
auto indice_weights = *savedItr++;
auto feature_requires_grad = *savedItr++;
{% endif %}
auto lxu_cache_locations = *savedItr++;
{% for tensor in args.split_saved_tensors %}
auto {{ tensor }} = *savedItr++;
{% endfor %}
{% if not nobag %}
auto max_D = ctx->saved_data["max_D"].toInt();
auto pooling_mode = ctx->saved_data["pooling_mode"].toInt();
{% else %}
auto D = ctx->saved_data["D"].toInt();
{% endif %}
auto total_hash_size_bits = ctx->saved_data["total_hash_size_bits"].toInt();
auto gradient_clipping = ctx->saved_data["gradient_clipping"].toBool();
auto max_gradient = ctx->saved_data["max_gradient"].toDouble();
auto stochastic_rounding = ctx->saved_data["stochastic_rounding"].toBool();
{% for (var, ivalue_cast) in args.saved_data %}
auto {{ var }} = ctx->saved_data["{{ var }}"].{{ ivalue_cast }}();
{% endfor %}
TORCH_CHECK(grad_outputs.size() == 1);
#ifdef __HIP_PLATFORM_HCC__
constexpr int32_t BT_block_size = 64;
constexpr int32_t max_segment_length_per_warp = 64;
#else
constexpr int32_t BT_block_size = 32;
constexpr int32_t max_segment_length_per_warp = 32;
#endif
using torch::autograd::Variable;
auto grad_output = gradient_clipping ? clamp(grad_outputs[0], -max_gradient, max_gradient) : grad_outputs[0];
if (reinterpret_cast<uint64_t>(grad_output.data_ptr()) % 16 != 0 ||
grad_output.stride(1) != 1 ||
grad_output.stride(0) % 4 != 0) {
grad_output = grad_output.contiguous();
}
{% if not nobag %}
if (!indice_weights.defined()) {
split_embedding_backward_codegen_{{ optimizer }}_unweighted_exact_cuda(
grad_output,
dev_weights,
uvm_weights,
lxu_cache_weights,
weights_placements,
weights_offsets,
D_offsets,
max_D,
hash_size_cumsum,
total_hash_size_bits,
indices,
offsets,
pooling_mode,
lxu_cache_locations,
BT_block_size,
max_segment_length_per_warp,
stochastic_rounding,
{{ args.split_function_arg_names | join(", ") }});
return {
Tensor(), // placeholder autograd tensor
Variable(), // output_dtype
Tensor(), // dev_weights
Variable(), // uvm_weights
Variable(), // lxu_cache_weights
Variable(), // weights_placements
Variable(), // weights_offsets
Variable(), // D_offsets
Variable(), // total_D
Variable(), // max_D
Variable(), // hash_size_cumsum
Variable(), //total_hash_size_bits
Variable(), // indices
Variable(), // offsets
Variable(), // pooling_mode
Variable(), // indice_weights
Variable(), // feature_requires_grad
Variable(), // lxu_cache_locations
Variable(), // gradient_clipping
Variable(), // max_gradient
Variable(), // stochastic_rounding
{{ args.split_variables | join(", ") }}
};
} else {
auto grad_indice_weights = split_embedding_codegen_grad_indice_weights_cuda(
grad_output,
dev_weights,
uvm_weights,
lxu_cache_weights,
weights_placements,
weights_offsets,
D_offsets,
max_D,
indices,
offsets,
lxu_cache_locations,
feature_requires_grad);
split_embedding_backward_codegen_{{ optimizer }}_weighted_exact_cuda(
grad_output,
dev_weights,
uvm_weights,
lxu_cache_weights,
weights_placements,
weights_offsets,
D_offsets,
max_D,
hash_size_cumsum,
total_hash_size_bits,
indices,
offsets,
pooling_mode,
indice_weights,
lxu_cache_locations,
BT_block_size,
max_segment_length_per_warp,
stochastic_rounding,
{{ args.split_function_arg_names | join(", ") }});
return {
Tensor(), // placeholder autograd tensor
Variable(), // output_dtype
Tensor(), // dev_weights
Variable(), // uvm_weights
Variable(), // lxu_cache_weights
Variable(), // weights_placements
Variable(), // weights_offsets
Variable(), // D_offsets
Variable(), // total_D
Variable(), // max_D
Variable(), // hash_size_cumsum
Variable(), //total_hash_size_bits
Variable(), // indices
Variable(), // offsets
Variable(), // pooling_mode
grad_indice_weights,
Variable(), // indice_weights
Variable(), // feature_requires_grad
Variable(), // lxu_cache_locations
Variable(), // gradient_clipping
Variable(), // max_gradient
Variable(), // stochastic_rounding
{{ args.split_variables | join(", ") }}
};
}
{% else %}
split_embedding_nobag_backward_codegen_{{ optimizer }}_unweighted_exact_cuda(
grad_output,
dev_weights,
uvm_weights,
lxu_cache_weights,
weights_placements,
weights_offsets,
D,
hash_size_cumsum,
total_hash_size_bits,
indices,
offsets,
lxu_cache_locations,
BT_block_size,
max_segment_length_per_warp,
stochastic_rounding,
{{ args.split_function_arg_names | join(", ") }});
return {
Tensor(), // placeholder autograd tensor
Tensor(), // dev_weights
Variable(), // uvm_weights
Variable(), // lxu_cache_weights
Variable(), // weights_placements
Variable(), // weights_offsets
Variable(), // D
Variable(), // hash_size_cumsum
Variable(), // total_hash_size_bits
Variable(), // indices
Variable(), // offsets
Variable(), // lxu_cache_locations
Variable(), // gradient_clipping
Variable(), // max_gradient
Variable(), // stochastic_rounding
{{ args.split_variables | join(", ") }}
};
{% endif %}
}