maga_transformer/cpp/kernels/_add.h (135 lines of code) (raw):

#pragma once //////////////////////////////////////////////////////////////////////////////////////////////////// inline __device__ float add(float a, float b) { return a + b; } //////////////////////////////////////////////////////////////////////////////////////////////////// inline __device__ float2 add(float2 a, float2 b) { float2 c; c.x = add(a.x, b.x); c.y = add(a.y, b.y); return c; } //////////////////////////////////////////////////////////////////////////////////////////////////// inline __device__ float4 add(float4 a, float4 b) { float4 c; c.x = add(a.x, b.x); c.y = add(a.y, b.y); c.z = add(a.z, b.z); c.w = add(a.w, b.w); return c; } //////////////////////////////////////////////////////////////////////////////////////////////////// #ifdef ENABLE_FP8 inline __device__ Float8_ add(Float8_ a, Float8_ b) { Float8_ c; c.x = add(a.x, b.x); c.y = add(a.y, b.y); c.z = add(a.z, b.z); c.w = add(a.w, b.w); return c; } #endif //////////////////////////////////////////////////////////////////////////////////////////////////// #ifdef ENABLE_BF16 inline __device__ __nv_bfloat16 add(__nv_bfloat16 a, __nv_bfloat16 b) { return a + b; } //////////////////////////////////////////////////////////////////////////////////////////////////// inline __device__ __nv_bfloat162 add(__nv_bfloat162 a, __nv_bfloat162 b) { return bf16hadd2(a, b); } //////////////////////////////////////////////////////////////////////////////////////////////////// inline __device__ bf16_4_t add(bf16_4_t a, bf16_4_t b) { bf16_4_t c; c.x = add(a.x, b.x); c.y = add(a.y, b.y); return c; } //////////////////////////////////////////////////////////////////////////////////////////////////// inline __device__ bf16_8_t add(bf16_8_t a, bf16_8_t b) { bf16_8_t c; c.x = add(a.x, b.x); c.y = add(a.y, b.y); c.z = add(a.z, b.z); c.w = add(a.w, b.w); return c; } #endif // ENABLE_BF16 //////////////////////////////////////////////////////////////////////////////////////////////////// inline __device__ uint16_t add(uint16_t a, uint16_t b) { #if USING_ROCM __half_raw out = __hadd(*reinterpret_cast<__half_raw*>(&a), *reinterpret_cast<__half_raw*>(&b)); return *reinterpret_cast<uint16_t*>(&(out.data)); #else uint16_t c; asm volatile("add.f16 %0, %1, %2;\n" : "=h"(c) : "h"(a), "h"(b)); return c; #endif } //////////////////////////////////////////////////////////////////////////////////////////////////// inline __device__ uint32_t add(uint32_t a, uint32_t b) { #if USING_ROCM __half2 out = __hadd2(*reinterpret_cast<__half2_raw*>(&a), *reinterpret_cast<__half2_raw*>(&b)); return *reinterpret_cast<uint32_t*>(&(out.data)); #else uint32_t c; asm volatile("add.f16x2 %0, %1, %2;\n" : "=r"(c) : "r"(a), "r"(b)); return c; #endif } //////////////////////////////////////////////////////////////////////////////////////////////////// inline __device__ uint2 add(uint2 a, uint2 b) { uint2 c; c.x = add(a.x, b.x); c.y = add(a.y, b.y); return c; } //////////////////////////////////////////////////////////////////////////////////////////////////// inline __device__ uint4 add(uint4 a, uint4 b) { uint4 c; c.x = add(a.x, b.x); c.y = add(a.y, b.y); c.z = add(a.z, b.z); c.w = add(a.w, b.w); return c; } inline __device__ float add(float a, uint16_t b) { return a + half_to_float(b); } //////////////////////////////////////////////////////////////////////////////////////////////////// #ifdef ENABLE_BF16 inline __device__ float add(float a, __nv_bfloat16 b) { return a + __bfloat162float(b); } #endif //////////////////////////////////////////////////////////////////////////////////////////////////// #ifdef ENABLE_FP8 inline __device__ float add(float a, __nv_fp8_e4m3 b) { return a + (float)(b); } #endif //////////////////////////////////////////////////////////////////////////////////////////////////// inline __device__ float2 add(uint32_t a, float2 fb) { float2 fa = half2_to_float2(a); return add(fa, fb); } //////////////////////////////////////////////////////////////////////////////////////////////////// inline __device__ Float4_ add(uint2 a, Float4_ fb) { Float4_ fc; fc.x = add(a.x, fb.x); fc.y = add(a.y, fb.y); return fc; } //////////////////////////////////////////////////////////////////////////////////////////////////// inline __device__ Float8_ add(uint4 a, Float8_ fb) { Float8_ fc; fc.x = add(a.x, fb.x); fc.y = add(a.y, fb.y); fc.z = add(a.z, fb.z); fc.w = add(a.w, fb.w); return fc; } //////////////////////////////////////////////////////////////////////////////////////////////////// #ifdef ENABLE_BF16 inline __device__ float2 add(__nv_bfloat162 a, float2 fb) { float2 fa = bf1622float2(a); return add(fa, fb); } //////////////////////////////////////////////////////////////////////////////////////////////////// inline __device__ Float4_ add(bf16_4_t a, Float4_ fb) { Float4_ fc; fc.x = add(a.x, fb.x); fc.y = add(a.y, fb.y); return fc; } //////////////////////////////////////////////////////////////////////////////////////////////////// inline __device__ Float8_ add(bf16_8_t a, Float8_ fb) { Float8_ fc; fc.x = add(a.x, fb.x); fc.y = add(a.y, fb.y); fc.z = add(a.z, fb.z); fc.w = add(a.w, fb.w); return fc; } #endif // ENABLE_BF16