maga_transformer/cpp/kernels/add_residual_kernels.cu (507 lines of code) (raw):
/*
* Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "maga_transformer/cpp/kernels/add_residual_kernels.h"
#include "maga_transformer/cpp/cuda/cuda_type_utils.cuh"
#include "maga_transformer/cpp/utils/AssertUtils.h"
#if USING_ROCM
#include "maga_transformer/cpp/rocm/hip_utils.h"
#endif
// wont't support new features
namespace rtp_llm{
#if USING_ROCM
using namespace rocm;
#endif
template<typename T, int RESIDUAL_NUM, typename T2 = T>
__global__ void addBiasResidual(T* output,
const T2* input,
const T* residual1,
const T* residual2,
const T* bias,
const float* scale_inter,
const float* scale_out,
const int m,
const int n)
{
const int col_index = blockIdx.y * blockDim.x + threadIdx.x;
if (col_index < n) {
T bias_val = (bias == nullptr) ? (T)(0.0f) : bias[col_index];
T in;
if (std::is_same<T, T2>::value) {
in = cuda_cast<T>(input[blockIdx.x * n + col_index]); // cast required for compilation when T != T2
}
else {
in = cuda_cast<float>(input[blockIdx.x * n + col_index]) * (*scale_inter) * (*scale_out);
}
if (RESIDUAL_NUM == 1) {
output[blockIdx.x * n + col_index] = in + residual1[blockIdx.x * n + col_index] + bias_val;
}
else if (RESIDUAL_NUM == 2) {
output[blockIdx.x * n + col_index] =
in + residual1[blockIdx.x * n + col_index] + residual2[blockIdx.x * n + col_index] + bias_val;
}
}
}
template<typename T>
void invokeAddBiasResidual(T* output,
const T* input,
const T* residual1,
const T* residual2,
const T* bias,
const float* scale_inter,
const float* scale_out,
const int m,
const int n,
cudaStream_t stream)
{
RTP_LLM_CHECK_WITH_INFO(!((scale_inter == nullptr) ^ (scale_out == nullptr)),
"Cannot use `scale_inter` without `scale_out`");
const bool should_scale_input = scale_inter != nullptr;
int blocks_per_row = ceil(float(n) / 1024);
dim3 grid(m, blocks_per_row);
dim3 block(min(n, 1024));
if (residual2 == nullptr) {
if (should_scale_input) {
addBiasResidual<T, 1><<<grid, block, 0, stream>>>(output,
reinterpret_cast<const int32_t*>(input),
residual1,
residual2,
bias,
scale_inter,
scale_out,
m,
n);
}
else {
addBiasResidual<T, 1>
<<<grid, block, 0, stream>>>(output, input, residual1, residual2, bias, nullptr, nullptr, m, n);
}
}
else {
if (should_scale_input) {
addBiasResidual<T, 2><<<grid, block, 0, stream>>>(output,
reinterpret_cast<const int32_t*>(input),
residual1,
residual2,
bias,
scale_inter,
scale_out,
m,
n);
}
else {
addBiasResidual<T, 2>
<<<grid, block, 0, stream>>>(output, input, residual1, residual2, bias, nullptr, nullptr, m, n);
}
}
}
template<typename T>
__global__ void alphaAddBiasResidual(T* output, const T* input, const T* bias, const T alpha, const int m, const int n)
{
const int col_index = blockIdx.y * blockDim.x + threadIdx.x;
if (col_index < n) {
T bias_val = (bias == nullptr) ? (T)(0.0f) : bias[col_index];
output[blockIdx.x * n + col_index] =
output[blockIdx.x * n + col_index] + input[blockIdx.x * n + col_index] * alpha + bias_val;
}
}
template<typename T>
__global__ void alphaAddBiasResidual(T* output, const T* input, const T* residual, const T* bias, const T alpha, const int m, const int n)
{
const int col_index = blockIdx.y * blockDim.x + threadIdx.x;
if (col_index < n) {
T bias_val = (bias == nullptr) ? (T)(0.0f) : bias[col_index];
output[blockIdx.x * n + col_index] =
residual[blockIdx.x * n + col_index] + input[blockIdx.x * n + col_index] * alpha + bias_val;
}
}
template<typename T>
void invokeAlphaAddBiasResidual(
T* output, const T* input, const T* residual, const T* bias, const T alpha, const int m, const int n, cudaStream_t stream)
{
int blocks_per_row = ceil(float(n) / 1024);
dim3 grid(m, blocks_per_row);
dim3 block(min(n, 1024));
if (residual) {
alphaAddBiasResidual<<<grid, block, 0, stream>>>(output, input, residual, bias, alpha, m, n);
} else {
alphaAddBiasResidual<<<grid, block, 0, stream>>>(output, input, bias, alpha, m, n);
}
}
template<typename T>
__global__ void addBiasAttentionFfnResidual(T* block_output,
const T* ffn_output,
const T* attn_output,
const T* block_input,
const T* bias,
const int m,
const int n,
const int block_input_tp_split)
{
const int col_index = blockIdx.y * blockDim.x + threadIdx.x;
if (col_index < n) {
block_output[blockIdx.x * n + col_index] =
ffn_output[blockIdx.x * n + col_index] + attn_output[blockIdx.x * n + col_index] + bias[col_index]
+ ((block_input != nullptr) ?
cuda_cast<T>((float)block_input[blockIdx.x * n + col_index] / (float)block_input_tp_split) :
static_cast<T>(0.0f));
}
}
template<typename T>
__global__ void addBiasAttentionFfnResidual(T* block_output,
const T* ffn_output,
const T* attn_output,
const T* bias,
const int m,
const int n,
const int block_input_tp_split)
{
const int col_index = blockIdx.y * blockDim.x + threadIdx.x;
if (col_index < n) {
const int global_index = blockIdx.x * n + col_index;
block_output[global_index] = add(cuda_cast<T>((float)block_output[global_index] / (float)block_input_tp_split),
ffn_output[global_index],
attn_output[global_index],
bias[col_index]);
}
}
template<typename T>
void invokeAddBiasAttentionFfnResidual(T* block_output,
const T* ffn_output,
const T* attn_output,
const T* block_input,
const T* bias,
const int m,
const int n,
const int block_input_tp_split,
cudaStream_t stream)
{
int blocks_per_row = ceil(float(n) / 1024);
dim3 grid(m, blocks_per_row);
dim3 block(min(n, 1024));
if (block_output == block_input) {
addBiasAttentionFfnResidual<<<grid, block, 0, stream>>>(
block_output, ffn_output, attn_output, bias, m, n, block_input_tp_split);
}
else {
addBiasAttentionFfnResidual<<<grid, block, 0, stream>>>(
block_output, ffn_output, attn_output, block_input, bias, m, n, block_input_tp_split);
}
}
#define INSTANTIATE_INVOKE_ADD_BIAS_RESIDUAL(T) \
template void invokeAddBiasResidual(T* output, \
const T* input, \
const T* residual1, \
const T* residual2, \
const T* bias, \
const float* scale_inter, \
const float* scale_out, \
const int m, \
const int n, \
cudaStream_t stream)
INSTANTIATE_INVOKE_ADD_BIAS_RESIDUAL(float);
INSTANTIATE_INVOKE_ADD_BIAS_RESIDUAL(half);
#ifdef ENABLE_BF16
INSTANTIATE_INVOKE_ADD_BIAS_RESIDUAL(__nv_bfloat16);
#endif
#undef INSTANTIATE_INVOKE_ADD_BIAS_RESIDUAL
template void invokeAlphaAddBiasResidual(float* output,
const float* input,
const float* residual,
const float* bias,
const float alpha,
const int m,
const int n,
cudaStream_t stream);
template void invokeAlphaAddBiasResidual(half* output,
const half* input,
const half* residual,
const half* bias,
const half alpha,
const int m,
const int n,
cudaStream_t stream);
#ifdef ENABLE_BF16
template void invokeAlphaAddBiasResidual(__nv_bfloat16* output,
const __nv_bfloat16* input,
const __nv_bfloat16* residual,
const __nv_bfloat16* bias,
const __nv_bfloat16 alpha,
const int m,
const int n,
cudaStream_t stream);
#endif
template void invokeAddBiasAttentionFfnResidual(float* block_output,
const float* ffn_output,
const float* attn_output,
const float* input,
const float* bias,
const int m,
const int n,
const int block_input_tp_split,
cudaStream_t stream);
template void invokeAddBiasAttentionFfnResidual(half* block_output,
const half* ffn_output,
const half* attn_output,
const half* input,
const half* bias,
const int m,
const int n,
const int block_input_tp_split,
cudaStream_t stream);
#ifdef ENABLE_BF16
template void invokeAddBiasAttentionFfnResidual(__nv_bfloat16* block_output,
const __nv_bfloat16* ffn_output,
const __nv_bfloat16* attn_output,
const __nv_bfloat16* input,
const __nv_bfloat16* bias,
const int m,
const int n,
const int block_input_tp_split,
cudaStream_t stream);
#endif
/******************* invokeAddBiasResidualCol32 ***********************/
// input1/input2/out matrix with layout of cublasLt CUBLASLT_ORDER_COL32 (m*n)
//(grid, block) must be (m, n/4)
// using char4
template<typename T>
__global__ void add_bias_input_COL32_int8I_DataTypeO(
T* output, const int8_t* input1, const T* input2, const T* bias, int m, int n, const float* input1_deQFactor_ptr)
{
const float input1_deQFactor = __ldg(input1_deQFactor_ptr);
int col_start = threadIdx.x << 2;
float local_out[4];
int outIdx = ((col_start & 0xffffffe0) * m + (blockIdx.x << 5) + (col_start & 31)) >> 2;
char4* input1TmpPtr = (char4*)input1;
char4 input1Tmp = __ldg(input1TmpPtr + outIdx);
int col_start_tmp = col_start;
local_out[0] = static_cast<float>(input2[(outIdx << 2) + 0]) + static_cast<float>(input1Tmp.x) * input1_deQFactor
+ static_cast<float>(__ldg(bias + col_start_tmp));
col_start_tmp = col_start_tmp + 1;
local_out[1] = static_cast<float>(input2[(outIdx << 2) + 1]) + static_cast<float>(input1Tmp.y) * input1_deQFactor
+ static_cast<float>(__ldg(bias + col_start_tmp));
col_start_tmp = col_start_tmp + 1;
local_out[2] = static_cast<float>(input2[(outIdx << 2) + 2]) + static_cast<float>(input1Tmp.z) * input1_deQFactor
+ static_cast<float>(__ldg(bias + col_start_tmp));
col_start_tmp = col_start_tmp + 1;
local_out[3] = static_cast<float>(input2[(outIdx << 2) + 3]) + static_cast<float>(input1Tmp.w) * input1_deQFactor
+ static_cast<float>(__ldg(bias + col_start_tmp));
for (int i = 0; i < 4; i++) {
output[(outIdx << 2) + i] = static_cast<T>(local_out[i]);
}
}
template<>
__global__ void add_bias_input_COL32_int8I_DataTypeO(half4* output,
const int8_t* input1,
const half4* input2,
const half4* bias,
int m,
int n,
const float* input1_deQFactor_ptr)
{
const float input1_deQFactor = __ldg(input1_deQFactor_ptr);
int col_start = (blockIdx.x << 5) + (threadIdx.x << 2);
int row_start = (blockIdx.y << 5) + (threadIdx.y);
if (col_start < n && row_start < m) {
half4 local_out;
int outIdx = ((col_start & 0xffffffe0) * m + (row_start << 5) + (col_start & 31)) >> 2;
char4* input1TmpPtr = (char4*)input1;
char4 input1Tmp = input1TmpPtr[outIdx];
half4 input2Tmp = input2[outIdx];
half4 biasTmp = bias[col_start >> 2];
local_out.x = static_cast<half>((float)input1Tmp.x * input1_deQFactor + (float)biasTmp.x + (float)input2Tmp.x);
local_out.y = static_cast<half>((float)input1Tmp.y * input1_deQFactor + (float)biasTmp.y + (float)input2Tmp.y);
local_out.z = static_cast<half>((float)input1Tmp.z * input1_deQFactor + (float)biasTmp.z + (float)input2Tmp.z);
local_out.w = static_cast<half>((float)input1Tmp.w * input1_deQFactor + (float)biasTmp.w + (float)input2Tmp.w);
output[outIdx] = local_out;
}
}
template<typename T>
void invokeAddBiasResidualCol32(T* output,
const int8_t* input1,
const T* input2,
const T* bias,
int m,
int n,
cudaStream_t stream,
const float* input1_deQFactor_ptr)
{
dim3 grid((n + 31) / 32, (m + 31) / 32);
dim3 block(8, 32);
assert(block.x <= 1024);
if (sizeof(T) == 2) {
add_bias_input_COL32_int8I_DataTypeO<<<grid, block, 0, stream>>>(
(half4*)output, input1, (const half4*)input2, (const half4*)bias, m, n, input1_deQFactor_ptr);
}
else {
add_bias_input_COL32_int8I_DataTypeO<T>
<<<grid, block, 0, stream>>>(output, input1, input2, bias, m, n, input1_deQFactor_ptr);
}
}
template void invokeAddBiasResidualCol32(float* output,
const int8_t* input1,
const float* input2,
const float* bias,
int m,
int n,
cudaStream_t stream,
const float* input1_deQFactor_ptr);
template void invokeAddBiasResidualCol32(half* output,
const int8_t* input1,
const half* input2,
const half* bias,
int m,
int n,
cudaStream_t stream,
const float* input1_deQFactor_ptr);
/******************* invokeAddBiasResidualCol32 ***********************/
// input1/input2/out matrix with layout of cublasLt CUBLASLT_ORDER_COL32 (m*n)
//(grid, block) must be (m, n/4)
// using char4
template<typename T>
__global__ void add_bias_input_COL32_int32I_DataTypeO(T* output,
const int32_t* input1,
const T* input2,
const T* bias,
int m,
int n,
const float* weight_amax,
const float* input1_amax_ptr,
const int scale_is_vector)
{
int col_start = threadIdx.x << 2;
const float4* weight_scale_ptr = (const float4*)weight_amax;
const float4 weight_scale = __ldg(weight_scale_ptr + threadIdx.x * scale_is_vector);
const float input1_deQ = __ldg(input1_amax_ptr) / 127.0f;
float local_out[4];
int outIdx = ((col_start & 0xffffffe0) * m + (blockIdx.x << 5) + (col_start & 31)) >> 2;
int4* input1TmpPtr = (int4*)input1;
int4 input1Tmp = input1TmpPtr[outIdx];
int col_start_tmp = col_start;
local_out[0] = static_cast<float>(input2[(outIdx << 2) + 0])
+ static_cast<float>(input1Tmp.x) * input1_deQ * weight_scale.x / 127.0f
+ static_cast<float>(__ldg(bias + col_start_tmp));
col_start_tmp = col_start_tmp + 1;
local_out[1] = static_cast<float>(input2[(outIdx << 2) + 1])
+ static_cast<float>(input1Tmp.y) * input1_deQ * weight_scale.y / 127.0f
+ static_cast<float>(__ldg(bias + col_start_tmp));
col_start_tmp = col_start_tmp + 1;
local_out[2] = static_cast<float>(input2[(outIdx << 2) + 2])
+ static_cast<float>(input1Tmp.z) * input1_deQ * weight_scale.z / 127.0f
+ static_cast<float>(__ldg(bias + col_start_tmp));
col_start_tmp = col_start_tmp + 1;
local_out[3] = static_cast<float>(input2[(outIdx << 2) + 3])
+ static_cast<float>(input1Tmp.w) * input1_deQ * weight_scale.w / 127.0f
+ static_cast<float>(__ldg(bias + col_start_tmp));
for (int i = 0; i < 4; i++) {
output[(outIdx << 2) + i] = static_cast<T>(local_out[i]);
}
}
template<>
__global__ void add_bias_input_COL32_int32I_DataTypeO(half4* output,
const int32_t* input1,
const half4* input2,
const half4* bias,
int m,
int n,
const float* weight_amax,
const float* input1_amax_ptr,
const int scale_is_vector)
{
int col_start = threadIdx.x << 2;
const float4* weight_scale_ptr = (const float4*)weight_amax;
const float weight_scale_single = __ldg(weight_amax);
const float4 weight_scale =
scale_is_vector == 1 ?
__ldg(weight_scale_ptr + threadIdx.x * scale_is_vector) :
make_float4(weight_scale_single, weight_scale_single, weight_scale_single, weight_scale_single);
const float input1_deQ = __ldg(input1_amax_ptr) / 127.0f;
float local_out[4];
int outIdx = ((col_start & 0xffffffe0) * m + (blockIdx.x << 5) + (col_start & 31)) >> 2;
int4* input1TmpPtr = (int4*)input1;
int4 input1Tmp = input1TmpPtr[outIdx];
half4 input2Tmp = input2[outIdx];
half4 biasTmp = bias[threadIdx.x];
local_out[0] = static_cast<float>(input2Tmp.x)
+ static_cast<float>(input1Tmp.x) * input1_deQ * weight_scale.x / 127.0f
+ static_cast<float>(biasTmp.x);
local_out[1] = static_cast<float>(input2Tmp.y)
+ static_cast<float>(input1Tmp.y) * input1_deQ * weight_scale.y / 127.0f
+ static_cast<float>(biasTmp.y);
local_out[2] = static_cast<float>(input2Tmp.z)
+ static_cast<float>(input1Tmp.z) * input1_deQ * weight_scale.z / 127.0f
+ static_cast<float>(biasTmp.z);
local_out[3] = static_cast<float>(input2Tmp.w)
+ static_cast<float>(input1Tmp.w) * input1_deQ * weight_scale.w / 127.0f
+ static_cast<float>(biasTmp.w);
half4 outTmp;
outTmp.x = static_cast<half>(local_out[0]);
outTmp.y = static_cast<half>(local_out[1]);
outTmp.z = static_cast<half>(local_out[2]);
outTmp.w = static_cast<half>(local_out[3]);
output[outIdx] = outTmp;
}
template<typename T>
void invokeAddBiasResidualCol32(T* output,
const int32_t* input1,
const T* input2,
const T* bias,
int m,
int n,
cudaStream_t stream,
const float* weight_amax,
const float* input1_amax_ptr,
const int scale_is_vector)
{
dim3 grid(m);
dim3 block(n / 4);
assert(block.x <= 1024);
if (sizeof(T) == 2) {
add_bias_input_COL32_int32I_DataTypeO<<<grid, block, 0, stream>>>((half4*)output,
input1,
(const half4*)input2,
(const half4*)bias,
m,
n,
weight_amax,
input1_amax_ptr,
scale_is_vector);
}
else {
add_bias_input_COL32_int32I_DataTypeO<T><<<grid, block, 0, stream>>>(
output, input1, input2, bias, m, n, weight_amax, input1_amax_ptr, scale_is_vector);
}
}
template void invokeAddBiasResidualCol32(float* output,
const int* input1,
const float* input2,
const float* bias,
int m,
int n,
cudaStream_t stream,
const float* weight_amax,
const float* input1_amax_ptr,
const int scale_is_vector);
template void invokeAddBiasResidualCol32(half* output,
const int* input1,
const half* input2,
const half* bias,
int m,
int n,
cudaStream_t stream,
const float* weight_amax,
const float* input1_amax_ptr,
const int scale_is_vector);
} // namespace rtp_llm