maga_transformer/cpp/cutlass/cutlass_kernels/weightOnlyBatchedGemv/fp8Gemm.cu (148 lines of code) (raw):
/*
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "cutlass/numeric_conversion.h"
#include "maga_transformer/cpp/cutlass/cutlass_kernels/weightOnlyBatchedGemv/fp8Gemm.h"
#include <cub/cub.cuh>
namespace tensorrt_llm
{
namespace kernels
{
namespace fp8_gemm
{
template <typename InputType, typename OutputType, SizeType32 TILE_M, SizeType32 TILE_N, SizeType32 BLOCK_SIZE>
__global__ void fp8Gemm(InputType const* __restrict__ act, InputType const* __restrict__ weight, float alpha,
OutputType* __restrict__ output, SizeType32 m, SizeType32 n, SizeType32 k)
{
using VecType = int4;
static constexpr SizeType32 kStepK = static_cast<SizeType32>(128 / (8 * sizeof(InputType)));
static constexpr SizeType32 kTileK = kStepK * BLOCK_SIZE;
auto tileIdM = static_cast<SizeType32>(blockIdx.x * TILE_M);
auto tileIdN = static_cast<SizeType32>(blockIdx.y * TILE_N);
auto tid = static_cast<SizeType32>(threadIdx.x);
float tile_a[kStepK], tile_w[TILE_N * kStepK];
float acc[TILE_M * TILE_N];
static_assert(kStepK % 4 == 0);
using CvtInputType
= std::conditional_t<std::is_same_v<InputType, __nv_fp8_e4m3>, cutlass::float_e4m3_t, cutlass::float_e5m2_t>;
using Converter = cutlass::NumericArrayConverter<float, CvtInputType, 4>;
using CvtSrcType = typename Converter::source_type;
using CvtResType = typename Converter::result_type;
static constexpr SizeType32 kCvtCount = static_cast<SizeType32>(sizeof(VecType) / sizeof(CvtSrcType));
#pragma unroll
for (SizeType32 i = 0; i < TILE_M * TILE_N; ++i)
{
acc[i] = 0;
}
act += tileIdM * k;
weight += tileIdN * k;
output += tileIdM * n + tileIdN;
for (SizeType32 idxK = tid * kStepK; idxK < k; idxK += kTileK)
{
#pragma unroll
for (SizeType32 i = 0; i < TILE_N; ++i)
{
auto tile_w_quantized = reinterpret_cast<VecType const*>(weight + i * k + idxK)[0];
#pragma unroll
for (SizeType32 cvtIdx = 0; cvtIdx < kCvtCount; ++cvtIdx)
{
reinterpret_cast<CvtResType*>(tile_w)[i * kCvtCount + cvtIdx]
= Converter::convert(reinterpret_cast<CvtSrcType*>(&tile_w_quantized)[cvtIdx]);
}
}
#pragma unroll
for (SizeType32 i = 0; i < TILE_M; ++i)
{
auto tile_a_quantized = reinterpret_cast<VecType const*>(act + i * k + idxK)[0];
#pragma unroll
for (SizeType32 cvtIdx = 0; cvtIdx < kCvtCount; ++cvtIdx)
{
reinterpret_cast<CvtResType*>(tile_a)[cvtIdx]
= Converter::convert(reinterpret_cast<CvtSrcType*>(&tile_a_quantized)[cvtIdx]);
}
#pragma unroll
for (SizeType32 j = 0; j < TILE_N; ++j)
{
#pragma unroll
for (SizeType32 l = 0; l < kStepK; ++l)
{
acc[i * TILE_N + j] = fma(tile_a[l], tile_w[j * kStepK + l], acc[i * TILE_N + j]);
}
}
}
}
typedef cub::WarpReduce<float> WarpReduce;
static constexpr SizeType32 kWarpSize = 32;
static constexpr SizeType32 kWarpNum = BLOCK_SIZE / kWarpSize;
SizeType32 warpId = tid / kWarpSize, laneId = tid % kWarpSize;
__shared__ float shmem[TILE_M * TILE_N * kWarpNum];
__shared__ typename WarpReduce::TempStorage tempStorage[kWarpNum];
#pragma unroll
for (SizeType32 mi = 0; mi < TILE_M; ++mi)
{
#pragma unroll
for (SizeType32 ni = 0; ni < TILE_N; ++ni)
{
float val = WarpReduce(tempStorage[warpId]).Sum(acc[mi * TILE_N + ni]);
if (laneId == 0)
{
shmem[mi * TILE_N + ni + warpId * TILE_M * TILE_N] = val;
}
}
}
__syncthreads();
#pragma unroll
for (SizeType32 ii = tid; ii < TILE_M * TILE_N; ii += BLOCK_SIZE)
{
SizeType32 mid = ii / TILE_N, nid = ii % TILE_N;
float val = 0;
#pragma unroll
for (SizeType32 jj = 0; jj < kWarpNum; ++jj)
{
val += shmem[jj * TILE_M * TILE_N + ii];
}
output[mid * n + nid] = static_cast<OutputType>(val * alpha);
}
}
template <typename InputType, typename OutputType, SizeType32 TILE_M, SizeType32 TILE_N, SizeType32 BLOCK_SIZE>
void fp8GemmKernel(Params& params, cudaStream_t stream)
{
dim3 block(BLOCK_SIZE);
dim3 grid(params.m / TILE_M, params.n / TILE_N);
fp8Gemm<InputType, OutputType, TILE_M, TILE_N, BLOCK_SIZE><<<grid, block, 0, stream>>>(
reinterpret_cast<InputType const*>(params.act), reinterpret_cast<InputType const*>(params.weight), params.alpha,
reinterpret_cast<OutputType*>(params.output), params.m, params.n, params.k);
}
template <typename InputType, typename OutputType>
void fp8GemmLauncher(Params& params, cudaStream_t stream)
{
#define DISPATCH(TargetM, TILE_M, TILE_N, BLOCK_SIZE) \
if (params.m == TargetM) \
{ \
fp8GemmKernel<InputType, OutputType, TILE_M, TILE_N, BLOCK_SIZE>(params, stream); \
return; \
}
DISPATCH(1, 1, 2, 128);
DISPATCH(2, 2, 2, 128);
DISPATCH(3, 3, 2, 128);
DISPATCH(4, 4, 2, 128);
#undef DISPATCH
}
template void fp8GemmLauncher<__nv_fp8_e4m3, float>(Params& params, cudaStream_t stream);
template void fp8GemmLauncher<__nv_fp8_e4m3, half>(Params& params, cudaStream_t stream);
template void fp8GemmLauncher<__nv_fp8_e4m3, __nv_bfloat16>(Params& params, cudaStream_t stream);
} // namespace fp8_gemm
} // namespace kernels
} // namespace tensorrt_llm