fbgemm_gpu/src/input_combine_cpu.cpp (278 lines of code) (raw):
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/
#include "fbgemm_gpu/input_combine.h"
#include "fbgemm_gpu/sparse_ops_utils.h"
#include <ATen/ATen.h>
#include <ATen/Context.h>
#include <ATen/Dispatch.h>
#include <ATen/Functions.h>
#include <ATen/TypeDefault.h>
#include <ATen/core/op_registration/op_registration.h>
#include <c10/core/ScalarType.h>
#include <c10/core/TensorOptions.h>
#include <c10/util/Exception.h>
#include <torch/script.h>
using Tensor = at::Tensor;
namespace fbgemm_gpu {
Tensor _cat_int_tensors(
const std::vector<Tensor>& tensor_list,
int64_t total_num,
bool use_pin_memory) {
auto combined_tensors = at::empty(
{total_num},
at::TensorOptions()
.dtype(c10::kInt)
.device(tensor_list[0].device())
.pinned_memory(use_pin_memory));
auto combined_tensors_acc = combined_tensors.accessor<int32_t, 1>();
size_t idx = 0;
for (size_t i = 0; i < tensor_list.size(); i++) {
AT_DISPATCH_INDEX_TYPES(
tensor_list[i].scalar_type(), "tbe_cat_inputs_", [&] {
auto indices_acc = tensor_list[i].accessor<index_t, 1>();
for (auto j = 0; j < tensor_list[i].numel(); j++) {
combined_tensors_acc[idx++] = static_cast<int32_t>(indices_acc[j]);
}
});
}
return combined_tensors;
}
Tensor _cat_per_sample_weights_list(
const std::vector<Tensor>& per_sample_weights,
const std::vector<Tensor>& indices_list,
int64_t total_num,
bool use_pin_memory) {
auto combined_weights = at::ones(
{total_num},
at::TensorOptions()
.dtype(c10::kFloat)
.device(per_sample_weights[0].device())
.pinned_memory(use_pin_memory));
auto* combined_weights_ptr = combined_weights.data_ptr<float>();
for (size_t i = 0; i < per_sample_weights.size(); i++) {
auto element_size = per_sample_weights[i].numel();
if (element_size != 0) {
memcpy(
combined_weights_ptr,
per_sample_weights[i].data_ptr<float>(),
element_size * sizeof(float));
}
combined_weights_ptr += indices_list[i].numel();
}
return combined_weights;
}
std::tuple<Tensor, Tensor, Tensor> tbe_input_combine_cpu(
const std::vector<Tensor>& indices_list,
const std::vector<Tensor>& offsets_list,
const std::vector<Tensor>& per_sample_weights,
const Tensor& include_last_offsets) {
TORCH_CHECK(indices_list.size() > 0);
TORCH_CHECK(offsets_list.size() == indices_list.size());
TORCH_CHECK(per_sample_weights.size() == indices_list.size());
TORCH_CHECK(
static_cast<uint64_t>(include_last_offsets.numel()) ==
indices_list.size());
auto include_last_offsets_acc = include_last_offsets.accessor<bool, 1>();
int64_t total_indices = 0;
int64_t total_offsets = 1;
bool need_weights = false;
bool pin_memory = false;
if (at::Context::hasCUDA() && at::getNumGPUs() > 0) {
pin_memory = true;
}
for (size_t i = 0; i < indices_list.size(); i++) {
TORCH_CHECK(
indices_list[i].dtype() == c10::kInt ||
indices_list[i].dtype() == c10::kLong);
TORCH_CHECK(
offsets_list[i].dtype() == c10::kInt ||
offsets_list[i].dtype() == c10::kLong);
TORCH_CHECK(indices_list[i].ndimension() == 1);
TORCH_CHECK(offsets_list[i].ndimension() == 1);
TORCH_CHECK(indices_list[i].is_contiguous());
TORCH_CHECK(offsets_list[i].is_contiguous());
total_indices += indices_list[i].numel();
auto num_offset =
offsets_list[i].numel() - (include_last_offsets_acc[i] ? 1 : 0);
total_offsets += num_offset == 0 ? 1 : num_offset;
if (per_sample_weights[i].numel() > 0) {
TORCH_CHECK(per_sample_weights[i].ndimension() == 1);
TORCH_CHECK(per_sample_weights[i].numel() == indices_list[i].numel());
TORCH_CHECK(per_sample_weights[i].is_contiguous());
need_weights = true;
}
}
auto combined_indices =
_cat_int_tensors(indices_list, total_indices, pin_memory);
auto combined_offsets = at::empty(
{total_offsets},
at::TensorOptions()
.dtype(c10::kInt)
.device(offsets_list[0].device())
.pinned_memory(pin_memory));
auto combined_offsets_acc = combined_offsets.accessor<int32_t, 1>();
int32_t offset = 0;
size_t offsets_acc_idx = 0;
combined_offsets_acc[offsets_acc_idx++] = 0;
for (size_t i = 0; i < offsets_list.size(); i++) {
AT_DISPATCH_INDEX_TYPES(
offsets_list[i].scalar_type(), "tbe_input_offsets_", [&] {
auto offsets_acc = offsets_list[i].accessor<index_t, 1>();
for (int64_t j = 1,
size = offsets_list[i].numel() -
(include_last_offsets_acc[i] ? 1 : 0);
j < size;
j++) {
combined_offsets_acc[offsets_acc_idx++] =
offset + static_cast<int32_t>(offsets_acc[j]);
}
offset += static_cast<int32_t>(indices_list[i].numel());
combined_offsets_acc[offsets_acc_idx++] = offset;
});
}
if (need_weights) {
return {
std::move(combined_indices),
std::move(combined_offsets),
_cat_per_sample_weights_list(
per_sample_weights, indices_list, total_indices, pin_memory)};
}
return {combined_indices, combined_offsets, at::empty({0})};
}
std::tuple<Tensor, Tensor, Tensor> tbe_input_combine_with_length_cpu(
const std::vector<Tensor>& indices_list,
const std::vector<Tensor>& lengths_list,
const std::vector<Tensor>& per_sample_weights) {
TORCH_CHECK(indices_list.size() > 0);
TORCH_CHECK(lengths_list.size() == indices_list.size());
TORCH_CHECK(per_sample_weights.size() == indices_list.size());
int64_t total_indices = 0;
int64_t total_lengths = 0;
bool need_weights = false;
bool pin_memory = false;
if (at::Context::hasCUDA() && at::getNumGPUs() > 0) {
pin_memory = true;
}
for (size_t i = 0; i < indices_list.size(); i++) {
TORCH_CHECK(
indices_list[i].dtype() == c10::kInt ||
indices_list[i].dtype() == c10::kLong);
TORCH_CHECK(
lengths_list[i].dtype() == c10::kInt ||
lengths_list[i].dtype() == c10::kLong);
TORCH_CHECK(indices_list[i].ndimension() == 1);
TORCH_CHECK(lengths_list[i].ndimension() == 1);
TORCH_CHECK(indices_list[i].is_contiguous());
TORCH_CHECK(lengths_list[i].is_contiguous());
total_indices += indices_list[i].numel();
total_lengths += lengths_list[i].numel();
if (per_sample_weights[i].numel() > 0) {
TORCH_CHECK(per_sample_weights[i].ndimension() == 1);
TORCH_CHECK(per_sample_weights[i].numel() == indices_list[i].numel());
TORCH_CHECK(per_sample_weights[i].is_contiguous());
need_weights = true;
}
}
auto combined_indices =
_cat_int_tensors(indices_list, total_indices, pin_memory);
auto combined_lengths =
_cat_int_tensors(lengths_list, total_lengths, pin_memory);
if (need_weights) {
return {
std::move(combined_indices),
std::move(combined_lengths),
_cat_per_sample_weights_list(
per_sample_weights, indices_list, total_indices, pin_memory)};
}
return {combined_indices, combined_lengths, at::empty({0})};
}
// Similar to tbe_input_combine_cpu, but padding all the offsets
// to the size specified by batch_size.
std::tuple<Tensor, Tensor, Tensor> padding_fused_tbe_input_combine_cpu(
const std::vector<Tensor>& indices_list,
const std::vector<Tensor>& offsets_list,
const std::vector<Tensor>& per_sample_weights,
const Tensor& include_last_offsets,
int64_t batch_size) {
TORCH_CHECK(indices_list.size() > 0);
TORCH_CHECK(offsets_list.size() == indices_list.size());
TORCH_CHECK(per_sample_weights.size() == indices_list.size());
TORCH_CHECK(
static_cast<uint64_t>(include_last_offsets.numel()) ==
indices_list.size());
auto include_last_offsets_acc = include_last_offsets.accessor<bool, 1>();
int64_t total_indices = 0;
int64_t total_offsets = 1 + batch_size * indices_list.size();
bool need_weights = false;
bool pin_memory = false;
if (at::Context::hasCUDA() && at::getNumGPUs() > 0) {
pin_memory = true;
}
for (size_t i = 0; i < indices_list.size(); i++) {
TORCH_CHECK(
indices_list[i].dtype() == c10::kInt ||
indices_list[i].dtype() == c10::kLong);
TORCH_CHECK(
offsets_list[i].dtype() == c10::kInt ||
offsets_list[i].dtype() == c10::kLong);
TORCH_CHECK(indices_list[i].ndimension() == 1);
TORCH_CHECK(offsets_list[i].ndimension() == 1);
TORCH_CHECK(indices_list[i].is_contiguous());
TORCH_CHECK(offsets_list[i].is_contiguous());
total_indices += indices_list[i].numel();
if (per_sample_weights[i].numel() > 0) {
TORCH_CHECK(per_sample_weights[i].ndimension() == 1);
TORCH_CHECK(per_sample_weights[i].numel() == indices_list[i].numel());
TORCH_CHECK(per_sample_weights[i].is_contiguous());
need_weights = true;
}
}
auto combined_indices =
_cat_int_tensors(indices_list, total_indices, pin_memory);
auto combined_offsets = at::empty(
{total_offsets},
at::TensorOptions()
.dtype(c10::kInt)
.device(offsets_list[0].device())
.pinned_memory(pin_memory));
auto combined_offsets_acc = combined_offsets.accessor<int32_t, 1>();
int32_t offset = 0;
size_t offsets_acc_idx = 0;
combined_offsets_acc[offsets_acc_idx++] = 0;
for (size_t i = 0; i < offsets_list.size(); i++) {
AT_DISPATCH_INDEX_TYPES(
offsets_list[i].scalar_type(), "tbe_input_offsets_", [&] {
auto offsets_acc = offsets_list[i].accessor<index_t, 1>();
int64_t offsets_size =
offsets_list[i].numel() - (include_last_offsets_acc[i] ? 1 : 0);
for (int64_t j = 1; j < offsets_size; j++) {
combined_offsets_acc[offsets_acc_idx++] =
offset + static_cast<int32_t>(offsets_acc[j]);
}
offset += static_cast<int32_t>(indices_list[i].numel());
for (int64_t j = offsets_size; j <= batch_size; j++) {
combined_offsets_acc[offsets_acc_idx++] = offset;
}
});
}
if (need_weights) {
return {
std::move(combined_indices),
std::move(combined_offsets),
_cat_per_sample_weights_list(
per_sample_weights, indices_list, total_indices, pin_memory)};
}
return {combined_indices, combined_offsets, at::empty({0})};
}
} // namespace fbgemm_gpu
TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
m.def(
"tbe_input_combine(Tensor[] indices_list, Tensor[] offsets_list, Tensor[] per_sample_weights, Tensor include_last_offsets) -> (Tensor, Tensor, Tensor)");
m.def(
"tbe_input_combine_with_length(Tensor[] indices_list, Tensor[] lengths_list, Tensor[] per_sample_weights) -> (Tensor, Tensor, Tensor)");
m.def(
"padding_fused_tbe_input_combine(Tensor[] indices_list, Tensor[] offsets_list, Tensor[] per_sample_weights, Tensor include_last_offsets, int batch_size) -> (Tensor, Tensor, Tensor)");
DISPATCH_TO_CPU("tbe_input_combine", fbgemm_gpu::tbe_input_combine_cpu);
DISPATCH_TO_CPU(
"tbe_input_combine_with_length",
fbgemm_gpu::tbe_input_combine_with_length_cpu);
DISPATCH_TO_CPU(
"padding_fused_tbe_input_combine",
fbgemm_gpu::padding_fused_tbe_input_combine_cpu);
}