fbgemm_gpu/src/permute_pooled_embedding_ops_gpu.cpp (132 lines of code) (raw):
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/
#include <ATen/ATen.h>
#include <ATen/core/op_registration/op_registration.h>
#include <c10/util/irange.h>
#include <torch/script.h>
#include <vector>
#include "fbgemm_gpu/permute_pooled_embedding_ops.h"
#include "fbgemm_gpu/sparse_ops_utils.h"
using Tensor = at::Tensor;
namespace fbgemm_gpu {
Tensor permute_pooled_embs_cpu(
const Tensor& pooled_embs, // [B_local][Sum_T_global(D)]
const Tensor& offset_dim_list,
const Tensor& permute_list,
const Tensor& inv_offset_dim_list,
const Tensor& inv_permute_list) {
TORCH_CHECK(
offset_dim_list.scalar_type() == at::ScalarType::Long,
"offset_dim_list needs to have long/int64 type")
TORCH_CHECK(
permute_list.scalar_type() == at::ScalarType::Long,
"permute_list needs to have long/int64 type")
auto permute = permute_list.data_ptr<int64_t>();
const auto n = permute_list.numel();
std::vector<int64_t> dims;
dims.reserve(n - 1);
for (const auto i : c10::irange(1, n)) {
dims.push_back(offset_dim_list[i].item<int64_t>());
}
auto ts = pooled_embs.tensor_split(dims, 1);
std::vector<Tensor> permuted_ts;
permuted_ts.reserve(n);
for (const auto i : c10::irange(n)) {
permuted_ts.push_back(ts[permute[i]]);
}
return at::cat(permuted_ts, 1);
}
using torch::autograd::AutogradContext;
using torch::autograd::Variable;
using torch::autograd::variable_list;
template <torch::autograd::Variable (*permute_pooled_embs_op)(
const Tensor&, // [B_local][Sum_T_global(D)]
const Tensor&,
const Tensor&,
const Tensor&,
const Tensor&)>
class PermutePooledEmbsFunction
: public torch::autograd::Function<
PermutePooledEmbsFunction<permute_pooled_embs_op>> {
public:
static Variable forward(
AutogradContext* ctx,
const Tensor& pooled_embs, // [B_local][Sum_T_global(D)]
const Tensor& offset_dim_list,
const Tensor& permute_list,
const Tensor& inv_offset_dim_list,
const Tensor& inv_permute_list) {
ctx->saved_data["offset_dim_list"] = offset_dim_list;
ctx->saved_data["permute_list"] = permute_list;
ctx->saved_data["inv_offset_dim_list"] = inv_offset_dim_list;
ctx->saved_data["inv_permute_list"] = inv_permute_list;
TORCH_CHECK(
offset_dim_list.scalar_type() == at::ScalarType::Long,
"offset_dim_list needs to have long/int64 type");
TORCH_CHECK(
permute_list.scalar_type() == at::ScalarType::Long,
"permute_list needs to have long/int64 type");
return permute_pooled_embs_op(
pooled_embs,
offset_dim_list,
permute_list,
inv_offset_dim_list,
inv_permute_list);
}
static variable_list backward(
AutogradContext* ctx,
variable_list grad_output) {
const auto& offset_dim_list = ctx->saved_data["offset_dim_list"].toTensor();
const auto& permute_list = ctx->saved_data["permute_list"].toTensor();
const auto& inv_offset_dim_list =
ctx->saved_data["inv_offset_dim_list"].toTensor();
const auto& inv_permute_list =
ctx->saved_data["inv_permute_list"].toTensor();
variable_list grad_inputs(5);
grad_inputs[0] = permute_pooled_embs_op(
grad_output[0],
inv_offset_dim_list,
inv_permute_list,
offset_dim_list,
permute_list);
return grad_inputs;
}
};
Tensor permute_pooled_embs_auto_grad_gpu(
const Tensor& pooled_embs,
const Tensor& offset_dim_list,
const Tensor& permute_list,
const Tensor& inv_offset_dim_list,
const Tensor& inv_permute_list) {
return PermutePooledEmbsFunction<permute_pooled_embs_gpu>::apply(
pooled_embs,
offset_dim_list,
permute_list,
inv_offset_dim_list,
inv_permute_list);
}
Tensor permute_pooled_embs_auto_grad_cpu(
const Tensor& pooled_embs,
const Tensor& offset_dim_list,
const Tensor& permute_list,
const Tensor& inv_offset_dim_list,
const Tensor& inv_permute_list) {
return PermutePooledEmbsFunction<permute_pooled_embs_cpu>::apply(
pooled_embs,
offset_dim_list,
permute_list,
inv_offset_dim_list,
inv_permute_list);
}
} // namespace fbgemm_gpu
TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
m.def(
"permute_pooled_embs(Tensor pooled_embs, Tensor offset_dim_list, Tensor permute_list, Tensor inv_offset_dim_list, Tensor inv_permute_list) -> Tensor");
DISPATCH_TO_CUDA("permute_pooled_embs", fbgemm_gpu::permute_pooled_embs_gpu);
m.def(
"permute_pooled_embs_auto_grad(Tensor pooled_embs, Tensor offset_dim_list, Tensor permute_list, Tensor inv_offset_dim_list, Tensor inv_permute_list) -> Tensor");
DISPATCH_TO_CPU(
"permute_pooled_embs_auto_grad",
fbgemm_gpu::permute_pooled_embs_auto_grad_cpu);
DISPATCH_TO_CUDA(
"permute_pooled_embs_auto_grad",
fbgemm_gpu::permute_pooled_embs_auto_grad_gpu);
}