maga_transformer/cpp/devices/cuda_impl/CudaDevice.cc (619 lines of code) (raw):
#include "maga_transformer/cpp/devices/cuda_impl/CudaDevice.h"
#include "maga_transformer/cpp/devices/cuda_impl/CudaFlashInfer.h"
#include "maga_transformer/cpp/kernels/eplb/experts_stats_kernels.h"
#include "maga_transformer/cpp/core/BufferHelper.h"
#include "maga_transformer/cpp/cuda/custom_ar/custom_ar_comm.h"
#include "maga_transformer/cpp/devices/DeviceFactory.h"
#include "maga_transformer/cpp/cuda/allocator_cuda.h"
#include "maga_transformer/cpp/cuda/nccl/nccl_utils_torch.h"
#include "maga_transformer/cpp/cuda/nccl/nccl_utils.h"
#include "maga_transformer/cpp/core/TrackerAllocator.h"
#include "maga_transformer/cpp/devices/OpData.h"
#include "maga_transformer/cpp/utils/Logger.h"
#include "maga_transformer/cpp/utils/compiler_config.h"
#include "maga_transformer/cpp/core/torch_utils/torch_cuda_allocator.h"
#include "maga_transformer/cpp/disaggregate/cache_store/NormalCacheStore.h"
#include <cuda_runtime.h>
#include <curand_kernel.h>
#include <unistd.h>
using namespace std;
using namespace rtp_llm;
using namespace tensorrt_llm;
using namespace tensorrt_llm::kernels;
using namespace rtp_llm;
namespace rtp_llm {
CudaDevice::CudaDevice(const DeviceInitParams& params)
: DeviceBase(params)
{
RTP_LLM_LOG_INFO("Initialize CudaDevice. %d", device_id_);
check_cuda_error(cudaSetDevice(device_id_));
torch_default_stream_ = std::make_unique<at::cuda::CUDAStream>(at::cuda::getDefaultCUDAStream());
torch_comm_stream_ = std::make_unique<at::cuda::CUDAStream>(at::cuda::getStreamFromPool(true));
stream_ = torch_default_stream_->stream();
communication_stream_ = torch_comm_stream_->stream();
check_cuda_error(cudaStreamCreateWithFlags(&no_block_copy_stream_, cudaStreamNonBlocking));
check_cuda_error(cublasCreate(&cublas_handle_));
check_cuda_error(cublasLtCreate(&cublaslt_handle_));
check_cuda_error(cublasSetStream(cublas_handle_, stream_));
check_cuda_error(cudaGetDeviceProperties(&device_prop_, device_id_));
weight_only_matmul_plugin_ = std::make_unique<trt_plugins::WeightOnlyQuantMatmulPlugin>();
smooth_quant_plugin_ = std::make_unique<trt_plugins::SmoothQuantGemmPlugin>();
weight_only_groupwise_matmul_plugin_ = std::make_unique<trt_plugins::WeightOnlyGroupwiseQuantMatmulPlugin>();
moe_plugin_ = std::make_unique<trt_plugins::MixtureOfExpertsPlugin>();
if (std::getenv("HACK_MOE_EXPERT") && std::string(std::getenv("HACK_MOE_EXPERT")) == "1") {
hack_moe_expert_ = true;
}
if (params.tp_size > 1) {
auto master_ip = params.master_ip;
if (params.dp_size > 1) {
master_ip = "127.0.0.1";
}
initNcclParam(
params.tp_rank, params.tp_size, master_ip, params.tp_master_port, "RTP_LLM_TP_GROUP_", tp_nccl_param_);
}
if (params.ffn_tp_size > 1) {
if (params.ffn_tp_size != params.tp_size) {
initNcclParam(params.ffn_tp_rank, params.ffn_tp_size, params.master_ip,
params.ffn_tp_master_port - params.tp_rank / params.ffn_tp_size, "RTP_LLM_FFN_TP_GROUP_", ffn_tp_nccl_param_);
} else {
ffn_tp_nccl_param_ = tp_nccl_param_;
}
}
if (params.ep_size > 1) {
initNcclParam(params.dp_rank * params.tp_size + params.tp_rank,
params.dp_size * params.tp_size,
params.master_ip,
params.dp_tp_master_port,
"RTP_LLM_DP_TP_GROUP_",
dp_tp_nccl_param_);
}
cuggemm_runner_.reset(new cuggemm());
cuggemm_runner_->init(stream_);
auto fmha_env = std::getenv("ENABLE_FMHA");
if (fmha_env && std::string(fmha_env) == "OFF") {
RTP_LLM_LOG_WARNING("FMHA is not enbaled");
} else {
checkUseTrtV1FMHA();
checkUseTrtV2FMHA();
checkUseOpenSourceFMHA();
checkSupportTrtFp8FMHA();
}
checkUseMultiBlockMode();
checkUseGroupGemm();
checkUseFlashinferSampleKernel();
// Initialize custom all reduce communicator
// Note: custom all reduce communicator will allocate cuda mem through cudaMalloc, it must be called before
// allocator init
if (tp_nccl_param_.world_size_ > 1) {
auto& nccl_param = tp_nccl_param_;
std::vector<size_t> tp_ranks = fcNcclGatherRanks(nccl_param, stream_);
custom_allreduce_comm_ = initCustomAllReduceComm(nccl_param, tp_ranks, stream_);
}
// cudaHostMalloc needs page table on GPU memory, retain this part first.
auto host_allocator_ptr = new Allocator<AllocatorType::CUDA_HOST>(device_id_);
host_allocator_ptr->setStream(stream_);
if (params.host_reserve_memory_bytes) {
RUNTIME_ASSERT_OP_ARG(params.host_reserve_memory_bytes > 0,
"cuda host memory can not reserve as much as possible (%lu), must specify concrete size.",
params.host_reserve_memory_bytes);
TrackerAllocatorParams tracker_params;
tracker_params.real_allocator = host_allocator_ptr;
tracker_params.target_track_bytes = params.host_reserve_memory_bytes;
tracker_params.align_size = 32; // required by avx512
host_allocator_.reset(new TrackerAllocator(tracker_params));
} else {
host_allocator_.reset(host_allocator_ptr);
}
auto allocator_ptr = new Allocator<AllocatorType::CUDA>(device_id_);
allocator_ptr->setStream(stream_);
if (params.device_reserve_memory_bytes) {
size_t free_bytes, total_bytes;
check_cuda_error(cudaMemGetInfo(&free_bytes, &total_bytes));
TrackerAllocatorParams tracker_params;
tracker_params.real_allocator = allocator_ptr;
tracker_params.target_track_bytes = params.device_reserve_memory_bytes > 0 ?
params.device_reserve_memory_bytes :
free_bytes + params.device_reserve_memory_bytes;
tracker_params.align_size = 16;
RTP_LLM_LOG_INFO("cuda device %d has %lu bytes free memory, trying to reserve %lu bytes.",
device_id_,
free_bytes,
tracker_params.target_track_bytes);
allocator_.reset(new TrackerAllocator(tracker_params));
syncAndCheck(); // sync check tracker malloc cuda mem
} else {
allocator_.reset(allocator_ptr);
}
// hijack torch cuda allocator
origin_torch_cuda_allocator_ = at::cuda::CUDACachingAllocator::allocator;
managed_torch_cuda_allocator_ = std::make_unique<TorchCudaAllocator>(this);
at::cuda::CUDACachingAllocator::allocator.store(managed_torch_cuda_allocator_.get());
cublas_algo_map_.reset(new cublasAlgoMap(GEMM_CONFIG));
cublas_mm_wrapper_.reset(new cublasMMWrapper(
cublas_handle_, cublaslt_handle_, stream_, cublas_algo_map_.get(),
&cublas_wrapper_mutex_, allocator_.get()));
// select mla type
if (params.mla_ops_type != MlaOpsType::AUTO) {
mla_ops_type = params.mla_ops_type;
} else {
mla_ops_type = device_prop_.major >= 9 ? MlaOpsType::FLASH_MLA : MlaOpsType::FLASH_INFER;
}
auto stable_scatter_add_env = std::getenv("ENABLE_STABLE_SCATTER_ADD");
if (stable_scatter_add_env && std::string(stable_scatter_add_env) == "ON") {
use_stable_scatter_add = true;
}
RTP_LLM_LOG_INFO("use_stable_scatter_add: %d", use_stable_scatter_add);
}
CudaDevice::~CudaDevice() {
// change torch cuda gpu allocate
if (origin_torch_cuda_allocator_) {
at::cuda::CUDACachingAllocator::allocator.store(origin_torch_cuda_allocator_);
origin_torch_cuda_allocator_ = nullptr;
}
curandstate_buf_.reset();
cublas_mm_wrapper_.reset();
check_cuda_error(cudaStreamDestroy(no_block_copy_stream_));
check_cuda_error(cublasDestroy(cublas_handle_));
check_cuda_error(cublasLtDestroy(cublaslt_handle_));
if (ffn_tp_nccl_param_ != tp_nccl_param_ && ffn_tp_nccl_param_.nccl_comm_) {
ncclCommDestroy(ffn_tp_nccl_param_.nccl_comm_);
}
if (tp_nccl_param_.nccl_comm_) {
ncclCommDestroy(tp_nccl_param_.nccl_comm_);
}
if (dp_tp_nccl_param_.nccl_comm_) {
ncclCommDestroy(dp_tp_nccl_param_.nccl_comm_);
}
cache_store_.reset();
}
void CudaDevice::init() {
DeviceBase::init();
RTP_LLM_LOG_INFO("cuda device init max batch size: %d\n", init_params_.max_batch_size);
curandstate_buf_ = allocateBuffer({init_params_.max_batch_size * sizeof(curandState_t)}, {"curandstate"});
if (init_params_.use_deepep_moe) {
if (!initDeepEPBuffer()) {
RTP_LLM_CHECK_WITH_INFO(false, "init deepep buffer failed");
} else {
RTP_LLM_LOG_INFO("init deepep buffer success");
}
}
}
void CudaDevice::initNcclParam(size_t rank,
size_t world_size,
const std::string& ip,
size_t port,
const string& group_name,
NcclParam& nccl_param) {
nccl_param.rank_ = rank;
nccl_param.world_size_ = world_size;
auto tcpStore = createTcpStore(ip, port, world_size, rank);
const auto nccl_id = &(nccl_param.nccl_uid_);
if (rank == 0) {
RTP_LLM_LOG_INFO("rank %d creates nccl uid in group %s.", rank, group_name.c_str());
NCCLCHECK(ncclGetUniqueId(nccl_id));
setUniqueId(nccl_id, group_name, tcpStore);
} else {
RTP_LLM_LOG_INFO("rank %d get nccl uid in group %s.", rank, group_name.c_str());
getUniqueId(nccl_id, group_name, tcpStore);
}
RTP_LLM_LOG_INFO("Initialize NCCL communicators [%s] rank %d of %d.", group_name.c_str(), rank, world_size);
NCCLCHECK(ncclGroupStart());
NCCLCHECK(ncclCommInitRank(&nccl_param.nccl_comm_, world_size, *nccl_id, rank));
NCCLCHECK(ncclGroupEnd());
}
void CudaDevice::syncAndCheck() {
syncCommunication();
cudaStreamSynchronize(stream_);
cudaStreamSynchronize(communication_stream_);
cudaStreamSynchronize(no_block_copy_stream_);
sync_check_cuda_error();
}
void CudaDevice::syncDeviceStream(DeviceStream stream) {
auto sync_stream = getStream(stream);
check_cuda_error(cudaStreamSynchronize(sync_stream));
}
void CudaDevice::syncCommunication(bool timeout) {
if (tp_nccl_param_.world_size_ > 1) {
RTP_LLM_LOG_DEBUG(
"Synchronize tp NCCL communicators rank %d of %d.", tp_nccl_param_.rank_, tp_nccl_param_.world_size_);
ftNcclStreamSynchronize(tp_nccl_param_, stream_, timeout);
}
if (dp_tp_nccl_param_.world_size_ > 1) {
RTP_LLM_LOG_DEBUG("Synchronize dp_tp NCCL communicators rank %d of %d.",
dp_tp_nccl_param_.rank_,
dp_tp_nccl_param_.world_size_);
ftNcclStreamSynchronize(dp_tp_nccl_param_, stream_, timeout);
}
if (ffn_tp_nccl_param_.world_size_ > 1 && ffn_tp_nccl_param_ != tp_nccl_param_) {
RTP_LLM_LOG_DEBUG("Synchronize ffn_tp NCCL communicators rank %d of %d.", ffn_tp_nccl_param_.rank_, ffn_tp_nccl_param_.world_size_);
ftNcclStreamSynchronize(ffn_tp_nccl_param_, stream_, timeout);
}
}
void CudaDevice::syncCommunication(ParallelMode mode, bool timeout) {
auto nccl_param = getNcclParam(mode);
auto stream = getCommStream(mode, false);
if (nccl_param.world_size_ > 1) {
RTP_LLM_LOG_DEBUG("Synchronize NCCL communicators rank %d of %d.", nccl_param.rank_, nccl_param.world_size_);
ftNcclStreamSynchronize(nccl_param, stream, timeout);
}
}
void CudaDevice::overlappedCommBarrier() {
// NOTE: when all the overlapped communication and computation done,
// we need to ensure the communication has been finished before starting the next computation.
if (tp_nccl_param_.world_size_ * init_params_.dp_size * ffn_tp_nccl_param_.world_size_ > 1) {
cudaEvent_t event;
check_cuda_error(cudaEventCreate(&event));
check_cuda_error(cudaEventRecord(event, communication_stream_));
check_cuda_error(cudaStreamWaitEvent(stream_, event, 0));
check_cuda_error(cudaEventDestroy(event));
}
}
DeviceHookPtr CudaDevice::createCommHook() {
return std::make_unique<CudaCommHook>(stream_, communication_stream_);
}
void CudaDevice::overlappedComputeBarrier() {
// NOTE: when all the overlapped communication and computation done,
// we need to ensure the communication has been finished before starting the next computation.
if (tp_nccl_param_.world_size_ * init_params_.dp_size * ffn_tp_nccl_param_.world_size_ > 1) {
cudaEvent_t event;
check_cuda_error(cudaEventCreate(&event));
check_cuda_error(cudaEventRecord(event, stream_));
check_cuda_error(cudaStreamWaitEvent(communication_stream_, event, 0));
check_cuda_error(cudaEventDestroy(event));
}
}
DeviceProperties CudaDevice::getDeviceProperties() {
static DeviceProperties* prop = nullptr;
if (prop == nullptr) {
prop = new DeviceProperties();
prop->type = DeviceType::Cuda;
prop->id = device_id_;
prop->tp_rank = init_params_.tp_rank;
prop->tp_size = init_params_.tp_size;
prop->dp_rank = init_params_.dp_rank;
prop->dp_size = init_params_.dp_size;
prop->enable_comm_overlap = init_params_.enable_comm_overlap;
prop->enable_layer_micro_batch = init_params_.enable_layer_micro_batch;
prop->enable_sp = init_params_.enable_sp;
prop->overlap_math_sm_count = init_params_.overlap_math_sm_count;
prop->overlap_comm_type = init_params_.overlap_comm_type;
prop->ffn_tp_size = init_params_.ffn_tp_size;
prop->ffn_tp_rank = init_params_.ffn_tp_rank;
prop->m_split = init_params_.m_split;
prop->use_deepep_moe = init_params_.use_deepep_moe;
prop->use_deepep_internode = init_params_.use_deepep_internode;
prop->use_deepep_low_latency = init_params_.use_deepep_low_latency;
prop->is_mtp = init_params_.is_mtp;
}
return *prop;
}
void CudaDevice::selectCuFMHARunner(const DevicePrepParams& params) {
bool found_cufmha_runner = false;
use_fp8_fmha_ = useFp8Fmha(params);
DataType fmha_datatype = use_fp8_fmha_ ? DataType::TYPE_FP8_E4M3 : params.attn_dtype;
for (auto& runner : cufmha_runner_pool_) {
if (runner->checkSignature(fmha_datatype,
params.configs.mask_type,
params.configs.head_num,
params.configs.kv_head_num,
params.configs.size_per_head,
params.configs.q_scaling / params.configs.softmax_extra_scale,
params.has_alibi_slopes)) {
cufmha_runner_ = runner;
found_cufmha_runner = true;
return;
}
}
if (!found_cufmha_runner) {
cufmha_runner_pool_.emplace_back();
cufmha_runner_pool_.back().reset(
new cufmha(fmha_datatype,
params.configs.mask_type,
params.configs.head_num,
params.configs.kv_head_num,
params.configs.size_per_head,
params.configs.q_scaling / params.configs.softmax_extra_scale, // div scale for DeepSeek V2
params.has_alibi_slopes,
use_trtv1_fmha,
use_trtv2_fmha,
use_trtv2_fmha_paged,
use_open_source_fmha,
use_open_source_fmha_paged,
stream_));
cufmha_runner_ = cufmha_runner_pool_.back();
}
}
DevicePrepOutput CudaDevice::prepareModelRun(const DevicePrepParams& params) {
DevicePrepOutput output;
fmha_type_ = FMHAType::NONE;
if (params.attn_dtype == DataType::TYPE_FP32) {
fmha_type_ = FMHAType::NONE;
output.need_mask = true;
} else if (params.context_batch_size) {
selectCuFMHARunner(params);
bool paged_kv_fmha =
params.diff_qkv_len && params.has_kv_cache && (params.configs.kv_cache_dtype == KvCacheDataType::BASE);
if (paged_kv_fmha) {
if (use_trtv2_fmha_paged && cufmha_runner_->trtV2FmhaPagedSupport()) {
fmha_type_ = FMHAType::PAGED_TRT_V2;
} else if (use_open_source_fmha_paged && cufmha_runner_->openSourceFmhaSupport()
&& params.configs.tokens_per_block % 256 == 0) {
fmha_type_ = FMHAType::PAGED_OPEN_SOURCE;
}
} else if (!params.diff_qkv_len) {
if (use_trtv2_fmha && cufmha_runner_->trtV2FmhaSupport()) {
fmha_type_ = FMHAType::TRT_V2;
} else if (use_open_source_fmha && cufmha_runner_->openSourceFmhaSupport()) {
fmha_type_ = FMHAType::OPEN_SOURCE;
} else if (use_trtv1_fmha && cufmha_runner_->trtV1FmhaSupport() && mla_ops_type == MlaOpsType::MHA) {
fmha_type_ = FMHAType::TRT_V1;
}
} else {
fmha_type_ = FMHAType::NONE;
}
output.need_mask = (fmha_type_ == FMHAType::NONE);
}
output.decode_flash_infer_attn_params = FlashInferAttnParams::prepare(
this,
params.configs,
nullptr,
params.sequence_lengths->slice(0, params.decoder_batch_size, false),
params.input_lengths->slice(0, params.decoder_batch_size, false),
params.kv_cache_block_id ? params.kv_cache_block_id->slice(0, params.decoder_batch_size, false) : nullptr,
params.kv_cache_block_id_device ? params.kv_cache_block_id_device->slice(0, params.decoder_batch_size, false) : nullptr,
params.attn_dtype);
output.prefill_flash_infer_attn_params = FlashInferAttnParams::prepare(
this,
params.configs,
params.prefix_lengths,
nullptr,
params.input_lengths->slice(params.decoder_batch_size, params.context_batch_size, false),
params.kv_cache_block_id ? params.kv_cache_block_id->slice(params.decoder_batch_size, params.context_batch_size, false) : nullptr,
params.kv_cache_block_id_device ? params.kv_cache_block_id_device->slice(params.decoder_batch_size, params.context_batch_size, false) : nullptr,
params.attn_dtype);
return output;
}
bool CudaDevice::useGroupGemm() const {
return use_group_gemm;
}
cudaStream_t CudaDevice::getStream(DeviceStream stream) {
switch (stream) {
default:
return stream_;
}
}
void CudaDevice::bufMemset(Buffer& buf, int val, DeviceStream stream) {
if (buf.where() == MemoryType::MEMORY_CPU || buf.where() == MemoryType::MEMORY_CPU_PINNED) {
std::memset(buf.data(), val, buf.sizeBytes());
} else {
cudaStream_t cur_stream = getStream(stream);
check_cuda_error(cudaMemsetAsync(buf.data(), val, buf.sizeBytes(), cur_stream));
}
}
void CudaDevice::checkUseOpenSourceFMHA() {
if (!(is_sm8x() || is_sm90())) {
RTP_LLM_LOG_WARNING("opensource FMHA is disabled for sm %d", get_sm());
return;
}
char* fmha_env = std::getenv("ENABLE_OPENSOURCE_FMHA");
if (fmha_env && std::string(fmha_env) == "OFF") {
RTP_LLM_LOG_WARNING("opensource FMHA is disabled for by env");
return;
}
RTP_LLM_LOG_INFO("use opensource fmha");
use_open_source_fmha = true;
char* paged_fmha_env = std::getenv("ENABLE_PAGED_OPEN_SOURCE_FMHA");
if (paged_fmha_env && std::string(paged_fmha_env) == "OFF") {
RTP_LLM_LOG_INFO("Paged open source FMHA is disabled for by ENABLE_PAGED_TRT_FMHA=OFF env");
return;
}
if (init_params_.tokens_per_block % 256 != 0) {
RTP_LLM_LOG_INFO("Paged open source FMHA is disabled since tokens_per_block % 256 != 0");
return;
}
RTP_LLM_LOG_INFO("use opensource fmha paged");
use_open_source_fmha_paged = true;
}
void CudaDevice::checkUseTrtV1FMHA() {
if (!CompileConfig::use_old_trt_fmha) {
return;
}
char* fmha_env = std::getenv("ENABLE_TRTV1_FMHA");
if (fmha_env && std::string(fmha_env) == "OFF") {
RTP_LLM_LOG_WARNING("TRTV1 FMHA is not enbaled");
return;
}
RTP_LLM_LOG_INFO("use TRTV1 fmha");
use_trtv1_fmha = true;
}
void CudaDevice::checkUseTrtV2FMHA() {
if (!(is_sm8x() || is_sm90() || is_sm70())) {
RTP_LLM_LOG_WARNING("TRT FMHA is disabled for sm %d", get_sm());
return;
}
char* fmha_env = std::getenv("ENABLE_TRT_FMHA");
if (fmha_env && std::string(fmha_env) == "OFF") {
RTP_LLM_LOG_WARNING("TRT FMHA is disabled for by env");
return;
}
if (CompileConfig::cudart_version < 12000) {
RTP_LLM_LOG_WARNING("cudart version %d not support need >= 12000!", CompileConfig::cudart_version);
return;
}
RTP_LLM_LOG_INFO("use TRTV2 fmha");
use_trtv2_fmha = true;
if (!(is_sm8x() || is_sm90())) {
RTP_LLM_LOG_INFO("Paged TRT FMHA is disabled for sm %d", get_sm());
return;
}
char* paged_fmha_env = std::getenv("ENABLE_PAGED_TRT_FMHA");
if (paged_fmha_env && std::string(paged_fmha_env) == "OFF") {
RTP_LLM_LOG_INFO("Paged TRT FMHA is disabled for by ENABLE_PAGED_TRT_FMHA=OFF env");
return;
}
RTP_LLM_LOG_INFO("use TRTV2 fmha paged");
use_trtv2_fmha_paged = true;
}
void CudaDevice::checkSupportTrtFp8FMHA() {
int sm = get_sm();
if (sm < 90 || !use_trtv2_fmha) {
RTP_LLM_LOG_WARNING("sm is [%d], use_trtv2_fmha:[%d] not support fp8 fmha", sm, use_trtv2_fmha);
return;
}
RTP_LLM_LOG_INFO("support fp8 fmha");
support_trt_fp8_fmha = true;
}
bool CudaDevice::useFp8Fmha(const DevicePrepParams& params) const {
#ifdef ENABLE_FP8
if (support_trt_fp8_fmha && params.configs.kv_cache_dtype == KvCacheDataType::FP8) {
return true;
}
#endif
return false;
}
void CudaDevice::checkUseFlashinferSampleKernel() {
char* flashinfer_sample_env = std::getenv("ENABLE_FLASHINFER_SAMPLE_KERNEL");
if (flashinfer_sample_env && std::string(flashinfer_sample_env) == "OFF") {
RTP_LLM_LOG_WARNING("Flashinfer sample is disabled for by env");
return;
}
RTP_LLM_LOG_INFO("use Flashinfer sample kernel");
use_flashinfer_sample_kernel = true;
}
void CudaDevice::checkUseMultiBlockMode() {
if constexpr (CompileConfig::cudart_version < 11070) {
RTP_LLM_LOG_WARNING("MMHA multi_block_mode for cudart_version %d is disabled", CompileConfig::cudart_version);
use_multi_block_mode = false;
return;
}
char* multi_block_mode_env = std::getenv("ENABLE_MULTI_BLOCK_MODE");
if (multi_block_mode_env != nullptr && std::string(multi_block_mode_env) == "OFF") {
RTP_LLM_LOG_WARNING("MMHA multi_block_mode is disabled");
use_multi_block_mode = false;
return;
}
if (get_sm() == 80 || get_sm() >= 89) {
RTP_LLM_LOG_INFO("MMHA multi_block_mode is enabled");
use_multi_block_mode = true;
return;
}
use_multi_block_mode = true;
}
void CudaDevice::checkUseGroupGemm() {
if (is_sm8x()) {
use_group_gemm = true;
} else {
use_group_gemm = false;
}
}
MemoryStatus CudaDevice::getDeviceMemoryStatus() {
MemoryStatus status;
size_t total_bytes;
auto error = cudaMemGetInfo(&status.free_bytes, &total_bytes);
RTP_LLM_CHECK(error == cudaSuccess);
status.used_bytes = total_bytes - status.free_bytes;
return status;
}
nvinfer1::DataType nvinfer1DtypeConvert(rtp_llm::DataType dtype) {
switch (dtype) {
case rtp_llm::DataType::TYPE_FP16 : return nvinfer1::DataType::kHALF;
case rtp_llm::DataType::TYPE_BF16 : return nvinfer1::DataType::kBF16;
case rtp_llm::DataType::TYPE_FP32 : return nvinfer1::DataType::kFLOAT;
case rtp_llm::DataType::TYPE_QINT8 : return nvinfer1::DataType::kINT8;
case rtp_llm::DataType::TYPE_QINT4X2 : return nvinfer1::DataType::kINT4;
case rtp_llm::DataType::TYPE_QFP8_E4M3 : return nvinfer1::DataType::kFP8;
default: throw OpException(OpErrorType::ERROR_UNIMPLEMENTED);
}
}
DeviceEventPtr CudaDevice::createEvent() {
return std::make_unique<CudaEvent>(stream_);
}
CudaEvent::CudaEvent(cudaStream_t stream): stream_(stream) {
check_cuda_error(cudaEventCreate(&event_));
check_cuda_error(cudaEventRecord(event_, stream));
}
CudaEvent::~CudaEvent() {
check_cuda_error(cudaEventDestroy(event_));
}
void CudaEvent::synchronize() const {
check_cuda_error(cudaEventSynchronize(event_));
check_cuda_error(cudaStreamSynchronize(stream_));
sync_check_cuda_error();
cudaDeviceSynchronize();
}
CudaCommHook::CudaCommHook(cudaStream_t main_stream, cudaStream_t comm_stream)
: main_stream_(main_stream), comm_stream_(comm_stream) {
check_cuda_error(cudaEventCreate(&hook_event_));
check_cuda_error(cudaEventRecord(hook_event_, comm_stream_));
}
CudaCommHook::~CudaCommHook() {
check_cuda_error(cudaEventDestroy(hook_event_));
}
void CudaCommHook::hook_sync() const {
check_cuda_error(cudaStreamWaitEvent(main_stream_, hook_event_, 0));
}
void CudaDevice::prepareCommBuffer(const PrepareCommBufferParams& params) {
if (attn_rs_comm_buffer_) {
return;
}
RTP_LLM_LOG_INFO("[PrepareCommBuffer] max_batch_seq_len %d, attn_rs_hidden %d, ffn_rs_hidden %d, attn_ag_hidden %d, ffn_ag_hidden %d, rs_output_type %d, ag_input_type %d, enable_per_token_scale %d, enable_ffn_tp %d",
params.max_batch_seq_len, params.attn_rs_hidden, params.ffn_rs_hidden, params.attn_ag_hidden, params.ffn_ag_hidden, params.rs_output_type, params.ag_input_type, params.enable_per_token_scale, params.enable_ffn_tp);
size_t m = params.max_batch_seq_len * 1.1;
std::vector<size_t> tp_ranks = fcNcclGatherRanks(tp_nccl_param_, stream_);
RTP_LLM_LOG_INFO("[PrepareCommBuffer] prepare attn_rs_comm_buffer_");
std::vector<size_t> attn_rs_buffer_shape = {m, params.attn_rs_hidden};
attn_rs_comm_buffer_ = initCommBuffer(attn_rs_buffer_shape, params.rs_output_type, tp_nccl_param_, tp_ranks, false, stream_);
RTP_LLM_LOG_INFO("[PrepareCommBuffer] prepare attn_ag_comm_buffer_");
std::vector<size_t> attn_ag_buffer_shape = {m, params.attn_ag_hidden};
attn_ag_comm_buffer_ = initCommBuffer(attn_ag_buffer_shape, params.ag_input_type, tp_nccl_param_, tp_ranks, true, stream_);
if (params.enable_per_token_scale) {
RTP_LLM_LOG_INFO("[PrepareCommBuffer] prepare attn_ag_scale_comm_buffer_");
std::vector<size_t> attn_ag_scale_shape = {m, 1};
attn_ag_scale_comm_buffer_ = initCommBuffer(attn_ag_scale_shape, DataType::TYPE_FP32, tp_nccl_param_, tp_ranks, true, stream_);
}
if (params.enable_ffn_tp) {
std::vector<size_t> ffn_tp_ranks = fcNcclGatherRanks(ffn_tp_nccl_param_, stream_);
RTP_LLM_LOG_INFO("[PrepareCommBuffer] prepare ffn_rs_comm_buffer_");
std::vector<size_t> ffn_rs_buffer_shape = {m, params.ffn_rs_hidden};
ffn_rs_comm_buffer_ = initCommBuffer(ffn_rs_buffer_shape, params.rs_output_type, ffn_tp_nccl_param_, ffn_tp_ranks, false, stream_);
RTP_LLM_LOG_INFO("[PrepareCommBuffer] prepare ffn_ag_comm_buffer_");
std::vector<size_t> ffn_ag_buffer_shape = {m, params.ffn_ag_hidden};
ffn_ag_comm_buffer_ = initCommBuffer(ffn_ag_buffer_shape, params.ag_input_type, ffn_tp_nccl_param_, ffn_tp_ranks, true, stream_);
RTP_LLM_LOG_INFO("[PrepareCommBuffer] prepare ffn_ag_scale_comm_buffer_");
if (params.enable_per_token_scale) {
std::vector<size_t> ffn_ag_scale_shape = {m, 1};
ffn_ag_scale_comm_buffer_ = initCommBuffer(ffn_ag_scale_shape, DataType::TYPE_FP32, ffn_tp_nccl_param_, ffn_tp_ranks, true, stream_);
}
}
}
void CudaDevice::updateExpertGpuLoads(const MoeConfigs& moe_conf,
const OptionalExpertStats& expert_stats,
BufferPtr expert_ids) {
if (expert_stats.has_value() && expert_ids->size()) {
auto& stats = expert_stats.value();
launch_update_gpu_loads(expert_ids->data<int>(),
stats.getLayerGpuLoads(),
expert_ids->size(),
stats.phy_exp_num,
moe_conf.ep_rank,
moe_conf.ep_size,
stream_);
}
}
void CudaDevice::balanceExperts(BufferPtr expert_ids,
const OptionalExpertStats& expert_stats,
const MoeConfigs& moe_conf,
const FfnLayerWeights& weights) {
if (expert_stats.has_value() && weights.log2phy) {
const auto& expert_stats_v = expert_stats.value();
int* log2phy = weights.log2phy->data<int>();
int* logic_expert_cnt = weights.logic_expert_cnt->data<int>();
switch (moe_conf.balance_method) {
case EplbBalanceMethod::EQUAL:
launch_equal_expert_balance(expert_ids->data<int>(),
expert_stats_v.getLayerLogStats(),
log2phy,
logic_expert_cnt,
expert_stats_v.log_exp_num,
expert_stats_v.phy_exp_num,
expert_ids->size(),
moe_conf.ep_rank,
stream_);
break;
default:
throw std::runtime_error("Unsupported balance method");
break;
}
sync_check_cuda_error();
}
}
}; // namespace rtp_llm