maga_transformer/cpp/rocm/hip_utils.h (233 lines of code) (raw):
#pragma once
#include "maga_transformer/cpp/utils/Logger.h"
#include "maga_transformer/cpp/utils/AssertUtils.h"
#include "maga_transformer/cpp/utils/StringUtil.h"
#include <hip/hip_runtime.h>
#include "cuda_shims.h"
#include <hipblas/hipblas.h>
#include <hipblaslt/hipblaslt.h>
#include <hipblaslt/hipblaslt-ext.hpp>
#include <fstream>
#include <iostream>
#include <string>
#include <vector>
namespace rtp_llm {
namespace rocm {
#define HIPBLAS_WORKSPACE_SIZE (512L*1024L*1024L) // C*splitK
#define ROCM_RUNTIME_MEM_SIZE (HIPBLAS_WORKSPACE_SIZE + 512L*1024L*1024L)
#define ROCM_CHECK(val) rocm::check((val), __FILE__, __LINE__)
#define ROCM_SYNC_AND_CHECK() rocm::sync_and_check(__FILE__, __LINE__)
#define ROCM_CHECK_VALUE(val, info, ...) \
do { \
bool is_valid_val = (val); \
if (!is_valid_val) { \
rocm::throwRocmError(__FILE__, __LINE__, rtp_llm::fmtstr(info, ##__VA_ARGS__)); \
} \
} while (0)
#define ROCM_FAIL(info, ...) rocm::throwRocmError(__FILE__, __LINE__, rtp_llm::fmtstr(info, ##__VA_ARGS__))
void throwRocmError(const char* const file, int const line, std::string const& info = "");
template<typename T>
void check(T result, const char* const file, int const line);
void sync_and_check(const char* const file, int const line);
enum FtHipDataType {
FP32 = 0,
FP16 = 1,
BF16 = 2,
INT8 = 3,
FP8 = 4
};
enum class OperationType {
FP32,
FP16,
BF16,
INT8,
FP8
};
inline int div_up(int a, int n) {
return (a + n - 1) / n;
}
int get_sm();
int getDevice();
int getDeviceCount();
typedef struct __attribute__((aligned(4))) {
half x, y, z, w;
} half4;
////////////////////////////////////////////////////////////////////////////////////////////////////
struct __attribute__((aligned(16))) Float4_ {
float2 x;
float2 y;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
struct __attribute__((aligned(32))) Float8_ {
float2 x;
float2 y;
float2 z;
float2 w;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
struct __attribute__((aligned(8))) bf16_4_t {
__nv_bfloat162 x;
__nv_bfloat162 y;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
struct __attribute__((aligned(16))) bf16_8_t {
__nv_bfloat162 x;
__nv_bfloat162 y;
__nv_bfloat162 z;
__nv_bfloat162 w;
};
// clang-format off
template<typename T> struct packed_type_2;
template <> struct packed_type_2<float> { using type = float; }; // we don't need to pack float by default
template <> struct packed_type_2<half> { using type = half2; };
template<>
struct packed_type_2<__nv_bfloat16> {
using type = __nv_bfloat162;
};
template <typename T, int N>
struct packed_type;
template <typename T>
struct packed_type<T, 1>
{
using type = T;
};
template <>
struct packed_type<int8_t, 1>
{
using type = int8_t;
};
template <>
struct packed_type<int8_t, 2>
{
using type = int16_t;
};
template <>
struct packed_type<int8_t, 4>
{
using type = int32_t;
};
template <>
struct packed_type<int8_t, 8>
{
using type = int64_t;
};
#ifdef ENABLE_FP8
template <>
struct packed_type<__nv_fp8_e4m3, 1>
{
using type = __nv_fp8_e4m3;
};
template <>
struct packed_type<__nv_fp8_e4m3, 2>
{
using type = fp8_2_t;
};
template <>
struct packed_type<__nv_fp8_e4m3, 4>
{
using type = fp8_4_t;
};
template <>
struct packed_type<__nv_fp8_e4m3, 8>
{
using type = fp8_8_t;
};
#endif // ENABLE_FP8
template <>
struct packed_type<uint16_t, 2>
{
using type = uint32_t;
};
template <>
struct packed_type<uint16_t, 4>
{
using type = uint2;
};
template <>
struct packed_type<uint16_t, 8>
{
using type = uint4;
};
template <>
struct packed_type<half, 2>
{
using type = uint32_t;
};
template <>
struct packed_type<half, 4>
{
using type = uint2;
};
template <>
struct packed_type<half, 8>
{
using type = uint4;
};
#ifdef ENABLE_BF16
template <>
struct packed_type<__nv_bfloat16, 2>
{
using type = __nv_bfloat162;
};
template <>
struct packed_type<__nv_bfloat16, 4>
{
using type = bf16_4_t;
};
template <>
struct packed_type<__nv_bfloat16, 8>
{
using type = bf16_8_t;
};
#endif
template <>
struct packed_type<float, 2>
{
using type = float2;
};
template <>
struct packed_type<float, 4>
{
using type = float4;
};
template <>
struct packed_type<float, 8>
{
using type = Float8_;
};
template<typename T> struct num_elems;
template <> struct num_elems<float> { static constexpr int value = 1; };
template <> struct num_elems<float2> { static constexpr int value = 2; };
template <> struct num_elems<float4> { static constexpr int value = 4; };
template <> struct num_elems<Float4_> { static constexpr int value = 4; };
template <> struct num_elems<Float8_> { static constexpr int value = 8; };
template <> struct num_elems<half> { static constexpr int value = 1; };
template <> struct num_elems<half2> { static constexpr int value = 2; };
template <> struct num_elems<uint32_t> { static constexpr int value = 2; };
template <> struct num_elems<int32_t> { static constexpr int value = 2; };
template <> struct num_elems<int64_t> { static constexpr int value = 4; };
template <> struct num_elems<uint2> { static constexpr int value = 4; };
template <> struct num_elems<uint4> { static constexpr int value = 8; };
#ifdef ENABLE_BF16
template <> struct num_elems<__nv_bfloat16> { static constexpr int value = 1; };
template <> struct num_elems<__nv_bfloat162> { static constexpr int value = 2; };
template <> struct num_elems<bf16_4_t> { static constexpr int value = 4; };
template <> struct num_elems<bf16_8_t> { static constexpr int value = 8; };
#endif
#ifdef ENABLE_FP8
template <> struct num_elems<__nv_fp8_e4m3> { static constexpr int value = 1; };
template <> struct num_elems<fp8_2_t> { static constexpr int value = 2; };
template <> struct num_elems<fp8_4_t> { static constexpr int value = 4; };
template <> struct num_elems<fp8_8_t> { static constexpr int value = 8; };
#endif
template<typename T, int num> struct packed_as;
template<typename T> struct packed_as<T, 1> { using type = T; };
template<> struct packed_as<half, 2> { using type = half2; };
template<> struct packed_as<float, 2> { using type = float2; };
template<> struct packed_as<int8_t, 2> { using type = int16_t; };
template<> struct packed_as<int32_t, 2> { using type = int2; };
template<> struct packed_as<half2, 1> { using type = half; };
template<> struct packed_as<float2, 1> { using type = float; };
#ifdef ENABLE_BF16
template<> struct packed_as<__nv_bfloat16, 2> { using type = __nv_bfloat162; };
template<> struct packed_as<__nv_bfloat162, 1> { using type = __nv_bfloat16; };
#endif
inline __device__ float2 operator*(float2 a, float2 b) { return make_float2(a.x * b.x, a.y * b.y); }
inline __device__ float2 operator*(float2 a, float b) { return make_float2(a.x * b, a.y * b); }
} // namespace rocm
} // namespace rtp_llm