fbgemm_gpu/src/sparse_ops.cu (1,953 lines of code) (raw):
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/
#include "fbgemm_gpu/sparse_ops.cuh"
#include "fbgemm_gpu/sparse_ops.h"
#include "fbgemm_gpu/sparse_ops_utils.h"
#include <ATen/ATen.h>
#include <ATen/core/op_registration/op_registration.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/Exceptions.h>
#include <c10/cuda/CUDAGuard.h>
#include <torch/library.h>
// clang-format off
#include "fbgemm_gpu/cub_namespace_prefix.cuh"
#include "cub/device/device_scan.cuh"
#include "fbgemm_gpu/cub_namespace_postfix.cuh"
// clang-format on
#include "fbgemm_gpu/embedding_backward_template_helpers.cuh"
#include "fbgemm_gpu/fbgemm_cuda_utils.cuh"
#include "fbgemm_gpu/split_embeddings_utils.cuh"
using Tensor = at::Tensor;
namespace fbgemm_gpu {
std::tuple<uint32_t, uint32_t, uint32_t> calc_offsets_range_thread_block(
const int64_t output_size,
const int64_t num_seq) {
uint32_t threads_per_block;
uint32_t vector_size;
if (output_size / num_seq < 2) {
threads_per_block = 512;
vector_size = 2;
} else if (output_size / num_seq < 4) {
threads_per_block = 512;
vector_size = 4;
} else if (output_size / num_seq < 64) {
threads_per_block = 512;
vector_size = 8;
} else if (output_size / num_seq < 128) {
threads_per_block = 512;
vector_size = 16;
} else {
threads_per_block = 512;
vector_size = 32;
}
uint32_t rows_per_block = threads_per_block / vector_size;
const auto num_blocks = cuda_calc_xblock_count(num_seq, rows_per_block);
return std::make_tuple(num_blocks, rows_per_block, vector_size);
}
// Kernel for calculating the offsets ranges
template <typename scalar_t>
__global__ void _offsets_range_cuda_kernel(
int64_t N,
int64_t range_size,
const scalar_t* __restrict__ offsets_data,
scalar_t* __restrict__ range_data) {
int start_row_idx = blockIdx.x * blockDim.y + threadIdx.y;
const int stride = gridDim.x * blockDim.y;
for (int row_idx = start_row_idx; row_idx < N; row_idx += stride) {
scalar_t row_start = offsets_data[row_idx];
scalar_t row_end =
(row_idx < N - 1 ? offsets_data[row_idx + 1] : range_size);
if (blockDim.x == 32) {
scalar_t i = row_start - (row_start & 31) + threadIdx.x;
// unaligned part
if (i >= row_start && i < row_end) {
range_data[i] = i - row_start;
}
// aligned part
for (i += 32; i < row_end; i += 32) {
range_data[i] = i - row_start;
}
} else {
for (scalar_t i = row_start + threadIdx.x; i < row_end; i += blockDim.x) {
range_data[i] = i - row_start;
}
}
}
}
Tensor offsets_range_cuda(const Tensor& offsets, int64_t range_size) {
TENSOR_ON_CUDA_GPU(offsets);
TENSOR_NDIM_EQUALS(offsets, 1);
at::cuda::OptionalCUDAGuard device_guard;
device_guard.set_index(offsets.get_device());
auto offsets_arg = at::TensorArg(offsets, "offsets", 1);
checkScalarTypes("_offsets_range_cuda", offsets_arg, {at::kLong, at::kInt});
auto range = at::empty(range_size, offsets.options());
if (range_size == 0) {
return range;
}
auto offsets_contig = offsets.contiguous();
int64_t N = offsets_contig.numel();
uint32_t vector_size;
uint32_t rows_per_block;
uint32_t num_blocks;
std::tie(num_blocks, rows_per_block, vector_size) =
calc_offsets_range_thread_block(range_size, N);
dim3 threads(vector_size, rows_per_block);
AT_DISPATCH_INDEX_TYPES(
offsets_contig.scalar_type(), "offsets_range_kernel", [&] {
_offsets_range_cuda_kernel<index_t>
<<<num_blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
N,
range_size,
offsets_contig.data_ptr<index_t>(),
range.data_ptr<index_t>());
C10_CUDA_KERNEL_LAUNCH_CHECK();
});
return range;
}
// Kernel for calculating the segmented sum for sparse matrix with CSR format.
// See https://moderngpu.github.io/segreduce.html
template <typename scalar_t>
__global__ void _segment_sum_csr_cuda_kernel(
int num_segments,
int batch_size,
const int* csr_seg_data,
const scalar_t* values_data,
scalar_t* output_data) {
typedef FBGEMM_GPU_CUB_NS_PREFIX cub::BlockReduce<scalar_t, 256> BlockReduce;
__shared__ typename BlockReduce::TempStorage temp_storage;
int seg_start = csr_seg_data[blockIdx.x] * batch_size;
int seg_end = csr_seg_data[blockIdx.x + 1] * batch_size;
scalar_t sum = 0;
for (int i = seg_start; i < seg_end; i += blockDim.x) {
scalar_t thread_data;
if (threadIdx.x < seg_end - i) {
thread_data = values_data[i + threadIdx.x];
}
scalar_t aggregate =
BlockReduce(temp_storage).Sum(thread_data, seg_end - i);
__syncthreads();
if (threadIdx.x == 0) {
sum += aggregate;
}
}
if (threadIdx.x == 0) {
output_data[blockIdx.x] = sum;
}
}
Tensor segment_sum_csr_cuda(
const int64_t batch_size,
const Tensor& csr_seg,
const Tensor& values) {
TENSOR_ON_CUDA_GPU(csr_seg);
TENSOR_ON_CUDA_GPU(values);
TENSORS_ON_SAME_DEVICE(csr_seg, values);
at::cuda::OptionalCUDAGuard device_guard;
device_guard.set_index(values.get_device());
auto output = at::empty(csr_seg.numel() - 1, values.options());
constexpr uint32_t threads_per_block = 256;
const uint32_t num_blocks = csr_seg.numel() - 1;
AT_DISPATCH_ALL_TYPES(values.type(), "_segment_sum_csr_cuda", [&] {
_segment_sum_csr_cuda_kernel<scalar_t>
<<<num_blocks,
threads_per_block,
0,
at::cuda::getCurrentCUDAStream()>>>(
csr_seg.numel() - 1,
batch_size,
csr_seg.data_ptr<int>(),
values.data_ptr<scalar_t>(),
output.data_ptr<scalar_t>());
C10_CUDA_KERNEL_LAUNCH_CHECK();
});
return output;
}
Tensor asynchronous_inclusive_cumsum_gpu(const Tensor& t_in) {
TENSOR_ON_CUDA_GPU(t_in);
at::cuda::OptionalCUDAGuard device_guard;
device_guard.set_index(t_in.get_device());
size_t temp_storage_bytes = 0;
TORCH_CHECK(t_in.is_contiguous());
TORCH_CHECK(t_in.dtype() == at::kInt || t_in.dtype() == at::kLong);
// CUB only handles up to INT_MAX elements.
TORCH_CHECK(t_in.numel() < std::numeric_limits<int32_t>::max());
auto t_out = at::empty_like(t_in);
AT_DISPATCH_INTEGRAL_TYPES(
t_in.scalar_type(), "cub_inclusive_sum_wrapper1", [&] {
AT_CUDA_CHECK(FBGEMM_GPU_CUB_NS_PREFIX cub::DeviceScan::InclusiveSum(
nullptr,
temp_storage_bytes,
t_in.data_ptr<scalar_t>(),
t_out.data_ptr<scalar_t>(),
t_in.numel(),
at::cuda::getCurrentCUDAStream()));
});
auto temp_storage = at::empty(
{static_cast<int64_t>(temp_storage_bytes)},
t_in.options().dtype(at::kByte));
AT_DISPATCH_INTEGRAL_TYPES(
t_in.scalar_type(), "cub_inclusive_sum_wrapper2", [&] {
AT_CUDA_CHECK(FBGEMM_GPU_CUB_NS_PREFIX cub::DeviceScan::InclusiveSum(
temp_storage.data_ptr(),
temp_storage_bytes,
t_in.data_ptr<scalar_t>(),
t_out.data_ptr<scalar_t>(),
t_in.numel(),
at::cuda::getCurrentCUDAStream()));
});
return t_out;
}
Tensor asynchronous_exclusive_cumsum_gpu(const Tensor& t_in) {
TENSOR_ON_CUDA_GPU(t_in);
at::cuda::OptionalCUDAGuard device_guard;
device_guard.set_index(t_in.get_device());
size_t temp_storage_bytes = 0;
TORCH_CHECK(t_in.is_contiguous());
TORCH_CHECK(t_in.dtype() == at::kInt || t_in.dtype() == at::kLong);
// CUB only handles up to INT_MAX elements.
TORCH_CHECK(t_in.numel() < std::numeric_limits<int32_t>::max());
auto t_out = at::empty_like(t_in);
AT_DISPATCH_INTEGRAL_TYPES(
t_in.scalar_type(), "cub_exclusive_sum_wrapper1", [&] {
AT_CUDA_CHECK(FBGEMM_GPU_CUB_NS_PREFIX cub::DeviceScan::ExclusiveSum(
nullptr,
temp_storage_bytes,
t_in.data_ptr<scalar_t>(),
t_out.data_ptr<scalar_t>(),
t_in.numel(),
at::cuda::getCurrentCUDAStream()));
});
auto temp_storage = at::empty(
{static_cast<int64_t>(temp_storage_bytes)},
t_in.options().dtype(at::kByte));
AT_DISPATCH_INTEGRAL_TYPES(
t_in.scalar_type(), "cub_exclusive_sum_wrapper2", [&] {
AT_CUDA_CHECK(FBGEMM_GPU_CUB_NS_PREFIX cub::DeviceScan::ExclusiveSum(
temp_storage.data_ptr(),
temp_storage_bytes,
t_in.data_ptr<scalar_t>(),
t_out.data_ptr<scalar_t>(),
t_in.numel(),
at::cuda::getCurrentCUDAStream()));
});
return t_out;
}
Tensor asynchronous_complete_cumsum_gpu(const Tensor& t_in) {
TENSOR_ON_CUDA_GPU(t_in);
at::cuda::OptionalCUDAGuard device_guard;
device_guard.set_index(t_in.get_device());
size_t temp_storage_bytes = 0;
TORCH_CHECK(t_in.is_contiguous());
TORCH_CHECK(t_in.dtype() == at::kInt || t_in.dtype() == at::kLong);
// CUB only handles up to INT_MAX elements.
TORCH_CHECK(t_in.numel() < std::numeric_limits<int32_t>::max());
TORCH_CHECK(t_in.dim() == 1);
auto t_out = at::empty({t_in.numel() + 1}, t_in.options());
t_out[0].zero_();
AT_DISPATCH_INTEGRAL_TYPES(
t_in.scalar_type(), "cub_inclusive_sum_wrapper1", [&] {
AT_CUDA_CHECK(FBGEMM_GPU_CUB_NS_PREFIX cub::DeviceScan::InclusiveSum(
nullptr,
temp_storage_bytes,
t_in.data_ptr<scalar_t>(),
t_out.data_ptr<scalar_t>() + 1,
t_in.numel(),
at::cuda::getCurrentCUDAStream()));
});
auto temp_storage = at::empty(
{static_cast<int64_t>(temp_storage_bytes)},
t_in.options().dtype(at::kByte));
AT_DISPATCH_INTEGRAL_TYPES(
t_in.scalar_type(), "cub_inclusive_sum_wrapper2", [&] {
AT_CUDA_CHECK(FBGEMM_GPU_CUB_NS_PREFIX cub::DeviceScan::InclusiveSum(
temp_storage.data_ptr(),
temp_storage_bytes,
t_in.data_ptr<scalar_t>(),
t_out.data_ptr<scalar_t>() + 1,
t_in.numel(),
at::cuda::getCurrentCUDAStream()));
});
return t_out;
}
// Kernel for permuting the indices and weights. Used for permutation of sparse
// data
template <
bool has_weight,
typename offsets_t,
typename indices_t,
typename weights_t>
__global__ void permute_2D_data_kernel(
int32_t len,
int32_t T,
int32_t B,
const indices_t* __restrict__ indices,
const weights_t* __restrict__ weights,
const int32_t* __restrict__ permute,
const offsets_t* __restrict__ input_offsets,
const offsets_t* __restrict__ output_offsets,
indices_t* __restrict__ permuted_indices,
weights_t* __restrict__ permuted_weights) {
int32_t b_t_start = blockIdx.x * blockDim.y + threadIdx.y;
const int stride = gridDim.x * blockDim.y;
for (int b_t = b_t_start; b_t < B * T; b_t += stride) {
int32_t b = b_t % B;
int32_t t = b_t / B;
offsets_t output_start = output_offsets[b_t];
offsets_t segment_length;
if (b_t == B * T - 1) {
segment_length = len - output_offsets[b_t];
} else {
segment_length = output_offsets[b_t + 1] - output_offsets[b_t];
}
offsets_t input_start = input_offsets[permute[t] * B + b];
for (int32_t i = threadIdx.x; i < segment_length; i += blockDim.x) {
permuted_indices[output_start + i] = indices[input_start + i];
if (has_weight) {
permuted_weights[output_start + i] = weights[input_start + i];
}
}
}
}
// Kernel for permuting the lengths. Used for permutation of sparse features.
template <typename index_t>
__global__ void permute_2D_lengths_kernel(
int32_t T,
int32_t B,
const index_t* __restrict__ lengths,
const int32_t* __restrict__ permute,
index_t* __restrict__ permuted_lengths) {
CUDA_KERNEL_LOOP(b_t, B * T) {
int32_t b = b_t % B;
int32_t t = b_t / B;
permuted_lengths[b_t] = lengths[permute[t] * B + b];
}
}
std::tuple<Tensor, Tensor, c10::optional<Tensor>> permute_2D_sparse_data_cuda(
const Tensor& permute,
const Tensor& lengths,
const Tensor& indices,
const c10::optional<Tensor>& weights,
const c10::optional<int64_t>& permuted_lengths_sum) {
TENSOR_ON_CUDA_GPU(permute);
TENSOR_ON_CUDA_GPU(lengths);
TENSOR_ON_CUDA_GPU(indices);
TENSOR_ON_CUDA_GPU(weights);
TORCH_CHECK(lengths.dim() == 2);
TENSORS_ON_SAME_DEVICE(permute, lengths);
TENSORS_ON_SAME_DEVICE(permute, indices);
TENSORS_ON_SAME_DEVICE(permute, weights);
at::cuda::OptionalCUDAGuard device_guard;
device_guard.set_index(indices.get_device());
const auto permute_contig = permute.contiguous();
const auto lengths_contig = lengths.contiguous();
const auto indices_contig = indices.contiguous();
// the data to permute over can be less or more with or without
// repetitions
const auto T = permute.numel();
const auto B = lengths.size(1);
Tensor permuted_lengths;
Tensor permuted_indices;
Tensor permuted_weights;
permuted_lengths = at::empty({T, B}, lengths.options());
constexpr int32_t threads_1 = 256;
const auto blocks_1 = cuda_calc_xblock_count(B * T, threads_1);
AT_DISPATCH_INDEX_TYPES(
lengths.scalar_type(), "permute_2D_lengths_kernel", [&] {
permute_2D_lengths_kernel<index_t>
<<<blocks_1, threads_1, 0, at::cuda::getCurrentCUDAStream()>>>(
T,
B,
lengths_contig.data_ptr<index_t>(),
permute.data_ptr<int32_t>(),
permuted_lengths.data_ptr<index_t>());
C10_CUDA_KERNEL_LAUNCH_CHECK();
});
// convert lengths to offsets
const auto input_offsets = asynchronous_exclusive_cumsum_gpu(lengths_contig);
const auto output_offsets =
asynchronous_exclusive_cumsum_gpu(permuted_lengths);
int64_t permuted_indices_size = 0;
if (permuted_lengths_sum.has_value()) {
permuted_indices_size = permuted_lengths_sum.value();
} else {
permuted_indices_size = permuted_lengths.sum().item<int64_t>();
}
constexpr int32_t BT_blocks = 32;
dim3 threads_2(32, BT_blocks);
const auto blocks_2 = cuda_calc_xblock_count(B * T, BT_blocks);
permuted_indices = at::empty(permuted_indices_size, indices.options());
AT_DISPATCH_INDEX_TYPES(
input_offsets.scalar_type(), "permute_2D_data_kernel_1", [&] {
using offsets_t = index_t;
AT_DISPATCH_ALL_TYPES_AND(
at::ScalarType::Half,
indices.scalar_type(),
"permute_2D_data_kernel_2",
[&] {
using indices_t = scalar_t;
if (weights.has_value()) {
const Tensor weights_value = weights.value();
const auto weights_value_contig = weights_value.contiguous();
permuted_weights =
at::empty(permuted_indices_size, weights_value.options());
AT_DISPATCH_ALL_TYPES_AND(
at::ScalarType::Half,
weights_value.scalar_type(),
"permute_2D_data_kernel_3",
[&] {
using weights_t = scalar_t;
permute_2D_data_kernel<
true,
offsets_t,
indices_t,
weights_t>
<<<blocks_2,
threads_2,
0,
at::cuda::getCurrentCUDAStream()>>>(
permuted_indices_size,
T,
B,
indices_contig.data_ptr<indices_t>(),
weights_value_contig.data_ptr<weights_t>(),
permute_contig.data_ptr<int32_t>(),
input_offsets.data_ptr<offsets_t>(),
output_offsets.data_ptr<offsets_t>(),
permuted_indices.data_ptr<indices_t>(),
permuted_weights.data_ptr<weights_t>());
C10_CUDA_KERNEL_LAUNCH_CHECK();
}); // for each weights_t
} else {
permute_2D_data_kernel<
false,
offsets_t,
indices_t,
std::nullptr_t>
<<<blocks_2,
threads_2,
0,
at::cuda::getCurrentCUDAStream()>>>(
permuted_indices_size,
T,
B,
indices_contig.data_ptr<indices_t>(),
nullptr,
permute_contig.data_ptr<int32_t>(),
input_offsets.data_ptr<offsets_t>(),
output_offsets.data_ptr<offsets_t>(),
permuted_indices.data_ptr<indices_t>(),
nullptr);
C10_CUDA_KERNEL_LAUNCH_CHECK();
}
}); // for each indices_t
}); // for each offsets_t
return {permuted_lengths, permuted_indices, permuted_weights};
}
// Kernel for permuting 1D lengths. Used for permutation of sparse features.
template <typename index_t>
__global__ void permute_1D_lengths_kernel(
const index_t* __restrict__ lengths,
int32_t permuted_lengths_size,
const int32_t* __restrict__ permute,
index_t* __restrict__ permuted_lengths) {
CUDA_KERNEL_LOOP(i, permuted_lengths_size) {
permuted_lengths[i] = lengths[permute[i]];
}
}
// Kernel for permuting the indices and weights. Used for permutation of sparse
// data
template <
bool has_weight,
typename offsets_t,
typename indices_t,
typename weights_t>
__global__ void permute_1D_data_kernel(
int32_t permuted_indices_size,
int32_t permuted_lengths_size,
const indices_t* __restrict__ indices,
const weights_t* __restrict__ weights,
const int32_t* __restrict__ permute,
const offsets_t* __restrict__ input_offsets,
const offsets_t* __restrict__ output_offsets,
indices_t* __restrict__ permuted_indices,
weights_t* __restrict__ permuted_weights) {
int32_t b_t_start = blockIdx.x * blockDim.y + threadIdx.y;
const int stride = gridDim.x * blockDim.y;
for (int b_t = b_t_start; b_t < permuted_lengths_size; b_t += stride) {
offsets_t output_start = output_offsets[b_t];
offsets_t segment_length;
if (b_t == permuted_lengths_size - 1) {
segment_length = permuted_indices_size - output_offsets[b_t];
} else {
segment_length = output_offsets[b_t + 1] - output_offsets[b_t];
}
offsets_t input_start = input_offsets[permute[b_t]];
for (int32_t i = threadIdx.x; i < segment_length; i += blockDim.x) {
permuted_indices[output_start + i] = indices[input_start + i];
if (has_weight) {
permuted_weights[output_start + i] = weights[input_start + i];
}
}
}
}
std::tuple<Tensor, Tensor, c10::optional<Tensor>> permute_1D_sparse_data_cuda(
const Tensor& permute,
const Tensor& lengths,
const Tensor& indices,
const c10::optional<Tensor>& weights,
const c10::optional<int64_t>& permuted_lengths_sum) {
TENSOR_ON_CUDA_GPU(permute);
TENSOR_ON_CUDA_GPU(lengths);
TENSOR_ON_CUDA_GPU(indices);
TENSOR_ON_CUDA_GPU(weights);
TENSORS_ON_SAME_DEVICE(permute, lengths);
TENSORS_ON_SAME_DEVICE(permute, indices);
TENSORS_ON_SAME_DEVICE(permute, weights);
at::cuda::OptionalCUDAGuard device_guard;
device_guard.set_index(indices.get_device());
const auto permute_contig = permute.contiguous();
const auto lengths_contig = lengths.contiguous();
const auto indices_contig = indices.contiguous();
// the data to permute over can be less or more with or without
// repetitions
const auto lengths_size = lengths.numel();
const auto permuted_lengths_size = permute.numel();
Tensor permuted_lengths;
Tensor permuted_indices;
Tensor permuted_weights;
permuted_lengths = at::empty({permuted_lengths_size}, lengths.options());
constexpr int32_t threads_1 = kMaxThreads;
const auto blocks_1 =
cuda_calc_xblock_count(permuted_lengths_size, threads_1);
AT_DISPATCH_INDEX_TYPES(
lengths.scalar_type(), "permute_1D_lengths_kernel", [&] {
permute_1D_lengths_kernel<index_t>
<<<blocks_1, threads_1, 0, at::cuda::getCurrentCUDAStream()>>>(
lengths_contig.data_ptr<index_t>(),
permuted_lengths_size,
permute.data_ptr<int32_t>(),
permuted_lengths.data_ptr<index_t>());
C10_CUDA_KERNEL_LAUNCH_CHECK();
});
// convert lengths to offsets
const auto input_offsets = asynchronous_exclusive_cumsum_gpu(lengths_contig);
const auto output_offsets =
asynchronous_exclusive_cumsum_gpu(permuted_lengths);
int64_t permuted_indices_size = 0;
if (permuted_lengths_sum.has_value()) {
permuted_indices_size = permuted_lengths_sum.value();
} else {
permuted_indices_size = permuted_lengths.sum().item<int64_t>();
}
constexpr int32_t BT_blocks = 32;
dim3 threads_2(32, BT_blocks);
const auto blocks_2 =
cuda_calc_xblock_count(permuted_lengths_size, BT_blocks);
permuted_indices = at::empty(permuted_indices_size, indices.options());
AT_DISPATCH_INDEX_TYPES(
input_offsets.scalar_type(), "permute_1D_data_kernel_1", [&] {
using offsets_t = index_t;
AT_DISPATCH_ALL_TYPES_AND(
at::ScalarType::Half,
indices.scalar_type(),
"permute_1D_data_kernel_2",
[&] {
using indices_t = scalar_t;
if (weights.has_value()) {
const Tensor weights_value = weights.value();
const auto weights_value_contig = weights_value.contiguous();
permuted_weights =
at::empty(permuted_indices_size, weights_value.options());
AT_DISPATCH_ALL_TYPES_AND(
at::ScalarType::Half,
weights_value.scalar_type(),
"permute_1D_data_kernel_3",
[&] {
using weights_t = scalar_t;
permute_1D_data_kernel<
true,
offsets_t,
indices_t,
weights_t>
<<<blocks_2,
threads_2,
0,
at::cuda::getCurrentCUDAStream()>>>(
permuted_indices_size,
permuted_lengths_size,
indices_contig.data_ptr<indices_t>(),
weights_value_contig.data_ptr<weights_t>(),
permute_contig.data_ptr<int32_t>(),
input_offsets.data_ptr<offsets_t>(),
output_offsets.data_ptr<offsets_t>(),
permuted_indices.data_ptr<indices_t>(),
permuted_weights.data_ptr<weights_t>());
C10_CUDA_KERNEL_LAUNCH_CHECK();
}); // for each weights_t
} else {
permute_1D_data_kernel<
false,
offsets_t,
indices_t,
std::nullptr_t>
<<<blocks_2,
threads_2,
0,
at::cuda::getCurrentCUDAStream()>>>(
permuted_indices_size,
permuted_lengths_size,
indices_contig.data_ptr<indices_t>(),
nullptr,
permute_contig.data_ptr<int32_t>(),
input_offsets.data_ptr<offsets_t>(),
output_offsets.data_ptr<offsets_t>(),
permuted_indices.data_ptr<indices_t>(),
nullptr);
C10_CUDA_KERNEL_LAUNCH_CHECK();
}
}); // for each indices_t
}); // for each offsets_t
return {permuted_lengths, permuted_indices, permuted_weights};
}
// Kernel for generate 1D data permute from dimension permute index.
// Used for permutation of sparse features.
template <typename index_t, typename offsets_t>
__global__ void expand_into_jagged_permute_kernel(
const offsets_t* __restrict__ input_offsets,
const offsets_t* __restrict__ output_offsets,
int32_t input_size,
const index_t* __restrict__ permute,
index_t* __restrict__ output_permute) {
const int32_t t_start = blockIdx.x * blockDim.y + threadIdx.y;
const int stride = gridDim.x * blockDim.y;
for (int t = t_start; t < input_size; t += stride) {
const offsets_t output_start = output_offsets[t];
const offsets_t segment_length = output_offsets[t + 1] - output_offsets[t];
const offsets_t input_start = input_offsets[permute[t]];
for (int32_t i = threadIdx.x; i < segment_length; i += blockDim.x) {
output_permute[output_start + i] = input_start + i;
}
}
}
Tensor expand_into_jagged_permute_cuda(
const Tensor& permute,
const Tensor& input_offsets,
const Tensor& output_offsets,
int64_t output_size) {
TENSOR_ON_CUDA_GPU(permute);
TENSOR_ON_CUDA_GPU(input_offsets);
TENSOR_ON_CUDA_GPU(output_offsets);
TENSORS_ON_SAME_DEVICE(permute, input_offsets);
TENSORS_ON_SAME_DEVICE(permute, output_offsets);
TORCH_CHECK(permute.numel() > 0);
TORCH_CHECK(permute.numel() == input_offsets.numel() - 1);
TORCH_CHECK(permute.numel() == output_offsets.numel() - 1);
at::cuda::OptionalCUDAGuard device_guard;
device_guard.set_index(permute.get_device());
const auto permute_contig = permute.contiguous();
const auto permute_size = permute.numel();
Tensor output_permute = at::empty({output_size}, permute.options());
// number of table per block
constexpr int32_t T_blocks = kMaxThreads / kWarpSize;
dim3 threads(kWarpSize, T_blocks);
const auto blocks = cuda_calc_xblock_count(permute_size, T_blocks);
AT_DISPATCH_INDEX_TYPES(
permute.scalar_type(), "expand_into_jagged_permute_kernel", [&] {
using offsets_t = index_t;
expand_into_jagged_permute_kernel<index_t, offsets_t>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
input_offsets.data_ptr<offsets_t>(),
output_offsets.data_ptr<offsets_t>(),
permute_size,
permute.data_ptr<index_t>(),
output_permute.data_ptr<index_t>());
C10_CUDA_KERNEL_LAUNCH_CHECK();
});
return output_permute;
}
// Kernel for bucketize lengths, with the Block distribution (vs. cyclic,
// block-cyclic distribution). Used for bucketize sparse feature, especially for
// checkpointing with row-wise partition (sparse_feature is partitioned
// continuously along the sparse dimension into my_size blocks)
template <typename offset_t, typename index_t>
__global__ void _block_bucketize_sparse_features_cuda_kernel1(
int32_t lengths_size,
int32_t B,
const index_t* __restrict__ block_sizes_data,
int my_size,
const offset_t* __restrict__ offsets_data,
const index_t* __restrict__ indices_data,
offset_t* __restrict__ new_lengths_data) {
using uindex_t = std::make_unsigned_t<index_t>;
CUDA_KERNEL_LOOP(b_t, lengths_size) {
int32_t t = b_t / B;
index_t blk_size = block_sizes_data[t];
offset_t rowstart = (b_t == 0 ? 0 : offsets_data[b_t - 1]);
offset_t rowend = offsets_data[b_t];
for (index_t i = rowstart; i < rowend; ++i) {
// We have use cases using none-hashed raw indices that can be either
// negative or larger than embedding table hash_size (blk_size *
// my_size). In cases of none-hashed indices we need to ensure
// bucketization can distribute them into different ranks and within
// range of blk_size, we expect the later embedding module to take care
// of hashing indices calculation.
uindex_t idx = static_cast<uindex_t>(indices_data[i]);
uindex_t p = idx < blk_size * my_size ? idx / blk_size : idx % my_size;
new_lengths_data[p * lengths_size + b_t]++;
}
}
}
// Kernel for bucketize offsets, indices, and positional weights, with the Block
// distribution (vs. cyclic, block-cyclic distribution). Used for bucketize
// sparse feature, especially for checkpointing with row-wise partition
// (sparse_feature is partitioned continuously along the sparse dimension into
// my_size blocks)
template <
bool sequence,
bool has_weight,
bool bucketize_pos,
typename offset_t,
typename index_t,
typename scalar_t>
__global__ void _block_bucketize_sparse_features_cuda_kernel2(
int lengths_size,
int32_t B,
const index_t* __restrict__ block_sizes_data,
int my_size,
const offset_t* __restrict__ offsets_data,
const index_t* __restrict__ indices_data,
const scalar_t* __restrict__ weights_data,
offset_t* __restrict__ new_offsets_data,
index_t* __restrict__ new_indices_data,
scalar_t* __restrict__ new_weights_data,
index_t* __restrict__ new_pos_data,
index_t* __restrict__ unbucketize_permute_data) {
using uindex_t = std::make_unsigned_t<index_t>;
using uoffset_t = std::make_unsigned_t<offset_t>;
CUDA_KERNEL_LOOP(b_t, lengths_size) {
int32_t t = b_t / B;
index_t blk_size = block_sizes_data[t];
offset_t rowstart = (b_t == 0 ? 0 : offsets_data[b_t - 1]);
offset_t rowend = offsets_data[b_t];
for (index_t i = rowstart; i < rowend; ++i) {
// We have use cases using none-hashed raw indices that can be either
// negative or larger than embedding table hash_size (blk_size *
// my_size). In cases of none-hashed indices we need to ensure
// bucketization can distribute them into different ranks and within
// range of blk_size, we expect the later embedding module to take care
// of hashing indices calculation.
uindex_t idx = static_cast<uindex_t>(indices_data[i]);
uindex_t p = idx < blk_size * my_size ? idx / blk_size : idx % my_size;
uindex_t new_idx =
idx < blk_size * my_size ? idx % blk_size : idx / my_size;
uoffset_t pos = new_offsets_data[p * lengths_size + b_t];
new_indices_data[pos] = new_idx;
new_offsets_data[p * lengths_size + b_t]++;
if (sequence) {
unbucketize_permute_data[i] = pos;
}
if (has_weight) {
new_weights_data[pos] = weights_data[i];
}
if (bucketize_pos) {
new_pos_data[pos] = i - rowstart;
}
}
}
}
// This function partitions sparse features
// continuously along the sparse dimension into my_size blocks
std::tuple<
Tensor,
Tensor,
c10::optional<Tensor>,
c10::optional<Tensor>,
c10::optional<Tensor>>
block_bucketize_sparse_features_cuda(
Tensor lengths,
Tensor indices,
bool bucketize_pos,
bool sequence,
Tensor block_sizes,
int64_t my_size,
c10::optional<Tensor> weights) {
TENSOR_ON_CUDA_GPU(lengths);
TENSOR_ON_CUDA_GPU(indices);
TENSORS_ON_SAME_DEVICE(lengths, indices);
TENSOR_ON_CUDA_GPU(weights);
TENSORS_ON_SAME_DEVICE(lengths, weights);
at::cuda::OptionalCUDAGuard device_guard;
device_guard.set_index(lengths.get_device());
// allocate tensors and buffers
const int lengths_size = lengths.numel();
const int T = block_sizes.numel();
const int B = lengths_size / T;
const int new_lengths_size = lengths_size * my_size;
auto offsets = at::empty({lengths_size}, lengths.options());
auto new_lengths = at::zeros({new_lengths_size}, lengths.options());
auto new_offsets = at::empty({new_lengths_size}, lengths.options());
auto new_indices = at::empty_like(indices);
auto lengths_contig = lengths.contiguous();
auto indices_contig = indices.contiguous();
auto offsets_contig = offsets.contiguous();
Tensor new_weights;
Tensor new_pos;
Tensor unbucketize_permute;
// count nonzeros
offsets_contig = asynchronous_inclusive_cumsum_gpu(lengths);
int threads_per_block = 256;
int num_blocks = (lengths_size + threads_per_block - 1) / threads_per_block;
AT_DISPATCH_INDEX_TYPES(
offsets_contig.scalar_type(),
"_block_bucketize_sparse_features_cuda_kernel1",
[&] {
using offset_t = index_t;
AT_DISPATCH_INDEX_TYPES(
indices_contig.scalar_type(),
"_block_bucketize_sparse_features_cuda_kernel2",
[&] {
_block_bucketize_sparse_features_cuda_kernel1<<<
num_blocks,
threads_per_block,
0,
at::cuda::getCurrentCUDAStream()>>>(
lengths_size,
B,
block_sizes.data_ptr<index_t>(),
my_size,
offsets_contig.data_ptr<offset_t>(),
indices_contig.data_ptr<index_t>(),
new_lengths.data_ptr<offset_t>());
C10_CUDA_KERNEL_LAUNCH_CHECK();
});
});
// bucketize nonzeros
new_offsets = asynchronous_exclusive_cumsum_gpu(new_lengths);
if (sequence) {
const auto lengths_sum = indices.numel();
unbucketize_permute = at::empty({lengths_sum}, indices.options());
if (weights.has_value() & bucketize_pos) {
Tensor weights_value = weights.value();
auto weights_value_contig = weights_value.contiguous();
new_weights = at::empty_like(weights_value);
new_pos = at::empty_like(indices);
AT_DISPATCH_INDEX_TYPES(
offsets_contig.scalar_type(),
"_bucketize_sparse_features_weight_cuda_kernel2_1",
[&] {
using offset_t = index_t;
AT_DISPATCH_INDEX_TYPES(
indices_contig.scalar_type(),
"_bucketize_sparse_features_weight_cuda_kernel2_2",
[&] {
AT_DISPATCH_FLOATING_TYPES(
weights_value.scalar_type(),
"_block_bucketize_sparse_features_cuda_weight_kernel2_3",
[&] {
_block_bucketize_sparse_features_cuda_kernel2<
true,
true,
true,
offset_t,
index_t,
scalar_t>
<<<num_blocks,
threads_per_block,
0,
at::cuda::getCurrentCUDAStream()>>>(
lengths_size,
B,
block_sizes.data_ptr<index_t>(),
my_size,
offsets_contig.data_ptr<offset_t>(),
indices_contig.data_ptr<index_t>(),
weights_value_contig.data_ptr<scalar_t>(),
new_offsets.data_ptr<offset_t>(),
new_indices.data_ptr<index_t>(),
new_weights.data_ptr<scalar_t>(),
new_pos.data_ptr<index_t>(),
unbucketize_permute.data_ptr<index_t>());
C10_CUDA_KERNEL_LAUNCH_CHECK();
});
});
});
} else if (weights.has_value()) {
Tensor weights_value = weights.value();
auto weights_value_contig = weights_value.contiguous();
new_weights = at::empty_like(weights_value);
AT_DISPATCH_INDEX_TYPES(
offsets_contig.scalar_type(),
"_bucketize_sparse_features_weight_cuda_kernel2_1",
[&] {
using offset_t = index_t;
AT_DISPATCH_INDEX_TYPES(
indices_contig.scalar_type(),
"_bucketize_sparse_features_weight_cuda_kernel2_2",
[&] {
AT_DISPATCH_FLOATING_TYPES(
weights_value.scalar_type(),
"_block_bucketize_sparse_features_cuda_weight_kernel2_3",
[&] {
_block_bucketize_sparse_features_cuda_kernel2<
true,
true,
false,
offset_t,
index_t,
scalar_t>
<<<num_blocks,
threads_per_block,
0,
at::cuda::getCurrentCUDAStream()>>>(
lengths_size,
B,
block_sizes.data_ptr<index_t>(),
my_size,
offsets_contig.data_ptr<offset_t>(),
indices_contig.data_ptr<index_t>(),
weights_value_contig.data_ptr<scalar_t>(),
new_offsets.data_ptr<offset_t>(),
new_indices.data_ptr<index_t>(),
new_weights.data_ptr<scalar_t>(),
nullptr,
unbucketize_permute.data_ptr<index_t>());
C10_CUDA_KERNEL_LAUNCH_CHECK();
});
});
});
} else if (bucketize_pos) {
new_pos = at::empty_like(indices);
AT_DISPATCH_INDEX_TYPES(
offsets_contig.scalar_type(),
"_bucketize_sparse_features_weight_cuda_kernel2_1",
[&] {
using offset_t = index_t;
AT_DISPATCH_INDEX_TYPES(
indices_contig.scalar_type(),
"_block_bucketize_sparse_features_cuda_kernel2_2",
[&] {
_block_bucketize_sparse_features_cuda_kernel2<
true,
false,
true,
offset_t,
index_t,
std::nullptr_t>
<<<num_blocks,
threads_per_block,
0,
at::cuda::getCurrentCUDAStream()>>>(
lengths_size,
B,
block_sizes.data_ptr<index_t>(),
my_size,
offsets_contig.data_ptr<offset_t>(),
indices_contig.data_ptr<index_t>(),
nullptr,
new_offsets.data_ptr<offset_t>(),
new_indices.data_ptr<index_t>(),
nullptr,
new_pos.data_ptr<index_t>(),
unbucketize_permute.data_ptr<index_t>());
C10_CUDA_KERNEL_LAUNCH_CHECK();
});
});
} else {
AT_DISPATCH_INDEX_TYPES(
offsets_contig.scalar_type(),
"_bucketize_sparse_features_weight_cuda_kernel2_1",
[&] {
using offset_t = index_t;
AT_DISPATCH_INDEX_TYPES(
indices_contig.scalar_type(),
"_block_bucketize_sparse_features_cuda_kernel2_2",
[&] {
_block_bucketize_sparse_features_cuda_kernel2<
true,
false,
false,
offset_t,
index_t,
std::nullptr_t>
<<<num_blocks,
threads_per_block,
0,
at::cuda::getCurrentCUDAStream()>>>(
lengths_size,
B,
block_sizes.data_ptr<index_t>(),
my_size,
offsets_contig.data_ptr<offset_t>(),
indices_contig.data_ptr<index_t>(),
nullptr,
new_offsets.data_ptr<offset_t>(),
new_indices.data_ptr<index_t>(),
nullptr,
nullptr,
unbucketize_permute.data_ptr<index_t>());
C10_CUDA_KERNEL_LAUNCH_CHECK();
});
});
}
} else {
if (weights.has_value() & bucketize_pos) {
Tensor weights_value = weights.value();
auto weights_value_contig = weights_value.contiguous();
new_weights = at::empty_like(weights_value);
new_pos = at::empty_like(indices);
AT_DISPATCH_INDEX_TYPES(
offsets_contig.scalar_type(),
"_bucketize_sparse_features_weight_cuda_kernel2_1",
[&] {
using offset_t = index_t;
AT_DISPATCH_INDEX_TYPES(
indices_contig.scalar_type(),
"_bucketize_sparse_features_weight_cuda_kernel2_2",
[&] {
AT_DISPATCH_FLOATING_TYPES(
weights_value.scalar_type(),
"_block_bucketize_sparse_features_cuda_weight_kernel2_3",
[&] {
_block_bucketize_sparse_features_cuda_kernel2<
false,
true,
true,
offset_t,
index_t,
scalar_t>
<<<num_blocks,
threads_per_block,
0,
at::cuda::getCurrentCUDAStream()>>>(
lengths_size,
B,
block_sizes.data_ptr<index_t>(),
my_size,
offsets_contig.data_ptr<offset_t>(),
indices_contig.data_ptr<index_t>(),
weights_value_contig.data_ptr<scalar_t>(),
new_offsets.data_ptr<offset_t>(),
new_indices.data_ptr<index_t>(),
new_weights.data_ptr<scalar_t>(),
new_pos.data_ptr<index_t>(),
nullptr);
C10_CUDA_KERNEL_LAUNCH_CHECK();
});
});
});
} else if (weights.has_value()) {
Tensor weights_value = weights.value();
auto weights_value_contig = weights_value.contiguous();
new_weights = at::empty_like(weights_value);
AT_DISPATCH_INDEX_TYPES(
offsets_contig.scalar_type(),
"_bucketize_sparse_features_weight_cuda_kernel2_1",
[&] {
using offset_t = index_t;
AT_DISPATCH_INDEX_TYPES(
indices_contig.scalar_type(),
"_bucketize_sparse_features_weight_cuda_kernel2_2",
[&] {
AT_DISPATCH_FLOATING_TYPES(
weights_value.scalar_type(),
"_block_bucketize_sparse_features_cuda_weight_kernel2_3",
[&] {
_block_bucketize_sparse_features_cuda_kernel2<
false,
true,
false,
offset_t,
index_t,
scalar_t>
<<<num_blocks,
threads_per_block,
0,
at::cuda::getCurrentCUDAStream()>>>(
lengths_size,
B,
block_sizes.data_ptr<index_t>(),
my_size,
offsets_contig.data_ptr<offset_t>(),
indices_contig.data_ptr<index_t>(),
weights_value_contig.data_ptr<scalar_t>(),
new_offsets.data_ptr<offset_t>(),
new_indices.data_ptr<index_t>(),
new_weights.data_ptr<scalar_t>(),
nullptr,
nullptr);
C10_CUDA_KERNEL_LAUNCH_CHECK();
});
});
});
} else if (bucketize_pos) {
new_pos = at::empty_like(indices);
AT_DISPATCH_INDEX_TYPES(
offsets_contig.scalar_type(),
"_bucketize_sparse_features_weight_cuda_kernel2_1",
[&] {
using offset_t = index_t;
AT_DISPATCH_INDEX_TYPES(
indices_contig.scalar_type(),
"_block_bucketize_sparse_features_cuda_kernel2_2",
[&] {
_block_bucketize_sparse_features_cuda_kernel2<
false,
false,
true,
offset_t,
index_t,
std::nullptr_t>
<<<num_blocks,
threads_per_block,
0,
at::cuda::getCurrentCUDAStream()>>>(
lengths_size,
B,
block_sizes.data_ptr<index_t>(),
my_size,
offsets_contig.data_ptr<offset_t>(),
indices_contig.data_ptr<index_t>(),
nullptr,
new_offsets.data_ptr<offset_t>(),
new_indices.data_ptr<index_t>(),
nullptr,
new_pos.data_ptr<index_t>(),
nullptr);
C10_CUDA_KERNEL_LAUNCH_CHECK();
});
});
} else {
AT_DISPATCH_INDEX_TYPES(
offsets_contig.scalar_type(),
"_bucketize_sparse_features_weight_cuda_kernel2_1",
[&] {
using offset_t = index_t;
AT_DISPATCH_INDEX_TYPES(
indices_contig.scalar_type(),
"_block_bucketize_sparse_features_cuda_kernel2_2",
[&] {
_block_bucketize_sparse_features_cuda_kernel2<
false,
false,
false,
offset_t,
index_t,
std::nullptr_t>
<<<num_blocks,
threads_per_block,
0,
at::cuda::getCurrentCUDAStream()>>>(
lengths_size,
B,
block_sizes.data_ptr<index_t>(),
my_size,
offsets_contig.data_ptr<offset_t>(),
indices_contig.data_ptr<index_t>(),
nullptr,
new_offsets.data_ptr<offset_t>(),
new_indices.data_ptr<index_t>(),
nullptr,
nullptr,
nullptr);
C10_CUDA_KERNEL_LAUNCH_CHECK();
});
});
}
}
return {new_lengths, new_indices, new_weights, new_pos, unbucketize_permute};
}
// Kernel for bucketize lengths, with the Cyclic distribution (vs. block,
// block-cyclic distribution). Used for bucketize sparse feature with row-wise
// partition (sparse_feature is partitioned cyclically along the sparse
// dimension into my_size blocks)
template <typename scalar_t>
__global__ void _bucketize_sparse_features_cuda_kernel1(
int lengths_size,
int my_size,
const scalar_t* __restrict__ offsets_data,
const scalar_t* __restrict__ indices_data,
scalar_t* __restrict__ new_lengths_data) {
using uscalar_t = std::make_unsigned_t<scalar_t>;
CUDA_KERNEL_LOOP(r, lengths_size) {
scalar_t rowstart = (r == 0 ? 0 : offsets_data[r - 1]);
scalar_t rowend = offsets_data[r];
for (scalar_t i = rowstart; i < rowend; ++i) {
// Need to handle negative indices if we use raw indices instead of hashed
// indices, convert to unsigned
uscalar_t idx = static_cast<uscalar_t>(indices_data[i]);
uscalar_t p = idx % my_size;
new_lengths_data[p * lengths_size + r]++;
}
}
}
// Kernel for bucketize offsets, indices, and positional weights, with the
// Cyclic distribution (vs. block, block-cyclic distribution). Used for
// bucketize sparse feature with row-wise partition (sparse_feature is
// partitioned cyclically along the sparse dimension into my_size blocks)
template <
bool has_weight,
bool bucketize_pos,
typename index_t,
typename scalar_t>
__global__ void _bucketize_sparse_features_cuda_kernel2(
int lengths_size,
int my_size,
const index_t* __restrict__ offsets_data,
const index_t* __restrict__ indices_data,
const scalar_t* __restrict__ weights_data,
index_t* __restrict__ new_offsets_data,
index_t* __restrict__ new_indices_data,
scalar_t* __restrict__ new_weights_data,
index_t* __restrict__ new_pos_data) {
using uindex_t = std::make_unsigned_t<index_t>;
CUDA_KERNEL_LOOP(r, lengths_size) {
index_t rowstart = r == 0 ? 0 : offsets_data[r - 1];
index_t rowend = offsets_data[r];
for (index_t i = rowstart; i < rowend; ++i) {
// Need to handle negative indices if we use raw indices instead of hashed
// indices, convert to unsigned
uindex_t idx = static_cast<uindex_t>(indices_data[i]);
uindex_t p = idx % my_size;
uindex_t new_idx = idx / my_size;
uindex_t pos = new_offsets_data[p * lengths_size + r];
new_indices_data[pos] = new_idx;
new_offsets_data[p * lengths_size + r]++;
if (has_weight) {
new_weights_data[pos] = weights_data[i];
}
if (bucketize_pos) {
new_pos_data[pos] = i - rowstart;
}
}
}
}
// This function partitions sparse features
// cyclically along the sparse dimension into my_size blocks
std::tuple<Tensor, Tensor, c10::optional<Tensor>, c10::optional<Tensor>>
bucketize_sparse_features_cuda(
const Tensor& lengths,
const Tensor& indices,
const bool bucketize_pos,
const int64_t my_size,
const c10::optional<Tensor>& weights) {
TENSOR_ON_CUDA_GPU(lengths);
TENSOR_ON_CUDA_GPU(indices);
TENSORS_ON_SAME_DEVICE(lengths, indices);
TENSOR_ON_CUDA_GPU(weights);
TENSORS_ON_SAME_DEVICE(lengths, weights);
at::cuda::OptionalCUDAGuard device_guard;
device_guard.set_index(lengths.get_device());
// allocate tensors and buffers
const int lengths_size = lengths.numel();
const int new_lengths_size = lengths_size * my_size;
auto offsets = at::empty({lengths_size}, lengths.options());
auto new_lengths = at::zeros({new_lengths_size}, lengths.options());
auto new_offsets = at::empty({new_lengths_size}, lengths.options());
auto new_indices = at::empty_like(indices);
auto lengths_contig = lengths.contiguous();
auto indices_contig = indices.contiguous();
auto offsets_contig = offsets.contiguous();
Tensor new_weights;
Tensor new_pos;
// count nonzeros
offsets_contig = fbgemm_gpu::asynchronous_inclusive_cumsum_gpu(lengths);
int threads_per_block = 256;
const auto num_blocks =
cuda_calc_xblock_count(lengths_size, threads_per_block);
AT_DISPATCH_INDEX_TYPES(
indices_contig.scalar_type(),
"_bucketize_sparse_features_cuda_kernel1",
([&] {
_bucketize_sparse_features_cuda_kernel1<<<
num_blocks,
threads_per_block,
0,
at::cuda::getCurrentCUDAStream()>>>(
lengths_size,
my_size,
offsets_contig.data_ptr<index_t>(),
indices_contig.data_ptr<index_t>(),
new_lengths.data_ptr<index_t>());
C10_CUDA_KERNEL_LAUNCH_CHECK();
}));
// bucketize nonzeros
new_offsets = fbgemm_gpu::asynchronous_exclusive_cumsum_gpu(new_lengths);
if (weights.has_value() & bucketize_pos) {
Tensor weights_value = weights.value();
auto weights_value_contig = weights_value.contiguous();
new_weights = at::empty_like(weights_value);
new_pos = at::empty_like(indices);
AT_DISPATCH_INDEX_TYPES(
indices_contig.scalar_type(),
"_bucketize_sparse_features_weight_cuda_kernel2_1",
([&] {
AT_DISPATCH_FLOATING_TYPES(
weights_value.scalar_type(),
"_bucketize_sparse_features_cuda_weight_kernel2_2",
([&] {
_bucketize_sparse_features_cuda_kernel2<
true,
true,
index_t,
scalar_t>
<<<num_blocks,
threads_per_block,
0,
at::cuda::getCurrentCUDAStream()>>>(
lengths_size,
my_size,
offsets_contig.data_ptr<index_t>(),
indices_contig.data_ptr<index_t>(),
weights_value_contig.data_ptr<scalar_t>(),
new_offsets.data_ptr<index_t>(),
new_indices.data_ptr<index_t>(),
new_weights.data_ptr<scalar_t>(),
new_pos.data_ptr<index_t>());
C10_CUDA_KERNEL_LAUNCH_CHECK();
}));
}));
} else if (weights.has_value()) {
Tensor weights_value = weights.value();
auto weights_value_contig = weights_value.contiguous();
new_weights = at::empty_like(weights_value);
AT_DISPATCH_INDEX_TYPES(
indices_contig.scalar_type(),
"_bucketize_sparse_features_weight_cuda_kernel2_1",
([&] {
AT_DISPATCH_FLOATING_TYPES(
weights_value.scalar_type(),
"_bucketize_sparse_features_cuda_weight_kernel2_2",
([&] {
_bucketize_sparse_features_cuda_kernel2<
true,
false,
index_t,
scalar_t>
<<<num_blocks,
threads_per_block,
0,
at::cuda::getCurrentCUDAStream()>>>(
lengths_size,
my_size,
offsets_contig.data_ptr<index_t>(),
indices_contig.data_ptr<index_t>(),
weights_value_contig.data_ptr<scalar_t>(),
new_offsets.data_ptr<index_t>(),
new_indices.data_ptr<index_t>(),
new_weights.data_ptr<scalar_t>(),
nullptr);
C10_CUDA_KERNEL_LAUNCH_CHECK();
}));
}));
} else if (bucketize_pos) {
new_pos = at::empty_like(indices);
AT_DISPATCH_INDEX_TYPES(
indices_contig.scalar_type(),
"_bucketize_sparse_features_cuda_kernel2",
([&] {
_bucketize_sparse_features_cuda_kernel2<
false,
true,
index_t,
std::nullptr_t>
<<<num_blocks,
threads_per_block,
0,
at::cuda::getCurrentCUDAStream()>>>(
lengths_size,
my_size,
offsets_contig.data_ptr<index_t>(),
indices_contig.data_ptr<index_t>(),
nullptr,
new_offsets.data_ptr<index_t>(),
new_indices.data_ptr<index_t>(),
nullptr,
new_pos.data_ptr<index_t>());
C10_CUDA_KERNEL_LAUNCH_CHECK();
}));
} else {
AT_DISPATCH_INDEX_TYPES(
indices_contig.scalar_type(),
"_bucketize_sparse_features_cuda_kernel2",
([&] {
_bucketize_sparse_features_cuda_kernel2<
false,
false,
index_t,
std::nullptr_t>
<<<num_blocks,
threads_per_block,
0,
at::cuda::getCurrentCUDAStream()>>>(
lengths_size,
my_size,
offsets_contig.data_ptr<index_t>(),
indices_contig.data_ptr<index_t>(),
nullptr,
new_offsets.data_ptr<index_t>(),
new_indices.data_ptr<index_t>(),
nullptr,
nullptr);
C10_CUDA_KERNEL_LAUNCH_CHECK();
}));
}
return {new_lengths, new_indices, new_weights, new_pos};
}
template <typename Dtype>
__global__ void reorder_batched_ad_lengths_kernel(
// reorder lengths from (ragged) [B x T x #num_ads_b)] to
// [T][B][#num_ads_b], i.e. [T][sum(#num_ads_b)].
const at::PackedTensorAccessor32<Dtype, 1, at::RestrictPtrTraits>
cat_ad_lengths,
const at::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits>
batch_offsets,
at::PackedTensorAccessor32<Dtype, 1, at::RestrictPtrTraits>
reordered_cat_ad_lengths,
int32_t T) {
const int32_t B = batch_offsets.size(0) - 1;
const int32_t num_ads_in_batch = batch_offsets[B];
// warp-per-segment.
const int32_t b_t = blockIdx.x * blockDim.y + threadIdx.y;
const int32_t b = b_t % B;
const int32_t t = b_t / B;
if (t >= T) {
return;
}
const int32_t num_ads_b = batch_offsets[b + 1] - batch_offsets[b];
const int32_t input_segment_start = T * batch_offsets[b] + t * num_ads_b;
const int32_t output_segment_start = t * num_ads_in_batch + batch_offsets[b];
for (int32_t i = threadIdx.x; i < num_ads_b; i += blockDim.x) {
reordered_cat_ad_lengths[output_segment_start + i] =
cat_ad_lengths[input_segment_start + i];
}
}
Tensor reorder_batched_ad_lengths_gpu(
const Tensor& cat_ad_lengths,
const Tensor& batch_offsets,
const int64_t num_ads_in_batch) {
TENSOR_ON_CUDA_GPU(cat_ad_lengths);
TENSOR_ON_CUDA_GPU(batch_offsets);
TENSORS_ON_SAME_DEVICE(cat_ad_lengths, batch_offsets);
at::cuda::OptionalCUDAGuard device_guard;
device_guard.set_index(cat_ad_lengths.get_device());
const int64_t B = batch_offsets.numel() - 1;
const int64_t T = cat_ad_lengths.numel() / num_ads_in_batch;
Tensor reordered_cat_ad_lengths = at::empty_like(cat_ad_lengths);
const dim3 threads(32, 32);
const dim3 blocks((B * T + 32 - 1) / 32);
AT_DISPATCH_ALL_TYPES(
cat_ad_lengths.type(), "reorder_batched_ad_lengths_gpu_kernel", [&] {
reorder_batched_ad_lengths_kernel<scalar_t>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
cat_ad_lengths
.packed_accessor32<scalar_t, 1, at::RestrictPtrTraits>(),
batch_offsets
.packed_accessor32<int32_t, 1, at::RestrictPtrTraits>(),
reordered_cat_ad_lengths
.packed_accessor32<scalar_t, 1, at::RestrictPtrTraits>(),
T);
C10_CUDA_KERNEL_LAUNCH_CHECK();
});
return reordered_cat_ad_lengths;
}
template <typename Dtype, typename index_t = int32_t>
__global__ void reorder_batched_ad_indices_kernel(
// reorder indices from (ragged) [B x T x #num_ads_b x length_{b, t, a})]
// to [T][B][#num_ads_b][length_{b, t, a}], i.e. [sum(length_{b, t, a})],
// laid out as [T][B][A][L] (if all lengths were equal).
const at::PackedTensorAccessor32<index_t, 1, at::RestrictPtrTraits>
cat_ad_offsets,
const at::PackedTensorAccessor32<Dtype, 1, at::RestrictPtrTraits>
cat_ad_indices,
const at::PackedTensorAccessor32<index_t, 1, at::RestrictPtrTraits>
reordered_cat_ad_offsets,
at::PackedTensorAccessor32<Dtype, 1, at::RestrictPtrTraits>
reordered_cat_ad_indices,
const at::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits>
batch_offsets,
int32_t T) {
const int32_t B = batch_offsets.size(0) - 1;
const int32_t num_ads_in_batch = batch_offsets[B];
// warp-per-segment.
const int32_t b_t = blockIdx.x * blockDim.y + threadIdx.y;
const int32_t b = b_t % B;
const int32_t t = b_t / B;
if (t >= T) {
return;
}
// for each ad,
const int32_t num_ads_b = batch_offsets[b + 1] - batch_offsets[b];
const int32_t b_t_start = T * batch_offsets[b] + t * num_ads_b;
const int32_t input_segment_offset_start =
T * batch_offsets[b] + t * num_ads_b;
const int32_t input_segment_offset_end =
T * batch_offsets[b] + t * num_ads_b + num_ads_b;
// Idea: we want to copy the entire segment of size sum_a(length_{b, t, a})
// from starting point (given by cat_ad_offsets[b, t])
// to end point (given by reordered_cat_ad_indices[t][b])
const int32_t input_segment_start =
cat_ad_offsets[input_segment_offset_start];
const int32_t input_segment_end = cat_ad_offsets[input_segment_offset_end];
const int32_t output_segment_offset_start =
t * num_ads_in_batch + batch_offsets[b];
const int32_t output_segment_start =
reordered_cat_ad_offsets[output_segment_offset_start];
for (int32_t i = threadIdx.x; i < input_segment_end - input_segment_start;
i += blockDim.x) {
reordered_cat_ad_indices[output_segment_start + i] =
cat_ad_indices[input_segment_start + i];
}
}
Tensor reorder_batched_ad_indices_gpu(
const Tensor& cat_ad_offsets,
const Tensor& cat_ad_indices,
const Tensor& reordered_cat_ad_offsets,
const Tensor& batch_offsets,
const int64_t num_ads_in_batch) {
TENSOR_ON_CUDA_GPU(cat_ad_offsets);
TENSOR_ON_CUDA_GPU(cat_ad_indices);
TENSOR_ON_CUDA_GPU(reordered_cat_ad_offsets);
TENSOR_ON_CUDA_GPU(batch_offsets);
TENSORS_ON_SAME_DEVICE(cat_ad_offsets, cat_ad_indices);
TENSORS_ON_SAME_DEVICE(cat_ad_offsets, reordered_cat_ad_offsets);
TENSORS_ON_SAME_DEVICE(cat_ad_offsets, batch_offsets);
at::cuda::OptionalCUDAGuard device_guard;
device_guard.set_index(cat_ad_offsets.get_device());
const int64_t B = batch_offsets.numel() - 1;
const int64_t T = (cat_ad_offsets.numel() - 1) / num_ads_in_batch;
Tensor reordered_cat_ad_indices = at::empty_like(cat_ad_indices);
const dim3 threads(32, 32);
const dim3 blocks((B * T + 32 - 1) / 32);
AT_DISPATCH_ALL_TYPES(
cat_ad_indices.type(), "reorder_batched_ad_indices_gpu_kernel_1", [&] {
AT_DISPATCH_INDEX_TYPES(
cat_ad_offsets.scalar_type(),
"reorder_batched_ad_indices_gpu_kernel_2",
[&] {
reorder_batched_ad_indices_kernel<scalar_t, index_t><<<
blocks,
threads,
0,
at::cuda::getCurrentCUDAStream()>>>(
cat_ad_offsets
.packed_accessor32<index_t, 1, at::RestrictPtrTraits>(),
cat_ad_indices
.packed_accessor32<scalar_t, 1, at::RestrictPtrTraits>(),
reordered_cat_ad_offsets
.packed_accessor32<index_t, 1, at::RestrictPtrTraits>(),
reordered_cat_ad_indices
.packed_accessor32<scalar_t, 1, at::RestrictPtrTraits>(),
batch_offsets
.packed_accessor32<int32_t, 1, at::RestrictPtrTraits>(),
T);
C10_CUDA_KERNEL_LAUNCH_CHECK();
});
});
return reordered_cat_ad_indices;
}
// Forward kernel for batched unary embedding op
template <typename scalar_t, typename index_t>
__global__ void batched_unary_embeddings_forward_kernel(
const int32_t N,
const int32_t B,
const int32_t T,
const scalar_t* __restrict__ weight, // N * sum(E) * 1 (embedding dimension
// is 1)
const index_t* __restrict__ table_offsets,
const index_t* __restrict__ offsets,
const index_t* __restrict__ indices,
scalar_t* __restrict__ output // N * B * T
) {
index_t sum_E = table_offsets[T];
int32_t b = blockIdx.x * blockDim.x + threadIdx.x;
if (b >= B) {
return;
}
int32_t t = blockIdx.y;
int32_t n = blockIdx.z;
index_t table_offset = table_offsets[t];
index_t indices_start = offsets[t * B + b];
index_t indices_end = offsets[t * B + b + 1];
int32_t L = indices_end - indices_start;
at::acc_type<scalar_t, true> sum = 0.0;
for (int32_t l = 0; l < L; ++l) {
auto idx = __ldg(&indices[indices_start + l]);
sum += weight[n * sum_E + table_offset + idx + 0];
}
output[(n * B + b) * T + t] = sum;
}
Tensor batched_unary_embeddings_forward_cuda(
const Tensor& weight,
const Tensor& table_offsets,
const Tensor& offsets,
const Tensor& indices) {
TENSOR_CONTIGUOUS_AND_ON_CUDA_GPU(table_offsets);
TENSOR_CONTIGUOUS_AND_ON_CUDA_GPU(weight);
TENSOR_CONTIGUOUS_AND_ON_CUDA_GPU(offsets);
TENSOR_CONTIGUOUS_AND_ON_CUDA_GPU(indices);
at::cuda::OptionalCUDAGuard device_guard;
device_guard.set_index(weight.get_device());
// N: number of tasks, T: number of tables, B: batch size
const int32_t N = weight.size(0);
const int32_t T = table_offsets.numel() - 1;
const int32_t B = (offsets.numel() - 1) / T;
TORCH_CHECK(N > 0);
TORCH_CHECK(B > 0);
TORCH_CHECK(T > 0);
TORCH_CHECK(T <= 65535);
TORCH_CHECK(N <= 65535);
int32_t threads = std::min<int32_t>(B, 512);
dim3 blocks(cuda_calc_xblock_count(B, threads), T, N);
auto output = at::empty({N, B, T}, weight.options());
AT_DISPATCH_INDEX_TYPES(
indices.type(), "batched_unary_embeddings_forward_kernel", [&] {
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
weight.type(), "batched_unary_embeddings_forward_kernel", [&] {
batched_unary_embeddings_forward_kernel<scalar_t>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
N,
B,
T,
weight.data_ptr<scalar_t>(),
table_offsets.data_ptr<index_t>(),
offsets.data_ptr<index_t>(),
indices.data_ptr<index_t>(),
output.data_ptr<scalar_t>());
C10_CUDA_KERNEL_LAUNCH_CHECK();
});
});
return output;
}
// Backward kernel for batched unary embedding op
// We sort input indices so we don't have race conditions, an approach similar
// to the usual split table batched embedding backward.
// We can think of the following alternatives but each with challenges:
// 1) Assign output elements to different threads. Each thread scan all indices
// corresponding to the table it owns but only accumulate gradients when an
// index value matches with the output element it owns.
// A challenge is each thread need to binary search to map from [0 .. sum_E]
// to table id.
// 2) Densify indices and offsets to create [B, sum_E] matrix. Then, do batched
// GEMM where ith GEMM multiplies [N, B] submatrix of grad_output with
// [B, E_i] submatrix where E_i is the num of embeddings of ith table.
// Concatenating the GEMM outputs will result in [N, B, T]
// A challenge is there's no available batched GEMM routine with varying K
// dimension.
template <typename scalar_t, typename index_t>
__global__ void batched_unary_embeddings_backward_kernel(
const int32_t N,
const int32_t B,
const int32_t T,
const scalar_t* __restrict__ grad_output, // [N * B * T]
const index_t* __restrict__ table_offsets,
scalar_t* __restrict__ grad_weight, // [N * sum_E * 1] (embedding
// dimension is 1)
const at::PackedTensorAccessor32<index_t, 1, at::RestrictPtrTraits>
sorted_linear_indices_run,
const int32_t* __restrict__ sorted_linear_indices_cumulative_run_lengths,
const int32_t* __restrict__ sorted_infos,
const int32_t* __restrict__ sorted_linear_indices_num_runs,
FixedDivisor fd) {
int32_t run_id = blockIdx.x * blockDim.x + threadIdx.x;
int32_t n = blockIdx.y;
if (n >= N) {
return;
}
if (run_id >= sorted_linear_indices_run.size(0)) {
return;
}
if (run_id >= sorted_linear_indices_num_runs[0]) {
return;
}
int64_t linear_index = sorted_linear_indices_run[run_id];
int32_t segment_start = sorted_linear_indices_cumulative_run_lengths[run_id];
int32_t segment_end =
sorted_linear_indices_cumulative_run_lengths[run_id + 1];
int32_t SL = segment_end - segment_start;
if (SL == 0) {
return;
}
// now, each segment corresponds to exactly one table `t` and row in
// that table (`idx`). Thus, we can hoist out some of the book-keeping.
auto info = sorted_infos[segment_start];
int t = fd.Div(info);
at::acc_type<scalar_t, true> grad_sum = 0.0;
for (int32_t sl = 0; sl < SL; ++sl) {
int32_t b = fd.Mod(sorted_infos[segment_start + sl]);
grad_sum += grad_output[(n * B + b) * T + t];
}
index_t table_offset = table_offsets[t];
index_t sum_E = table_offsets[T];
int64_t idx = linear_index - table_offset;
grad_weight[n * sum_E + table_offset + idx] = grad_sum;
}
Tensor batched_unary_embeddings_backward_cuda(
const Tensor& grad_output,
const Tensor& weight,
const Tensor& table_offsets,
const Tensor& offsets,
const Tensor& indices) {
TENSOR_ON_CUDA_GPU(grad_output);
TENSOR_ON_CUDA_GPU(weight);
TENSOR_ON_CUDA_GPU(table_offsets);
TENSOR_ON_CUDA_GPU(offsets);
TENSOR_ON_CUDA_GPU(indices);
at::cuda::OptionalCUDAGuard device_guard;
device_guard.set_index(grad_output.get_device());
// N: number of tasks, T: number of tables, B: batch size
const int32_t N = grad_output.size(0);
const int32_t B = grad_output.size(1);
const int32_t T = grad_output.size(2);
TORCH_CHECK(N > 0);
TORCH_CHECK(B > 0);
TORCH_CHECK(T > 0);
// weight: [N, sum_E]
// total_hash_size_bits = log2(sum_E)
int64_t total_hash_size_bits = log2(weight.numel() / N) + 1;
Tensor linear_indices, linear_indices_sorted;
Tensor infos_sorted;
Tensor sorted_linear_indices_run, sorted_linear_indices_run_lengths,
sorted_linear_indices_num_runs,
sorted_linear_indices_cumulative_run_lengths;
std::tie(
linear_indices,
linear_indices_sorted,
infos_sorted,
sorted_linear_indices_run,
sorted_linear_indices_run_lengths,
sorted_linear_indices_num_runs,
sorted_linear_indices_cumulative_run_lengths) =
transpose_embedding_input(
table_offsets, total_hash_size_bits, indices, offsets);
int threads = std::min<int32_t>(sorted_linear_indices_run.numel(), 512);
dim3 blocks(
cuda_calc_xblock_count(sorted_linear_indices_run.numel(), threads), N);
auto grad_weight = at::zeros_like(weight);
AT_DISPATCH_INDEX_TYPES(
indices.type(), "batched_unary_embeddings_backward_kernel", [&] {
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
grad_output.type(),
"batched_unary_embeddings_backward_kernel",
[&] {
batched_unary_embeddings_backward_kernel<scalar_t>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
N,
B,
T,
grad_output.data_ptr<scalar_t>(),
table_offsets.data_ptr<index_t>(),
grad_weight.data_ptr<scalar_t>(),
sorted_linear_indices_run.packed_accessor32<
index_t,
1,
at::RestrictPtrTraits>(),
sorted_linear_indices_cumulative_run_lengths
.data_ptr<int32_t>(),
infos_sorted.data_ptr<int32_t>(),
sorted_linear_indices_num_runs.data_ptr<int32_t>(),
FixedDivisor(B));
C10_CUDA_KERNEL_LAUNCH_CHECK();
});
});
return grad_weight;
}
Tensor lengths_range_cuda(
const Tensor& t_in,
const c10::optional<std::vector<int64_t>>& shape) {
TENSOR_ON_CUDA_GPU(t_in);
TENSOR_NDIM_EQUALS(t_in, 1);
at::cuda::OptionalCUDAGuard device_guard;
device_guard.set_index(t_in.get_device());
const auto t_in_contig = t_in.contiguous();
const auto num_seq = t_in_contig.numel();
Tensor offsets;
int64_t output_size = 1;
if (shape.has_value()) {
offsets = fbgemm_gpu::asynchronous_exclusive_cumsum_gpu(t_in_contig);
output_size = c10::multiply_integers(shape.value());
} else {
// if we don't provide the the shape info, this is a slow path
// we need to transfer the size of the output from GPU to CPU
offsets = fbgemm_gpu::asynchronous_complete_cumsum_gpu(t_in_contig);
AT_DISPATCH_INDEX_TYPES(
t_in_contig.scalar_type(), "lengths_range_output_size", [&] {
output_size = *(offsets[num_seq].cpu().data_ptr<index_t>());
});
}
auto output = at::empty({output_size}, t_in.options());
uint32_t vector_size;
uint32_t rows_per_block;
uint32_t num_blocks;
std::tie(num_blocks, rows_per_block, vector_size) =
calc_offsets_range_thread_block(output_size, num_seq);
dim3 threads(vector_size, rows_per_block);
AT_DISPATCH_INDEX_TYPES(
t_in_contig.scalar_type(), "lengths_range_compute", [&] {
fbgemm_gpu::_offsets_range_cuda_kernel<index_t>
<<<num_blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
num_seq,
output_size,
offsets.data_ptr<index_t>(),
output.data_ptr<index_t>());
C10_CUDA_KERNEL_LAUNCH_CHECK();
});
return output;
}
// Kernel for permuting the indices and weights. Used for permutation of
// sparse features
template <bool has_weight, typename index_t, typename scalar_t>
__global__ void permute_indices_weights_kernel(
int32_t len,
int32_t T,
int32_t B,
const index_t* __restrict__ indices,
const scalar_t* __restrict__ weights,
const int32_t* __restrict__ permute,
const index_t* __restrict__ input_offsets,
const index_t* __restrict__ output_offsets,
index_t* __restrict__ permuted_indices,
scalar_t* __restrict__ permuted_weights) {
int32_t b_t_start = blockIdx.x * blockDim.y + threadIdx.y;
const int stride = gridDim.x * blockDim.y;
for (int b_t = b_t_start; b_t < B * T; b_t += stride) {
int32_t b = b_t % B;
int32_t t = b_t / B;
index_t output_start = output_offsets[b_t];
index_t segment_length;
if (b_t == B * T - 1) {
segment_length = len - output_offsets[b_t];
} else {
segment_length = output_offsets[b_t + 1] - output_offsets[b_t];
}
index_t input_start = input_offsets[permute[t] * B + b];
for (int32_t i = threadIdx.x; i < segment_length; i += blockDim.x) {
permuted_indices[output_start + i] = indices[input_start + i];
if (has_weight) {
permuted_weights[output_start + i] = weights[input_start + i];
}
}
}
}
std::tuple<Tensor, Tensor, c10::optional<Tensor>> permute_sparse_features_cuda(
const Tensor& permute,
const Tensor& lengths,
const Tensor& indices,
const c10::optional<Tensor>& weights) {
TENSOR_ON_CUDA_GPU(permute);
TENSOR_ON_CUDA_GPU(lengths);
TENSOR_ON_CUDA_GPU(indices);
TENSOR_ON_CUDA_GPU(weights);
TENSORS_ON_SAME_DEVICE(permute, lengths);
TENSORS_ON_SAME_DEVICE(permute, indices);
TENSORS_ON_SAME_DEVICE(permute, weights);
at::cuda::OptionalCUDAGuard device_guard;
device_guard.set_index(indices.get_device());
// the following implementation requires lengths and indices has the same
// dtype if usecase comes up that requires different dtype (e.g. int32 for
// lengths and int64 for indices, this will give a better error msg for
// debugging
TENSORS_HAVE_SAME_TYPE(lengths, indices);
TORCH_CHECK(
lengths.dim() == 2,
"The dimension of lengths tensor should be equal to 2 to correctly infer number of features and batch size.")
const auto permute_contig = permute.contiguous();
const auto lengths_contig = lengths.contiguous();
const auto indices_contig = indices.contiguous();
// the features to permute over can be less or more with or without
// repetitions
const auto num_output_features = permute.numel();
const auto num_features = lengths.size(0);
const auto B = lengths.size(1);
Tensor permuted_lengths;
Tensor permuted_indices;
Tensor permuted_weights;
permuted_lengths = at::empty({num_output_features, B}, lengths.options());
constexpr int32_t threads_1 = 256;
const auto blocks_1 =
cuda_calc_xblock_count(B * num_output_features, threads_1);
AT_DISPATCH_INDEX_TYPES(
lengths.scalar_type(), "permute_2D_lengths_kernel", [&] {
fbgemm_gpu::permute_2D_lengths_kernel<index_t>
<<<blocks_1, threads_1, 0, at::cuda::getCurrentCUDAStream()>>>(
num_output_features,
B,
lengths_contig.data_ptr<index_t>(),
permute.data_ptr<int32_t>(),
permuted_lengths.data_ptr<index_t>());
C10_CUDA_KERNEL_LAUNCH_CHECK();
});
// convert lengths to offsets
const auto input_offsets =
fbgemm_gpu::asynchronous_exclusive_cumsum_gpu(lengths_contig);
const auto output_offsets =
fbgemm_gpu::asynchronous_exclusive_cumsum_gpu(permuted_lengths);
int64_t permuted_lengths_sum = indices.numel();
/* TODO: Remove the condition protecting the slow path because even when the
* condition below is true permuted_lengths.sum() could still be needed. For
* instance if there are three features with indices `[0, 1, 2]`, `permute`
* can be `[0, 1, 1]` for which permuted lengths sum would be needed to
* create permuted_{indices, weights} and `permuted_lengths_sum =
* indices.numel() or weights.numdel() would be incorrect.
*/
if (num_features != num_output_features) {
permuted_lengths_sum = permuted_lengths.sum().item<int64_t>();
}
constexpr int32_t BT_blocks = 32;
dim3 threads_2(32, BT_blocks);
const auto blocks_2 =
cuda_calc_xblock_count(B * num_output_features, BT_blocks);
permuted_indices = at::empty(permuted_lengths_sum, indices.options());
if (weights.has_value()) {
const Tensor weights_value = weights.value();
const auto weights_value_contig = weights_value.contiguous();
permuted_weights = at::empty(permuted_lengths_sum, weights_value.options());
AT_DISPATCH_INDEX_TYPES(
input_offsets.scalar_type(), "permute_indices_weights_kernel_1", [&] {
AT_DISPATCH_FLOATING_TYPES_AND(
at::ScalarType::Int,
weights_value.scalar_type(),
"permute_indices_weights_kernel_2",
[&] {
permute_indices_weights_kernel<true, index_t, scalar_t>
<<<blocks_2,
threads_2,
0,
at::cuda::getCurrentCUDAStream()>>>(
permuted_lengths_sum,
num_output_features,
B,
indices_contig.data_ptr<index_t>(),
weights_value_contig.data_ptr<scalar_t>(),
permute_contig.data_ptr<int32_t>(),
input_offsets.data_ptr<index_t>(),
output_offsets.data_ptr<index_t>(),
permuted_indices.data_ptr<index_t>(),
permuted_weights.data_ptr<scalar_t>());
C10_CUDA_KERNEL_LAUNCH_CHECK();
});
});
} else {
AT_DISPATCH_INDEX_TYPES(
indices.scalar_type(), "permute_indices_kernel", [&] {
permute_indices_weights_kernel<false, index_t, std::nullptr_t>
<<<blocks_2, threads_2, 0, at::cuda::getCurrentCUDAStream()>>>(
permuted_lengths_sum,
num_output_features,
B,
indices_contig.data_ptr<index_t>(),
nullptr,
permute_contig.data_ptr<int32_t>(),
input_offsets.data_ptr<index_t>(),
output_offsets.data_ptr<index_t>(),
permuted_indices.data_ptr<index_t>(),
nullptr);
C10_CUDA_KERNEL_LAUNCH_CHECK();
});
}
return {permuted_lengths, permuted_indices, permuted_weights};
}
} // namespace fbgemm_gpu