maga_transformer/cpp/rocm/hip_utils.h (233 lines of code) (raw):

#pragma once #include "maga_transformer/cpp/utils/Logger.h" #include "maga_transformer/cpp/utils/AssertUtils.h" #include "maga_transformer/cpp/utils/StringUtil.h" #include <hip/hip_runtime.h> #include "cuda_shims.h" #include <hipblas/hipblas.h> #include <hipblaslt/hipblaslt.h> #include <hipblaslt/hipblaslt-ext.hpp> #include <fstream> #include <iostream> #include <string> #include <vector> namespace rtp_llm { namespace rocm { #define HIPBLAS_WORKSPACE_SIZE (512L*1024L*1024L) // C*splitK #define ROCM_RUNTIME_MEM_SIZE (HIPBLAS_WORKSPACE_SIZE + 512L*1024L*1024L) #define ROCM_CHECK(val) rocm::check((val), __FILE__, __LINE__) #define ROCM_SYNC_AND_CHECK() rocm::sync_and_check(__FILE__, __LINE__) #define ROCM_CHECK_VALUE(val, info, ...) \ do { \ bool is_valid_val = (val); \ if (!is_valid_val) { \ rocm::throwRocmError(__FILE__, __LINE__, rtp_llm::fmtstr(info, ##__VA_ARGS__)); \ } \ } while (0) #define ROCM_FAIL(info, ...) rocm::throwRocmError(__FILE__, __LINE__, rtp_llm::fmtstr(info, ##__VA_ARGS__)) void throwRocmError(const char* const file, int const line, std::string const& info = ""); template<typename T> void check(T result, const char* const file, int const line); void sync_and_check(const char* const file, int const line); enum FtHipDataType { FP32 = 0, FP16 = 1, BF16 = 2, INT8 = 3, FP8 = 4 }; enum class OperationType { FP32, FP16, BF16, INT8, FP8 }; inline int div_up(int a, int n) { return (a + n - 1) / n; } int get_sm(); int getDevice(); int getDeviceCount(); typedef struct __attribute__((aligned(4))) { half x, y, z, w; } half4; //////////////////////////////////////////////////////////////////////////////////////////////////// struct __attribute__((aligned(16))) Float4_ { float2 x; float2 y; }; //////////////////////////////////////////////////////////////////////////////////////////////////// struct __attribute__((aligned(32))) Float8_ { float2 x; float2 y; float2 z; float2 w; }; //////////////////////////////////////////////////////////////////////////////////////////////////// struct __attribute__((aligned(8))) bf16_4_t { __nv_bfloat162 x; __nv_bfloat162 y; }; //////////////////////////////////////////////////////////////////////////////////////////////////// struct __attribute__((aligned(16))) bf16_8_t { __nv_bfloat162 x; __nv_bfloat162 y; __nv_bfloat162 z; __nv_bfloat162 w; }; // clang-format off template<typename T> struct packed_type_2; template <> struct packed_type_2<float> { using type = float; }; // we don't need to pack float by default template <> struct packed_type_2<half> { using type = half2; }; template<> struct packed_type_2<__nv_bfloat16> { using type = __nv_bfloat162; }; template <typename T, int N> struct packed_type; template <typename T> struct packed_type<T, 1> { using type = T; }; template <> struct packed_type<int8_t, 1> { using type = int8_t; }; template <> struct packed_type<int8_t, 2> { using type = int16_t; }; template <> struct packed_type<int8_t, 4> { using type = int32_t; }; template <> struct packed_type<int8_t, 8> { using type = int64_t; }; #ifdef ENABLE_FP8 template <> struct packed_type<__nv_fp8_e4m3, 1> { using type = __nv_fp8_e4m3; }; template <> struct packed_type<__nv_fp8_e4m3, 2> { using type = fp8_2_t; }; template <> struct packed_type<__nv_fp8_e4m3, 4> { using type = fp8_4_t; }; template <> struct packed_type<__nv_fp8_e4m3, 8> { using type = fp8_8_t; }; #endif // ENABLE_FP8 template <> struct packed_type<uint16_t, 2> { using type = uint32_t; }; template <> struct packed_type<uint16_t, 4> { using type = uint2; }; template <> struct packed_type<uint16_t, 8> { using type = uint4; }; template <> struct packed_type<half, 2> { using type = uint32_t; }; template <> struct packed_type<half, 4> { using type = uint2; }; template <> struct packed_type<half, 8> { using type = uint4; }; #ifdef ENABLE_BF16 template <> struct packed_type<__nv_bfloat16, 2> { using type = __nv_bfloat162; }; template <> struct packed_type<__nv_bfloat16, 4> { using type = bf16_4_t; }; template <> struct packed_type<__nv_bfloat16, 8> { using type = bf16_8_t; }; #endif template <> struct packed_type<float, 2> { using type = float2; }; template <> struct packed_type<float, 4> { using type = float4; }; template <> struct packed_type<float, 8> { using type = Float8_; }; template<typename T> struct num_elems; template <> struct num_elems<float> { static constexpr int value = 1; }; template <> struct num_elems<float2> { static constexpr int value = 2; }; template <> struct num_elems<float4> { static constexpr int value = 4; }; template <> struct num_elems<Float4_> { static constexpr int value = 4; }; template <> struct num_elems<Float8_> { static constexpr int value = 8; }; template <> struct num_elems<half> { static constexpr int value = 1; }; template <> struct num_elems<half2> { static constexpr int value = 2; }; template <> struct num_elems<uint32_t> { static constexpr int value = 2; }; template <> struct num_elems<int32_t> { static constexpr int value = 2; }; template <> struct num_elems<int64_t> { static constexpr int value = 4; }; template <> struct num_elems<uint2> { static constexpr int value = 4; }; template <> struct num_elems<uint4> { static constexpr int value = 8; }; #ifdef ENABLE_BF16 template <> struct num_elems<__nv_bfloat16> { static constexpr int value = 1; }; template <> struct num_elems<__nv_bfloat162> { static constexpr int value = 2; }; template <> struct num_elems<bf16_4_t> { static constexpr int value = 4; }; template <> struct num_elems<bf16_8_t> { static constexpr int value = 8; }; #endif #ifdef ENABLE_FP8 template <> struct num_elems<__nv_fp8_e4m3> { static constexpr int value = 1; }; template <> struct num_elems<fp8_2_t> { static constexpr int value = 2; }; template <> struct num_elems<fp8_4_t> { static constexpr int value = 4; }; template <> struct num_elems<fp8_8_t> { static constexpr int value = 8; }; #endif template<typename T, int num> struct packed_as; template<typename T> struct packed_as<T, 1> { using type = T; }; template<> struct packed_as<half, 2> { using type = half2; }; template<> struct packed_as<float, 2> { using type = float2; }; template<> struct packed_as<int8_t, 2> { using type = int16_t; }; template<> struct packed_as<int32_t, 2> { using type = int2; }; template<> struct packed_as<half2, 1> { using type = half; }; template<> struct packed_as<float2, 1> { using type = float; }; #ifdef ENABLE_BF16 template<> struct packed_as<__nv_bfloat16, 2> { using type = __nv_bfloat162; }; template<> struct packed_as<__nv_bfloat162, 1> { using type = __nv_bfloat16; }; #endif inline __device__ float2 operator*(float2 a, float2 b) { return make_float2(a.x * b.x, a.y * b.y); } inline __device__ float2 operator*(float2 a, float b) { return make_float2(a.x * b, a.y * b); } } // namespace rocm } // namespace rtp_llm