maga_transformer/cpp/kernels/_logn_attention.h (99 lines of code) (raw):
#pragma once
////////////////////////////////////////////////////////////////////////////////////////////////////
__device__ __inline__ void logn_attention(float& vec, const int seq_idx, const int logn_seq_len) {
float logn = logf(seq_idx + 1) / logf(logn_seq_len);
if (seq_idx > logn_seq_len) {
vec = vec * logn;
}
}
__device__ __inline__ void logn_attention(float2& vec, const int seq_idx, const int logn_seq_len) {
float logn = logf(seq_idx + 1) / logf(logn_seq_len);
if (seq_idx > logn_seq_len) {
vec.x = vec.x * logn;
vec.y = vec.y * logn;
}
}
__device__ __inline__ void logn_attention(uint32_t& vec, const int seq_idx, const int logn_seq_len) {
float2 result = half2_to_float2(vec);
float logn = logf(seq_idx + 1) / logf(logn_seq_len);
if (seq_idx > logn_seq_len) {
result.x = result.x * logn;
result.y = result.y * logn;
}
vec = float2_to_half2(result);
}
__device__ __inline__ void logn_attention(float4& vec, const int seq_idx, const int logn_seq_len) {
float logn = logf(seq_idx + 1) / logf(logn_seq_len);
if (seq_idx > logn_seq_len) {
vec.x = vec.x * logn;
vec.y = vec.y * logn;
vec.z = vec.z * logn;
vec.w = vec.w * logn;
}
}
__device__ __inline__ void logn_attention(uint2& vec, const int seq_idx, const int logn_seq_len) {
float2 result0 = half2_to_float2(vec.x);
float2 result1 = half2_to_float2(vec.y);
float logn = logf(seq_idx + 1) / logf(logn_seq_len);
if (seq_idx > logn_seq_len) {
result0.x = result0.x * logn;
result0.y = result0.y * logn;
result1.x = result1.x * logn;
result1.y = result1.y * logn;
}
vec.x = float2_to_half2(result0);
vec.y = float2_to_half2(result1);
}
__device__ __inline__ void logn_attention(uint4& vec, const int seq_idx, const int logn_seq_len) {
float2 result0 = half2_to_float2(vec.x);
float2 result1 = half2_to_float2(vec.y);
float2 result2 = half2_to_float2(vec.z);
float2 result3 = half2_to_float2(vec.w);
float logn = logf(seq_idx + 1) / logf(logn_seq_len);
if (seq_idx > logn_seq_len) {
result0.x = result0.x * logn;
result0.y = result0.y * logn;
result1.x = result1.x * logn;
result1.y = result1.y * logn;
result2.x = result2.x * logn;
result2.y = result2.y * logn;
result3.x = result3.x * logn;
result3.y = result3.y * logn;
}
vec.x = float2_to_half2(result0);
vec.y = float2_to_half2(result1);
vec.z = float2_to_half2(result2);
vec.w = float2_to_half2(result3);
}
#ifdef ENABLE_BF16
__device__ __inline__ void logn_attention(__nv_bfloat162& vec, const int seq_idx, const int logn_seq_len) {
if (seq_idx > logn_seq_len) {
__nv_bfloat16 scalar = __nv_bfloat16((logf(seq_idx + 1) / logf(logn_seq_len)));
vec.x = vec.x * scalar;
vec.y = vec.y * scalar;
}
}
__device__ __inline__ void logn_attention(__nv_bfloat16& vec, const int seq_idx, const int logn_seq_len) {
if (seq_idx > logn_seq_len) {
__nv_bfloat16 scalar = __nv_bfloat16((logf(seq_idx + 1) / logf(logn_seq_len)));
vec = vec * scalar;
}
}
__device__ __inline__ void logn_attention(bf16_8_t& vec, const int seq_idx, const int logn_seq_len) {
if (seq_idx > logn_seq_len) {
__nv_bfloat16 scalar = __nv_bfloat16((logf(seq_idx + 1) / logf(logn_seq_len)));
__nv_bfloat162 scalar2 = __nv_bfloat162(scalar, scalar);
vec.x = vec.x * scalar2;
vec.y = vec.y * scalar2;
vec.z = vec.z * scalar2;
vec.w = vec.w * scalar2;
}
}
__device__ __inline__ void logn_attention(bf16_4_t& vec, const int seq_idx, const int logn_seq_len) {
if (seq_idx > logn_seq_len) {
__nv_bfloat16 scalar = __nv_bfloat16((logf(seq_idx + 1) / logf(logn_seq_len)));
__nv_bfloat162 scalar2 = __nv_bfloat162(scalar, scalar);
vec.x = vec.x * scalar2;
vec.y = vec.y * scalar2;
}
}
#endif