maga_transformer/cpp/kernels/_mul.h (562 lines of code) (raw):
#pragma once
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename Acc, typename A, typename B>
inline __device__ Acc mul(A a, B b) {
// This will error out when multiply operation is not supported.
return Acc(a * b);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<>
inline __device__ float mul<float, float>(float a, float b) {
return a * b;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<>
inline __device__ float2 mul(float2 a, float2 b) {
float2 c;
c.x = a.x * b.x;
c.y = a.y * b.y;
return c;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<>
inline __device__ float2 mul(float a, float2 b) {
float2 c;
c.x = a * b.x;
c.y = a * b.y;
return c;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<>
inline __device__ float4 mul(float4 a, float4 b) {
float4 c;
c.x = a.x * b.x;
c.y = a.y * b.y;
c.z = a.z * b.z;
c.w = a.w * b.w;
return c;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<>
inline __device__ Float4_ mul(float4 a, Float4_ b) {
float4 c;
c = mul<float4, float4, float4>(a, reinterpret_cast<float4&>(b));
return reinterpret_cast<Float4_&>(c);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<>
inline __device__ float4 mul(float a, float4 b) {
float4 c;
c.x = a * b.x;
c.y = a * b.y;
c.z = a * b.z;
c.w = a * b.w;
return c;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<>
inline __device__ Float4_ mul(float a, Float4_ b) {
float4 c = mul<float4, float, float4>(a, reinterpret_cast<float4&>(b));
return reinterpret_cast<Float4_&>(c);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<>
inline __device__ Float8_ mul(float a, Float8_ b) {
Float8_ c;
c.x = mul<float2, float, float2>(a, b.x);
c.y = mul<float2, float, float2>(a, b.y);
c.z = mul<float2, float, float2>(a, b.z);
c.w = mul<float2, float, float2>(a, b.w);
return c;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<>
inline __device__ uint16_t mul(uint16_t a, uint16_t b) {
#if USING_ROCM
__half_raw out = __hmul(*reinterpret_cast<__half_raw*>(&a), *reinterpret_cast<__half_raw*>(&b));
return *reinterpret_cast<uint16_t*>(&(out.data));
#else
uint16_t c;
asm volatile("mul.f16 %0, %1, %2;\n" : "=h"(c) : "h"(a), "h"(b));
return c;
#endif
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<>
inline __device__ uint32_t mul(uint32_t a, uint32_t b) {
#if USING_ROCM
__half2 out = __hmul2(*reinterpret_cast<__half2_raw*>(&a), *reinterpret_cast<__half2_raw*>(&b));
return *reinterpret_cast<uint32_t*>(&(out.data));
#else
uint32_t c;
asm volatile("mul.f16x2 %0, %1, %2;\n" : "=r"(c) : "r"(a), "r"(b));
return c;
#endif
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<>
inline __device__ uint32_t mul(uint16_t a, uint32_t b) {
return mul<uint32_t, uint32_t, uint32_t>(h0_h0(a), b);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<>
inline __device__ uint2 mul(uint2 a, uint2 b) {
uint2 c;
c.x = mul<uint32_t, uint32_t, uint32_t>(a.x, b.x);
c.y = mul<uint32_t, uint32_t, uint32_t>(a.y, b.y);
return c;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<>
inline __device__ uint2 mul(uint16_t a, uint2 b) {
uint32_t s = h0_h0(a);
uint2 c;
c.x = mul<uint32_t, uint32_t, uint32_t>(s, b.x);
c.y = mul<uint32_t, uint32_t, uint32_t>(s, b.y);
return c;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<>
inline __device__ uint4 mul(uint4 a, uint4 b) {
uint4 c;
c.x = mul<uint32_t, uint32_t, uint32_t>(a.x, b.x);
c.y = mul<uint32_t, uint32_t, uint32_t>(a.y, b.y);
c.z = mul<uint32_t, uint32_t, uint32_t>(a.z, b.z);
c.w = mul<uint32_t, uint32_t, uint32_t>(a.w, b.w);
return c;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<>
inline __device__ uint4 mul(uint16_t a, uint4 b) {
uint32_t s = h0_h0(a);
uint4 c;
c.x = mul<uint32_t, uint32_t, uint32_t>(s, b.x);
c.y = mul<uint32_t, uint32_t, uint32_t>(s, b.y);
c.z = mul<uint32_t, uint32_t, uint32_t>(s, b.z);
c.w = mul<uint32_t, uint32_t, uint32_t>(s, b.w);
return c;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<>
inline __device__ float mul(uint16_t a, uint16_t b) {
float fa = half_to_float(a);
float fb = half_to_float(b);
return fa * fb;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<>
inline __device__ float mul(uint16_t a, float b) {
return half_to_float(a) * b;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<>
inline __device__ float2 mul(uint32_t a, uint32_t b) {
float2 fa = half2_to_float2(a);
float2 fb = half2_to_float2(b);
return mul<float2, float2, float2>(fa, fb);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<>
inline __device__ float2 mul(uint32_t a, float2 fb) {
float2 fa = half2_to_float2(a);
return mul<float2, float2, float2>(fa, fb);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<>
inline __device__ float2 mul(float2 fa, uint32_t b) {
float2 fb = half2_to_float2(b);
return mul<float2, float2, float2>(fa, fb);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<>
inline __device__ float2 mul(uint16_t a, uint32_t b) {
return mul<float2, uint32_t, uint32_t>(h0_h0(a), b);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<>
inline __device__ Float4_ mul(uint2 a, uint2 b) {
Float4_ fc;
fc.x = mul<float2, uint32_t, uint32_t>(a.x, b.x);
fc.y = mul<float2, uint32_t, uint32_t>(a.y, b.y);
return fc;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<>
inline __device__ Float4_ mul(uint16_t a, uint2 b) {
uint32_t s = h0_h0(a);
Float4_ fc;
fc.x = mul<float2, uint32_t, uint32_t>(s, b.x);
fc.y = mul<float2, uint32_t, uint32_t>(s, b.y);
return fc;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<>
inline __device__ Float8_ mul(uint4 a, uint4 b) {
Float8_ fc;
fc.x = mul<float2, uint32_t, uint32_t>(a.x, b.x);
fc.y = mul<float2, uint32_t, uint32_t>(a.y, b.y);
fc.z = mul<float2, uint32_t, uint32_t>(a.z, b.z);
fc.w = mul<float2, uint32_t, uint32_t>(a.w, b.w);
return fc;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<>
inline __device__ Float8_ mul(Float8_ fa, uint4 b) {
Float8_ fc;
fc.x = mul<float2, float2, uint32_t>(fa.x, b.x);
fc.y = mul<float2, float2, uint32_t>(fa.y, b.y);
fc.z = mul<float2, float2, uint32_t>(fa.z, b.z);
fc.w = mul<float2, float2, uint32_t>(fa.w, b.w);
return fc;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<>
inline __device__ Float8_ mul(Float8_ fa, Float8_ fb) {
Float8_ fc;
fc.x = mul<float2, float2, float2>(fa.x, fb.x);
fc.y = mul<float2, float2, float2>(fa.y, fb.y);
fc.z = mul<float2, float2, float2>(fa.z, fb.z);
fc.w = mul<float2, float2, float2>(fa.w, fb.w);
return fc;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<>
inline __device__ Float8_ mul(uint4 a, Float8_ fb) {
Float8_ fc;
fc.x = mul<float2, uint32_t, float2>(a.x, fb.x);
fc.y = mul<float2, uint32_t, float2>(a.y, fb.y);
fc.z = mul<float2, uint32_t, float2>(a.z, fb.z);
fc.w = mul<float2, uint32_t, float2>(a.w, fb.w);
return fc;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<>
inline __device__ Float8_ mul(uint16_t a, uint4 b) {
uint32_t s = h0_h0(a);
Float8_ fc;
fc.x = mul<float2, uint32_t, uint32_t>(s, b.x);
fc.y = mul<float2, uint32_t, uint32_t>(s, b.y);
fc.z = mul<float2, uint32_t, uint32_t>(s, b.z);
fc.w = mul<float2, uint32_t, uint32_t>(s, b.w);
return fc;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<>
inline __device__ Float8_ mul(float a, uint4 b) {
uint16_t h0 = float_to_half(a);
uint32_t s = h0_h0(h0);
Float8_ fc;
fc.x = mul<float2, uint32_t, uint32_t>(s, b.x);
fc.y = mul<float2, uint32_t, uint32_t>(s, b.y);
fc.z = mul<float2, uint32_t, uint32_t>(s, b.z);
fc.w = mul<float2, uint32_t, uint32_t>(s, b.w);
return fc;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<>
inline __device__ uint4 mul(float a, uint4 b) {
uint16_t h = float_to_half(a);
uint4 c = mul<uint4, uint16_t, uint4>(h, b);
return c;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
#ifdef ENABLE_BF16
template<>
inline __device__ __nv_bfloat16 mul(__nv_bfloat16 a, __nv_bfloat16 b) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
return __hmul(a, b);
#else
return bf16hmul(a, b);
#endif
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<>
inline __device__ __nv_bfloat162 mul(__nv_bfloat162 a, __nv_bfloat162 b) {
return bf16hmul2(a, b);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<>
inline __device__ __nv_bfloat162 mul(__nv_bfloat16 a, __nv_bfloat162 b) {
return mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(bf162bf162(a), b);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<>
inline __device__ bf16_4_t mul(bf16_4_t a, bf16_4_t b) {
bf16_4_t c;
c.x = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.x, b.x);
c.y = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.y, b.y);
return c;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<>
inline __device__ bf16_4_t mul(__nv_bfloat16 a, bf16_4_t b) {
__nv_bfloat162 s = bf162bf162(a);
bf16_4_t c;
c.x = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(s, b.x);
c.y = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(s, b.y);
return c;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<>
inline __device__ bf16_8_t mul(bf16_8_t a, bf16_8_t b) {
bf16_8_t c;
c.x = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.x, b.x);
c.y = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.y, b.y);
c.z = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.z, b.z);
c.w = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.w, b.w);
return c;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<>
inline __device__ bf16_8_t mul(float a, bf16_8_t b) {
__nv_bfloat162 a_ = float22bf162(make_float2(a, a));
bf16_8_t c;
c.x = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a_, b.x);
c.y = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a_, b.y);
c.z = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a_, b.z);
c.w = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a_, b.w);
return c;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<>
inline __device__ bf16_8_t mul(__nv_bfloat16 a, bf16_8_t b) {
__nv_bfloat162 s = bf162bf162(a);
bf16_8_t c;
c.x = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(s, b.x);
c.y = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(s, b.y);
c.z = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(s, b.z);
c.w = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(s, b.w);
return c;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<>
inline __device__ float mul(__nv_bfloat16 a, __nv_bfloat16 b) {
float fa = (float)a;
float fb = (float)b;
return fa * fb;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<>
inline __device__ float mul(__nv_bfloat16 a, float b) {
return __bfloat162float(a) * b;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<>
inline __device__ float2 mul(__nv_bfloat162 a, __nv_bfloat162 b) {
float2 fa = bf1622float2(a);
float2 fb = bf1622float2(b);
return mul<float2, float2, float2>(fa, fb);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<>
inline __device__ float2 mul(__nv_bfloat162 a, float2 fb) {
float2 fa = bf1622float2(a);
return mul<float2, float2, float2>(fa, fb);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<>
inline __device__ float2 mul(float2 fa, __nv_bfloat162 b) {
float2 fb = bf1622float2(b);
return mul<float2, float2, float2>(fa, fb);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<>
inline __device__ float2 mul(__nv_bfloat16 a, __nv_bfloat162 b) {
return mul<float2, __nv_bfloat162, __nv_bfloat162>(bf162bf162(a), b);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<>
inline __device__ Float4_ mul(bf16_4_t a, bf16_4_t b) {
Float4_ fc;
fc.x = mul<float2, __nv_bfloat162, __nv_bfloat162>(a.x, b.x);
fc.y = mul<float2, __nv_bfloat162, __nv_bfloat162>(a.y, b.y);
return fc;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<>
inline __device__ Float4_ mul(__nv_bfloat16 a, bf16_4_t b) {
__nv_bfloat162 s = bf162bf162(a);
Float4_ fc;
fc.x = mul<float2, __nv_bfloat162, __nv_bfloat162>(s, b.x);
fc.y = mul<float2, __nv_bfloat162, __nv_bfloat162>(s, b.y);
return fc;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<>
inline __device__ Float8_ mul(bf16_8_t a, bf16_8_t b) {
Float8_ fc;
fc.x = mul<float2, __nv_bfloat162, __nv_bfloat162>(a.x, b.x);
fc.y = mul<float2, __nv_bfloat162, __nv_bfloat162>(a.y, b.y);
fc.z = mul<float2, __nv_bfloat162, __nv_bfloat162>(a.z, b.z);
fc.w = mul<float2, __nv_bfloat162, __nv_bfloat162>(a.w, b.w);
return fc;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<>
inline __device__ Float8_ mul(bf16_8_t a, Float8_ fb) {
Float8_ fc;
fc.x = mul<float2, __nv_bfloat162, float2>(a.x, fb.x);
fc.y = mul<float2, __nv_bfloat162, float2>(a.y, fb.y);
fc.z = mul<float2, __nv_bfloat162, float2>(a.z, fb.z);
fc.w = mul<float2, __nv_bfloat162, float2>(a.w, fb.w);
return fc;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<>
inline __device__ Float8_ mul(Float8_ fa, bf16_8_t b) {
Float8_ fc;
fc.x = mul<float2, float2, __nv_bfloat162>(fa.x, b.x);
fc.y = mul<float2, float2, __nv_bfloat162>(fa.y, b.y);
fc.z = mul<float2, float2, __nv_bfloat162>(fa.z, b.z);
fc.w = mul<float2, float2, __nv_bfloat162>(fa.w, b.w);
return fc;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<>
inline __device__ Float8_ mul(__nv_bfloat16 a, bf16_8_t b) {
__nv_bfloat162 s = bf162bf162(a);
Float8_ fc;
fc.x = mul<float2, __nv_bfloat162, __nv_bfloat162>(s, b.x);
fc.y = mul<float2, __nv_bfloat162, __nv_bfloat162>(s, b.y);
fc.z = mul<float2, __nv_bfloat162, __nv_bfloat162>(s, b.z);
fc.w = mul<float2, __nv_bfloat162, __nv_bfloat162>(s, b.w);
return fc;
}
#endif // ENABLE_BF16
#ifdef ENABLE_FP8
////////////////////////////////////////////////////////////////////////////////////////////////////
template<>
inline __device__ Float8_ mul(uint4 a, fp8_8_t b) {
Float8_ fc;
union {
fp8_8_t fp8_8;
fp8_2_t fp8_2[4];
};
fp8_8 = b;
fc.x = mul<float2, uint32_t, float2>(a.x, float2(fp8_2[0]));
fc.y = mul<float2, uint32_t, float2>(a.y, float2(fp8_2[1]));
fc.z = mul<float2, uint32_t, float2>(a.z, float2(fp8_2[2]));
fc.w = mul<float2, uint32_t, float2>(a.w, float2(fp8_2[3]));
return fc;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<>
inline __device__ Float8_ mul(Float8_ fa, fp8_8_t b) {
Float8_ fc;
union {
fp8_8_t fp8_8;
fp8_2_t fp8_2[4];
};
fp8_8 = b;
fc.x = mul<float2, float2, float2>(fa.x, float2(fp8_2[0]));
fc.y = mul<float2, float2, float2>(fa.y, float2(fp8_2[1]));
fc.z = mul<float2, float2, float2>(fa.z, float2(fp8_2[2]));
fc.w = mul<float2, float2, float2>(fa.w, float2(fp8_2[3]));
return fc;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<>
inline __device__ Float4_ mul(float fa, fp8_4_t b) {
Float4_ fc;
union {
fp8_4_t fp8_4;
fp8_2_t fp8_2[2];
};
fp8_4 = b;
float2 fa2 = make_float2(fa, fa);
fc.x = mul<float2, float2, float2>(fa2, float2(fp8_2[0]));
fc.y = mul<float2, float2, float2>(fa2, float2(fp8_2[1]));
return fc;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<>
inline __device__ float4 mul(float fa, fp8_4_t b) {
Float4_ fc = mul<Float4_, float, fp8_4_t>(fa, b);
return reinterpret_cast<float4&>(fc);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<>
inline __device__ Float8_ mul(float fa, fp8_8_t b) {
Float8_ fc;
union {
fp8_8_t fp8_8;
fp8_2_t fp8_2[4];
};
fp8_8 = b;
float2 fa2 = make_float2(fa, fa);
fc.x = mul<float2, float2, float2>(fa2, float2(fp8_2[0]));
fc.y = mul<float2, float2, float2>(fa2, float2(fp8_2[1]));
fc.z = mul<float2, float2, float2>(fa2, float2(fp8_2[2]));
fc.w = mul<float2, float2, float2>(fa2, float2(fp8_2[3]));
return fc;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<>
inline __device__ Float8_ mul(bf16_8_t a, fp8_8_t b) {
Float8_ fc;
union {
fp8_8_t fp8_8;
fp8_2_t fp8_2[4];
};
fp8_8 = b;
fc.x = mul<float2, __nv_bfloat162, float2>(a.x, float2(fp8_2[0]));
fc.y = mul<float2, __nv_bfloat162, float2>(a.y, float2(fp8_2[1]));
fc.z = mul<float2, __nv_bfloat162, float2>(a.z, float2(fp8_2[2]));
fc.w = mul<float2, __nv_bfloat162, float2>(a.w, float2(fp8_2[3]));
return fc;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<>
inline __device__ float4 mul(float4 fa, fp8_4_t b) {
float4 fc;
union {
fp8_4_t fp8_4;
fp8_2_t fp8_2[2];
};
fp8_4 = b;
float2 fb0 = float2(fp8_2[0]);
float2 fb1 = float2(fp8_2[1]);
fc.x = fa.x * fb0.x;
fc.y = fa.y * fb0.y;
fc.z = fa.z * fb1.x;
fc.w = fa.w * fb1.y;
return fc;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<>
inline __device__ Float4_ mul(float4 fa, fp8_4_t b) {
float4 fc = mul<float4, float4, fp8_4_t>(fa, b);
return reinterpret_cast<Float4_&>(fc);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
#endif // ENABLE_FP8
////////////////////////////////////////////////////////////////////////////////////////////////////
template<>
inline __device__ Float8_ mul(uint4 a, int64_t b) {
Float8_ fc;
union {
int64_t int64;
int8_t int8[8];
};
int64 = b;
fc.x = mul<float2, uint32_t, float2>(a.x, make_float2(int8[0], int8[1]));
fc.y = mul<float2, uint32_t, float2>(a.y, make_float2(int8[2], int8[3]));
fc.z = mul<float2, uint32_t, float2>(a.z, make_float2(int8[4], int8[5]));
fc.w = mul<float2, uint32_t, float2>(a.w, make_float2(int8[6], int8[7]));
return fc;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<>
inline __device__ Float8_ mul(Float8_ fa, int64_t b) {
Float8_ fc;
union {
int64_t int64;
int8_t int8[8];
};
int64 = b;
fc.x = mul<float2, float2, float2>(fa.x, make_float2(int8[0], int8[1]));
fc.y = mul<float2, float2, float2>(fa.y, make_float2(int8[2], int8[3]));
fc.z = mul<float2, float2, float2>(fa.z, make_float2(int8[4], int8[5]));
fc.w = mul<float2, float2, float2>(fa.w, make_float2(int8[6], int8[7]));
return fc;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<>
inline __device__ Float8_ mul(float fa, int64_t b) {
Float8_ fc;
union {
int64_t int64;
int8_t int8[8];
};
int64 = b;
float2 fa2 = make_float2(fa, fa);
fc.x = mul<float2, float2, float2>(fa2, make_float2(int8[0], int8[1]));
fc.y = mul<float2, float2, float2>(fa2, make_float2(int8[2], int8[3]));
fc.z = mul<float2, float2, float2>(fa2, make_float2(int8[4], int8[5]));
fc.w = mul<float2, float2, float2>(fa2, make_float2(int8[6], int8[7]));
return fc;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<>
inline __device__ Float4_ mul(float fa, int32_t b) {
Float4_ fc;
union {
int32_t int32;
int8_t int8[4];
};
int32 = b;
float2 fa2 = make_float2(fa, fa);
fc.x = mul<float2, float2, float2>(fa2, make_float2(int8[0], int8[1]));
fc.y = mul<float2, float2, float2>(fa2, make_float2(int8[2], int8[3]));
return fc;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<>
inline __device__ float4 mul(float fa, int32_t b) {
Float4_ fc = mul<Float4_, float, int32_t>(fa, b);
return reinterpret_cast<float4&>(fc);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
#ifdef ENABLE_BF16
template<>
inline __device__ Float8_ mul(bf16_8_t a, int64_t b) {
Float8_ fc;
union {
int64_t int64;
int8_t int8[8];
};
int64 = b;
fc.x = mul<float2, __nv_bfloat162, float2>(a.x, make_float2(int8[0], int8[1]));
fc.y = mul<float2, __nv_bfloat162, float2>(a.y, make_float2(int8[2], int8[3]));
fc.z = mul<float2, __nv_bfloat162, float2>(a.z, make_float2(int8[4], int8[5]));
fc.w = mul<float2, __nv_bfloat162, float2>(a.w, make_float2(int8[6], int8[7]));
return fc;
}
#endif // ENABLE_BF16
///////////////////////////////////////////////////////////////////////////////////////////////
template<>
inline __device__ float4 mul(float4 a, int32_t b) {
float4 fc;
union {
int32_t int32;
int8_t int8[4];
};
int32 = b;
fc.x = a.x * float(int8[0]);
fc.y = a.y * float(int8[1]);
fc.z = a.z * float(int8[2]);
fc.w = a.w * float(int8[3]);
return fc;
}