static torch::autograd::variable_list backward()

in fbgemm_gpu/codegen/embedding_backward_split_host_template.cpp [221:422]


  static torch::autograd::variable_list backward(
      torch::autograd::AutogradContext* ctx,
      torch::autograd::variable_list grad_outputs) {
    const auto saved = ctx->get_saved_variables();
    auto savedItr = std::begin(saved);
    auto dev_weights = *savedItr++;
    auto uvm_weights = *savedItr++;
    auto lxu_cache_weights = *savedItr++;
    auto weights_placements = *savedItr++;
    auto weights_offsets = *savedItr++;
    {% if not nobag %}
    auto D_offsets = *savedItr++;
    {% endif %}
    auto hash_size_cumsum = *savedItr++;
    auto indices = *savedItr++;
    auto offsets = *savedItr++;
    {% if not nobag %}
    auto indice_weights = *savedItr++;
    auto feature_requires_grad = *savedItr++;
    {% endif %}
    auto lxu_cache_locations = *savedItr++;

    {% for tensor in args.split_saved_tensors %}
    auto {{ tensor }} = *savedItr++;
    {% endfor %}

    {% if not nobag %}
    auto max_D = ctx->saved_data["max_D"].toInt();
    auto pooling_mode = ctx->saved_data["pooling_mode"].toInt();
    {% else %}
    auto D = ctx->saved_data["D"].toInt();
    {% endif %}
    auto total_hash_size_bits = ctx->saved_data["total_hash_size_bits"].toInt();
    auto gradient_clipping = ctx->saved_data["gradient_clipping"].toBool();
    auto max_gradient = ctx->saved_data["max_gradient"].toDouble();
    auto stochastic_rounding = ctx->saved_data["stochastic_rounding"].toBool();

    {% for (var, ivalue_cast) in args.saved_data %}
    auto {{ var }} = ctx->saved_data["{{ var }}"].{{ ivalue_cast }}();
    {% endfor %}

    TORCH_CHECK(grad_outputs.size() == 1);

#ifdef __HIP_PLATFORM_HCC__
    constexpr int32_t BT_block_size = 64;
    constexpr int32_t max_segment_length_per_warp = 64;
#else
    constexpr int32_t BT_block_size = 32;
    constexpr int32_t max_segment_length_per_warp = 32;
#endif
    using torch::autograd::Variable;

    auto grad_output = gradient_clipping ? clamp(grad_outputs[0], -max_gradient, max_gradient) : grad_outputs[0];
    if (reinterpret_cast<uint64_t>(grad_output.data_ptr()) % 16 != 0 ||
        grad_output.stride(1) != 1 ||
        grad_output.stride(0) % 4 != 0) {
        grad_output = grad_output.contiguous();
    }

    {% if not nobag %}
    if (!indice_weights.defined()) {
      split_embedding_backward_codegen_{{ optimizer }}_unweighted_exact_cuda(
          grad_output,
          dev_weights,
          uvm_weights,
          lxu_cache_weights,
          weights_placements,
          weights_offsets,
          D_offsets,
          max_D,
          hash_size_cumsum,
          total_hash_size_bits,
          indices,
          offsets,
          pooling_mode,
          lxu_cache_locations,
          BT_block_size,
          max_segment_length_per_warp,
          stochastic_rounding,
          {{ args.split_function_arg_names | join(", ") }});
      return {
          Tensor(), // placeholder autograd tensor
          Variable(), // output_dtype
          Tensor(), // dev_weights
          Variable(), // uvm_weights
          Variable(), // lxu_cache_weights
          Variable(), // weights_placements
          Variable(), // weights_offsets
          Variable(), // D_offsets
          Variable(), // total_D
          Variable(), // max_D
          Variable(), // hash_size_cumsum
          Variable(), //total_hash_size_bits
          Variable(), // indices
          Variable(), // offsets
          Variable(), // pooling_mode
          Variable(), // indice_weights
          Variable(), // feature_requires_grad
          Variable(), // lxu_cache_locations
          Variable(), // gradient_clipping
          Variable(), // max_gradient
          Variable(), // stochastic_rounding
          {{ args.split_variables | join(", ") }}
      };
    } else {
      auto grad_indice_weights = split_embedding_codegen_grad_indice_weights_cuda(
          grad_output,
          dev_weights,
          uvm_weights,
          lxu_cache_weights,
          weights_placements,
          weights_offsets,
          D_offsets,
          max_D,
          indices,
          offsets,
          lxu_cache_locations,
          feature_requires_grad);
      split_embedding_backward_codegen_{{ optimizer }}_weighted_exact_cuda(
          grad_output,
          dev_weights,
          uvm_weights,
          lxu_cache_weights,
          weights_placements,
          weights_offsets,
          D_offsets,
          max_D,
          hash_size_cumsum,
          total_hash_size_bits,
          indices,
          offsets,
          pooling_mode,
          indice_weights,
          lxu_cache_locations,
          BT_block_size,
          max_segment_length_per_warp,
          stochastic_rounding,
          {{ args.split_function_arg_names | join(", ") }});
      return {
          Tensor(), // placeholder autograd tensor
          Variable(), // output_dtype
          Tensor(), // dev_weights
          Variable(), // uvm_weights
          Variable(), // lxu_cache_weights
          Variable(), // weights_placements
          Variable(), // weights_offsets
          Variable(), // D_offsets
          Variable(), // total_D
          Variable(), // max_D
          Variable(), // hash_size_cumsum
          Variable(), //total_hash_size_bits
          Variable(), // indices
          Variable(), // offsets
          Variable(), // pooling_mode
          grad_indice_weights,
          Variable(), // indice_weights
          Variable(), // feature_requires_grad
          Variable(), // lxu_cache_locations
          Variable(), // gradient_clipping
          Variable(), // max_gradient
          Variable(), // stochastic_rounding
          {{ args.split_variables | join(", ") }}
      };
    }
    {% else %}
    split_embedding_nobag_backward_codegen_{{ optimizer }}_unweighted_exact_cuda(
        grad_output,
        dev_weights,
        uvm_weights,
        lxu_cache_weights,
        weights_placements,
        weights_offsets,
        D,
        hash_size_cumsum,
        total_hash_size_bits,
        indices,
        offsets,
        lxu_cache_locations,
        BT_block_size,
        max_segment_length_per_warp,
        stochastic_rounding,
        {{ args.split_function_arg_names | join(", ") }});
    return {
        Tensor(), // placeholder autograd tensor
        Tensor(), // dev_weights
        Variable(), // uvm_weights
        Variable(), // lxu_cache_weights
        Variable(), // weights_placements
        Variable(), // weights_offsets
        Variable(), // D
        Variable(), // hash_size_cumsum
        Variable(), // total_hash_size_bits
        Variable(), // indices
        Variable(), // offsets
        Variable(), // lxu_cache_locations
        Variable(), // gradient_clipping
        Variable(), // max_gradient
        Variable(), // stochastic_rounding
        {{ args.split_variables | join(", ") }}
    };
    {% endif %}
  }