fbgemm_gpu/include/fbgemm_gpu/sparse_ops_utils.h (214 lines of code) (raw):
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/
#pragma once
#include <ATen/ATen.h>
inline bool torch_tensor_on_cpu_check(const c10::optional<at::Tensor>& ten) {
return !ten.has_value() ||
!ten->is_cuda(); // TODO: Should be a better way to do this
}
inline std::string torch_tensor_device_name(const at::Tensor& ten) {
return c10::DeviceTypeName(ten.device().type());
}
inline std::string torch_tensor_device_name(
const c10::optional<at::Tensor>& ten) {
if (ten.has_value()) {
return c10::DeviceTypeName(ten->device().type());
} else {
return "No device: optional tensor unused.";
}
}
inline bool torch_tensor_on_same_device_check(
const at::Tensor& ten1,
const at::Tensor& ten2) {
return ten1.get_device() == ten2.get_device();
}
inline bool torch_tensor_on_same_device_check(
const at::Tensor& ten1,
const c10::optional<at::Tensor>& ten2) {
return !ten2.has_value() || ten1.get_device() == ten2->get_device();
}
inline bool torch_tensor_on_cuda_gpu_check(const at::Tensor& ten) {
return ten.is_cuda();
}
inline bool torch_tensor_on_cuda_gpu_check(
const c10::optional<at::Tensor>& ten) {
return !ten.has_value() || ten->is_cuda();
}
inline bool torch_tensor_empty_or_on_cuda_gpu_check(const at::Tensor& ten) {
return (ten.numel() == 0) || ten.is_cuda();
}
inline bool torch_tensor_empty_or_on_cuda_gpu_check(
const c10::optional<at::Tensor>& ten) {
return !ten.has_value() || (ten->numel() == 0) || ten->is_cuda();
}
#define DISPATCH_TO_CUDA(name, function) \
m.impl(name, torch::dispatch(c10::DispatchKey::CUDA, TORCH_FN(function)))
#define DISPATCH_TO_CPU(name, function) \
m.impl(name, torch::dispatch(c10::DispatchKey::CPU, TORCH_FN(function)))
#define DISPATCH_TO_ALL(name, function) \
m.impl(name, torch::dispatch(c10::DispatchKey::CatchAll, TORCH_FN(function)))
#define TENSOR_ON_CPU(x) \
TORCH_CHECK( \
torch_tensor_on_cpu_check(x), \
#x " must be a CPU tensor; it is currently on device ", \
torch_tensor_device_name(x))
#define TENSORS_HAVE_SAME_TYPE(x, y) \
TORCH_CHECK( \
(x).dtype() == (y).dtype(), \
#x " must have the same type as " #y " types were ", \
(x).dtype().name(), \
" and ", \
(y).dtype().name())
#define TENSOR_ON_CUDA_GPU(x) \
TORCH_CHECK( \
torch_tensor_on_cuda_gpu_check(x), \
#x " must be a CUDA tensor; it is currently on device ", \
torch_tensor_device_name(x))
#define TENSOR_EMPTY_OR_ON_CUDA_GPU(x) \
TORCH_CHECK( \
torch_tensor_empty_or_on_cuda_gpu_check(x), \
#x " must be empty or a CUDA tensor; it is currently on device ", \
torch_tensor_device_name(x))
#define TENSORS_ON_SAME_DEVICE(x, y) \
TORCH_CHECK( \
torch_tensor_on_same_device_check(x, y), \
#x " must be on the same device as " #y "! " #x " is currently on ", \
torch_tensor_device_name(x), \
#y " is currently on ", \
torch_tensor_device_name(y))
#define TENSORS_HAVE_SAME_TYPE(x, y) \
TORCH_CHECK( \
(x).dtype() == (y).dtype(), \
#x " must have the same type as " #y " types were ", \
(x).dtype().name(), \
" and ", \
(y).dtype().name())
#define TENSOR_NDIM_EQUALS(ten, dims) \
TORCH_CHECK( \
(ten).ndimension() == (dims), \
"Tensor '" #ten "' must have " #dims \
" dimension(s). " \
"Found ", \
(ten).ndimension())
#define TENSOR_CONTIGUOUS(x) \
TORCH_CHECK((x).is_contiguous(), #x " must be contiguous")
#define TENSOR_CONTIGUOUS_AND_ON_CUDA_GPU(x) \
TENSOR_ON_CUDA_GPU(x); \
TENSOR_CONTIGUOUS(x)
/// Determine an appropriate CUDA block count along the x axis
///
/// When launching CUDA kernels the number of blocks B is often calculated
/// w.r.t. the number of threads T and items to be processed N as
/// B=(N+T-1)/T - which is integer division rounding up.
/// This function abstracts that calculation, performs it in an
/// overflow-safe manner, and limits the return value appropriately.
///
/// This is a general function for all integral data types.
/// The goal of this set of functions is to ensure correct calculations
/// across a variety of data types without forcing the programmer to
/// cast to an appropriate type (which is dangerous because we don't
/// have conversion warnings enabled). The values of the variables
/// can then be checked for correctness at run-time.
/// Specialized functions below handle various combinations of signed
/// and unsigned inputs. This system prevents "pointless comparison
/// against zero" warnings from the compiler for unsigned types
/// (simpler ways of suppressing this warning didn't work) while
/// maintaining the various warnings.
///
/// Function is designed to facilitate run-time value checking.
template <
typename Integer1,
typename Integer2,
std::enable_if_t<std::is_integral<Integer1>::value, bool> = true,
std::enable_if_t<std::is_integral<Integer2>::value, bool> = true>
constexpr uint32_t cuda_calc_xblock_count_base(
Integer1 num_items,
Integer2 threads_per_block) {
// The number of threads can be as high as 2048 on some newer architectures,
// but this is not portable.
TORCH_CHECK(threads_per_block <= 1024, "Number of threads must be <=1024!");
// The CUDA specification at
// https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#features-and-technical-specifications
// states that for compute capability 3.5-* the grid dimension of a kernel
// launch must must be <=2^31-1.
constexpr uint64_t max_blocks = 2147483647;
const auto u_num_items = static_cast<uint64_t>(num_items);
const auto u_threads = static_cast<uint64_t>(threads_per_block);
// Overflow safe variant of (a + b - 1) / b
const uint64_t blocks =
u_num_items / u_threads + (u_num_items % u_threads != 0);
return static_cast<uint32_t>(std::min(blocks, max_blocks));
}
// See: cuda_calc_xblock_count_base
template <
typename Integer1,
typename Integer2,
std::enable_if_t<
std::is_integral<Integer1>::value && std::is_signed<Integer2>::value,
bool> = true,
std::enable_if_t<
std::is_integral<Integer2>::value && std::is_unsigned<Integer2>::value,
bool> = true>
constexpr uint32_t cuda_calc_xblock_count(
Integer1 num_items,
Integer2 threads_per_block) {
TORCH_CHECK(
num_items >= 0,
"When calculating block counts, the number of items must be positive!");
return cuda_calc_xblock_count_base(num_items, threads_per_block);
}
// See: cuda_calc_xblock_count_base
template <
typename Integer1,
typename Integer2,
std::enable_if_t<
std::is_integral<Integer1>::value && std::is_unsigned<Integer2>::value,
bool> = true,
std::enable_if_t<
std::is_integral<Integer2>::value && std::is_signed<Integer2>::value,
bool> = true>
constexpr uint32_t cuda_calc_xblock_count(
Integer1 num_items,
Integer2 threads_per_block) {
TORCH_CHECK(
threads_per_block >= 0,
"When calculating thread counts, the number of threads must be positive!");
return cuda_calc_xblock_count_base(num_items, threads_per_block);
}
// See: cuda_calc_xblock_count_base
template <
typename Integer1,
typename Integer2,
std::enable_if_t<
std::is_integral<Integer1>::value && std::is_signed<Integer2>::value,
bool> = true,
std::enable_if_t<
std::is_integral<Integer2>::value && std::is_signed<Integer2>::value,
bool> = true>
constexpr uint32_t cuda_calc_xblock_count(
Integer1 num_items,
Integer2 threads_per_block) {
TORCH_CHECK(
num_items >= 0,
"When calculating block counts, the number of items must be positive!");
TORCH_CHECK(
threads_per_block >= 0,
"When calculating thread counts, the number of threads must be positive!");
return cuda_calc_xblock_count_base(num_items, threads_per_block);
}
// See: cuda_calc_xblock_count_base
template <
typename Integer1,
typename Integer2,
std::enable_if_t<
std::is_integral<Integer1>::value && std::is_unsigned<Integer2>::value,
bool> = true,
std::enable_if_t<
std::is_integral<Integer2>::value && std::is_unsigned<Integer2>::value,
bool> = true>
constexpr uint32_t cuda_calc_xblock_count(
Integer1 num_items,
Integer2 threads_per_block) {
return cuda_calc_xblock_count_base(num_items, threads_per_block);
}
/// Determine an appropriate CUDA block count.
///
/// See cuda_calc_xblock_count_base() for details.
template <
typename Integer1,
typename Integer2,
std::enable_if_t<std::is_integral<Integer1>::value, bool> = true,
std::enable_if_t<std::is_integral<Integer2>::value, bool> = true>
constexpr uint32_t cuda_calc_block_count(
Integer1 num_items,
Integer2 threads_per_block) {
// The CUDA specification at
// https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#features-and-technical-specifications
// states that the grid dimension of a kernel launch must generally
// be <=65535. (For compute capability 3.5-* the grid's x-dimension must
// be <=2^31-1.) Because this function does not know which dimension
// is being calculated, we use the smaller limit.
constexpr uint32_t max_blocks = 65535;
return std::min(
cuda_calc_xblock_count(num_items, threads_per_block), max_blocks);
}
// Used in jagged_tensor_ops.cu and jagged_tensor_ops_cpu.cpp
// Passing lambda exp argument by value instead of by reference to avoid
// "internal compiler error: in maybe_undo_parenthesized_ref" error for specific
// compiler version.
#define JAGGED_TENSOR_DISPATCH_DIMS() \
AT_DISPATCH_INDEX_TYPES(x_offsets[0].scalar_type(), "jagged_indices", [=] { \
switch (num_jagged_dim) { \
case 1: \
INVOKE_KERNEL_WITH_DIM(1); \
break; \
case 2: \
INVOKE_KERNEL_WITH_DIM(2); \
break; \
case 3: \
INVOKE_KERNEL_WITH_DIM(3); \
break; \
case 4: \
INVOKE_KERNEL_WITH_DIM(4); \
break; \
case 5: \
INVOKE_KERNEL_WITH_DIM(5); \
break; \
default: \
TORCH_CHECK( \
false, "unsupported number of jagged dim ", num_jagged_dim); \
} \
});