in maga_transformer/cpp/devices/cuda_impl/CudaFP8Moe.cc [17:243]
FfnLayerOutput CudaDevice::moeFfnFp8(const FfnLayerParams& params, const MoeGateSelectOutput& gate_outputs) {
#ifdef ENABLE_FP8
using T = __nv_bfloat16;
RUNTIME_ASSERT_OP_ARG(params.configs.moe_configs, "moe configs not set");
BufferPtr hidden_fp8;
BufferPtr hidden_fp8_scales;
BufferPtr quantize_buffer;
printBufferData(params.input, "moeFfnFp8_input");
if (params.input.isQBuffer()) {
hidden_fp8 = reinterpret_cast<const QBuffer&>(params.input).kernelPtr();
hidden_fp8_scales = reinterpret_cast<const QBuffer&>(params.input).scalesPtr();
} else {
quantize_buffer = quantize({params.input, DataType::TYPE_QFP8_E4M3, 1, params.qscheme});
hidden_fp8 = std::dynamic_pointer_cast<QBuffer>(quantize_buffer)->kernelPtr();
hidden_fp8_scales = std::dynamic_pointer_cast<QBuffer>(quantize_buffer)->scalesPtr();
}
const auto& moe_conf = params.configs.moe_configs.value();
bool is_gated_activation = isGatedActivation(params.configs.activation_type);
const auto& weights = params.weights;
const auto token_num = hidden_fp8->shape()[0];
const auto hidden_size = hidden_fp8->shape()[1];
BufferPtr output = nullptr;
if (params.output) {
output = params.output;
} else {
output = allocateBuffer({DataType::TYPE_BF16, {token_num, hidden_size}});
}
if (token_num == 0) {
return {output};
}
const auto num_experts = moe_conf.expert_num + moe_conf.extra_expert_num;
const auto top_k = moe_conf.top_k;
const auto moe_inter_size = moe_conf.moe_inter_padding_size;
const size_t num_experts_per_node = num_experts / moe_conf.ep_size;
const auto src_row_to_dst = allocateBuffer({DataType::TYPE_INT32, {top_k, token_num}}, {"moe_src_to_dst"});
cudaMemsetAsync(src_row_to_dst->data(), -1, src_row_to_dst->sizeBytes(), stream_);
const auto source_rows = allocateBuffer({DataType::TYPE_INT32, {token_num, top_k}}, {"source_rows"});
const auto permuted_experts = allocateBuffer({DataType::TYPE_INT32, {top_k, token_num}}, {"permuted_experts"});
const auto permuted_rows = allocateBuffer({DataType::TYPE_INT32, {token_num, top_k}}, {"permuted_rows"});
const auto expert_first_token_offset =
allocateBuffer({DataType::TYPE_INT64, {num_experts_per_node + 1}}, {"expert_first_token_offset"});
trt::CubKeyValueSorter sorter(num_experts);
const size_t sorter_size =
pad_to_multiple_of_128(trt::CubKeyValueSorter::getWorkspaceSize(token_num * top_k, num_experts));
const auto sorter_ws = allocateBuffer({DataType::TYPE_BYTES, {sorter_size}}, {"sorter_ws"});
int const start_expert = num_experts_per_node * moe_conf.ep_rank;
int const end_expert = start_expert + num_experts_per_node;
auto expert_for_source_row = gate_outputs.expert_ids;
auto expert_scales = gate_outputs.expert_scales;
printBufferData(*expert_for_source_row, "expert_for_source_row");
// these logics are from DeepEPDispatch, might could be fused.
if (expert_for_source_row->type() != DataType::TYPE_INT32) {
auto expert_idx_tensor = Buffer2torchTensor(expert_for_source_row, false);
expert_for_source_row = torchTensor2BufferWithDstType(expert_idx_tensor, torch::kInt32);
tensorrt_llm::kernels::genSourceRowRevert(
expert_for_source_row->data<int>(),
expert_for_source_row->shape()[0],
expert_for_source_row->shape()[1],
num_experts_per_node * moe_conf.ep_rank,
stream_);
}
trt::genSourceRow(expert_for_source_row->data<int>(),
source_rows->data<int>(),
token_num,
top_k,
num_experts,
start_expert,
end_expert,
stream_);
printBufferData(*source_rows, "source_rows");
trt::sortAndScanSoftmaxOutput(expert_for_source_row->data<int>(),
source_rows->data<int>(),
permuted_experts->data<int>(),
permuted_rows->data<int>(),
expert_first_token_offset->data<int64_t>(),
token_num,
num_experts,
num_experts_per_node,
top_k,
sorter,
static_cast<void*>(sorter_ws->data()),
stream_);
printBufferData(*permuted_experts, "permuted_experts");
printBufferData(*permuted_rows, "permuted_rows");
printBufferData(*expert_first_token_offset, "expert_first_token_offset");
sync_check_cuda_error();
const auto expert_first_token_offset_host = clone({*expert_first_token_offset, AllocationType::HOST});
int64_t* expert_first_token_offset_host_ptr = expert_first_token_offset_host->data<int64_t>();
size_t total_padding_num = 0;
const auto permuted_src_row_to_dst =
allocateBuffer({DataType::TYPE_INT32, {token_num * top_k}, AllocationType::HOST}, {"permuted_rows"});
int* permuted_src_row_to_dst_ptr = permuted_src_row_to_dst->data<int>();
BufferPtr padding_group_index = allocateBuffer(
{DataType::TYPE_INT32, {pad_to_multiple_of_128(token_num) * num_experts_per_node}, AllocationType::HOST},
{"padding_group_index"});
int* padding_group_index_ptr = padding_group_index->data<int>();
for (int i = 0; i < num_experts_per_node; ++i) {
size_t src_row_offset = expert_first_token_offset_host_ptr[i];
size_t num_row_now = expert_first_token_offset_host_ptr[i + 1] - expert_first_token_offset_host_ptr[i];
for (int j = 0; j < num_row_now; ++j) {
permuted_src_row_to_dst_ptr[src_row_offset + j] = total_padding_num + j;
}
size_t padding_size = pad_to_multiple_of_128(num_row_now);
for (int j = 0; j < padding_size; ++j) {
padding_group_index_ptr[total_padding_num + j] = i;
}
total_padding_num += padding_size;
}
BufferPtr permuted_src_row_to_dst_device = clone({*permuted_src_row_to_dst});
BufferPtr padding_group_index_device = clone({*padding_group_index});
cudaStreamSynchronize(stream_);
int64_t dest_num_rows = expert_first_token_offset_host_ptr[num_experts_per_node];
BufferPtr permuted_padding_input =
allocateBuffer({DataType::TYPE_FP8_E4M3, {total_padding_num, hidden_size}}, {"permuted_padding_input"});
BufferPtr permuted_padding_input_fp8_scales = allocateBuffer(
{DataType::TYPE_FP32, {total_padding_num, hidden_size / 128}}, {"permuted_padding_input_fp8_scales"});
BufferPtr permuted_padding_scales =
allocateBuffer({DataType::TYPE_FP32, {total_padding_num}}, {"permuted_padding_scales"});
printBufferData(*hidden_fp8, "moe_hidden_fp8");
printBufferData(*hidden_fp8_scales, "moe_hidden_fp8_scales");
expandInputRowsKernelLauncherContiguous<__nv_fp8_e4m3>(hidden_fp8->data<__nv_fp8_e4m3>(),
hidden_fp8_scales->data<float>(),
permuted_padding_input->data<__nv_fp8_e4m3>(),
permuted_padding_input_fp8_scales->data<float>(),
expert_scales->data<float>(),
permuted_padding_scales->data<float>(),
permuted_rows->data<int>(),
permuted_src_row_to_dst_device->data<int>(),
src_row_to_dst->data<int>(),
token_num,
dest_num_rows,
hidden_size,
top_k,
stream_);
sync_check_cuda_error();
BufferPtr fc1_result;
if (is_gated_activation) {
fc1_result =
allocateBuffer({DataType::TYPE_BF16, {total_padding_num, (size_t)moe_inter_size * 2}}, {"fc1_result"});
} else {
fc1_result = allocateBuffer({DataType::TYPE_BF16, {total_padding_num, (size_t)moe_inter_size}}, {"fc1_result"});
}
BufferPtr permuted_padding_input_fp8(
new QBuffer(std::move(permuted_padding_input),
std::move(permuted_padding_input_fp8_scales),
std::move(BufferPtr(new Buffer(MemoryType::MEMORY_GPU, DataType::TYPE_INVALID, {0}, nullptr)))));
printBufferData(*permuted_padding_input_fp8, "fc1_input_fp8");
printBufferData(*weights.moe_gate_weight->kernel, "moe_gate_weight");
printBufferData(*padding_group_index_device, "padding_group_index_device");
DeepGemmPlugin::groupedGemmFp8Contiguous(*permuted_padding_input_fp8,
*weights.moe_gate_weight->kernel,
*fc1_result,
padding_group_index_device->view(0, total_padding_num),
stream_);
printBufferData(*fc1_result, "fc1_result");
sync_check_cuda_error();
using GemmOutputType = __nv_bfloat16;
using ScaleBiasType = __nv_bfloat16;
BufferPtr fc1_activation =
allocateBuffer({DataType::TYPE_FP8_E4M3, {total_padding_num, (size_t)moe_inter_size}}, {"fc1_activation"});
BufferPtr fc1_activation_fp8_scales = allocateBuffer(
{DataType::TYPE_FP32, {total_padding_num, (size_t)moe_inter_size / 128}}, {"fc1_activation_fp8_scales"});
doActivationContiguous<GemmOutputType, ScaleBiasType>(
fc1_activation->data<__nv_fp8_e4m3>(),
fc1_activation_fp8_scales->data<float>(),
static_cast<GemmOutputType const*>(fc1_result->data<T>()),
(ScaleBiasType*)OPTIONAL_BUFFER_GET_DATA_OR_NULLPTR(weights.moe_gate_weight->bias),
true,
permuted_src_row_to_dst_device->data<int>(),
dest_num_rows,
moe_inter_size,
params.configs.activation_type,
permuted_experts->data<int>(),
stream_);
fc1_result.reset();
sync_check_cuda_error();
const auto fc2_result = allocateBuffer({DataType::TYPE_BF16, {total_padding_num, hidden_size}}, {"fc2_result"});
BufferPtr fc1_activation_fp8(
new QBuffer(std::move(fc1_activation),
std::move(fc1_activation_fp8_scales),
std::move(BufferPtr(new Buffer(MemoryType::MEMORY_GPU, DataType::TYPE_INVALID, {0}, nullptr)))));
printBufferData(*fc1_activation_fp8, "fc1_activation_fp8");
DeepGemmPlugin::groupedGemmFp8Contiguous(*fc1_activation_fp8,
*weights.moe_down_weight->kernel,
*fc2_result,
padding_group_index_device->view(0, total_padding_num),
stream_);
printBufferData(*fc2_result, "fc2_result");
sync_check_cuda_error();
using OutputType = __nv_bfloat16;
trt::MOEParallelismConfig parallel_config(1, 0, moe_conf.ep_size, moe_conf.ep_rank);
trt::finalizeMoeRoutingKernelLauncher<OutputType, OutputType, GemmOutputType, ScaleBiasType>(
fc2_result->data<GemmOutputType>(),
output->data<OutputType>(),
(ScaleBiasType*)OPTIONAL_BUFFER_GET_DATA_OR_NULLPTR(weights.moe_down_weight->bias),
expert_scales->data<float>(),
src_row_to_dst->data<int>(),
expert_for_source_row->data<int>(),
token_num,
hidden_size,
top_k,
nullptr,
parallel_config,
trt::MOEExpertScaleNormalizationMode::NONE,
stream_);
printBufferData(*output, "moe_ffn_out");
sync_check_cuda_error();
return {output};
#else
throw OpException(OpErrorType::ERROR_UNIMPLEMENTED);
return {nullptr};
#endif
}