fbgemm_gpu/codegen/embedding_forward_template_helpers.cuh (29 lines of code) (raw):

/* * Copyright (c) Meta Platforms, Inc. and affiliates. * All rights reserved. * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. */ #include <ATen/ATen.h> #include <ATen/AccumulateType.h> #include <ATen/TensorUtils.h> #include <ATen/core/TensorAccessor.h> #include <ATen/cuda/CUDAContext.h> #include <c10/cuda/CUDAGuard.h> #include <ATen/cuda/Atomic.cuh> // clang-format off #include "fbgemm_gpu/cub_namespace_prefix.cuh" #include <cub/device/device_radix_sort.cuh> #include <cub/device/device_run_length_encode.cuh> #include <cub/device/device_scan.cuh> #include "fbgemm_gpu/cub_namespace_postfix.cuh" // clang-format on #include <cuda.h> #include <cuda_runtime.h> #include <curand_kernel.h> #include <limits> #include <mutex> #include "fbgemm_gpu/dispatch_macros.h" #include "fbgemm_gpu/embedding_common.h" #include "fbgemm_gpu/fbgemm_cuda_utils.cuh" #include "fbgemm_gpu/sparse_ops_utils.h"