in maga_transformer/cpp/devices/base_impl/FfnLayer.cc [15:200]
FfnLayerOutput DeviceBase::ffnLayer(const FfnLayerParams& params) {
RUNTIME_ASSERT_OP_ARG(!params.residual, "default FFN implementation does not support residual!");
BufferPtr output;
if (params.weights.moe_gating_weight) {
RUNTIME_ASSERT_OP_ARG(params.configs.moe_configs, "moe configs not set");
auto moe_output = moeFfnLayer(params);
output = moe_output.hidden_states;
auto shared_expert_output = moeSharedExpert(params).hidden_states;
// for deep ep ll, the gather should be defered afater shared expert.
if (moe_output.moe_combine_output) {
moe_output.comm_barrier_hook->hook_sync();
moe_output = gatherCombineOutput(moe_output.moe_combine_output.value());
output = moe_output.hidden_states;
}
printBufferData(*output, "moe_out_after_barrier");
if (shared_expert_output) {
// just add bias to output
layernorm({
output, nullptr, nullopt, mayGetRef(shared_expert_output)
}).output;
}
} else {
BufferPtr up_output;
bool fuse_gate_up_weight = (params.weights.gate_up_weight != nullptr);
if (isGatedActivation(params.configs.activation_type)) {
BufferPtr ffn_input_ptr = nullptr;
RTP_LLM_LOG_DEBUG("enable_sp %d ffn_tp_size %d", params.enable_sp, init_params_.ffn_tp_size);
if (params.enable_sp && init_params_.ffn_tp_size > 1) {
BufferPtr ag_recv_buffer = nullptr;
size_t pad_token_num = params.input.shape()[0] * init_params_.ffn_tp_size;
if (params.qscheme == NoQuantize) {
ffn_input_ptr = params.input.slice(0, params.input.shape()[0]);
ag_recv_buffer = allocateBuffer({ffn_input_ptr->type(), {pad_token_num, ffn_input_ptr->shape()[1]}}, {"ag_recv_buffer"});
} else if (params.qscheme == Qint8PerToken){
ffn_input_ptr = reinterpret_cast<const QBuffer&>(params.input).qslice(0, params.input.shape()[0]);
BufferPtr kernel = allocateBuffer({ffn_input_ptr->type(), {pad_token_num, ffn_input_ptr->shape()[1]}}, {"ag_recv_buffer"});
BufferPtr scales = allocateBuffer({DataType::TYPE_FP32,
{pad_token_num},
AllocationType::DEVICE},
{"ag_recv_buffer_scale"});
ag_recv_buffer = BufferPtr(new QBuffer(std::move(kernel),
std::move(scales),
std::move(BufferPtr(
new Buffer(MemoryType::MEMORY_GPU,
DataType::TYPE_INVALID,
{0},
nullptr)))));
} else if (params.qscheme == Qfp8PerTensor){
ffn_input_ptr = reinterpret_cast<const QBuffer&>(params.input).qslicePerTensor(0, params.input.shape()[0]);
BufferPtr kernel = allocateBuffer({ffn_input_ptr->type(), {pad_token_num, ffn_input_ptr->shape()[1]}}, {"ag_recv_buffer"});
BufferPtr scales = reinterpret_cast<const QBuffer&>(params.input).scalesPtr();
ag_recv_buffer = BufferPtr(new QBuffer(std::move(kernel),
std::move(scales),
std::move(BufferPtr(
new Buffer(MemoryType::MEMORY_GPU,
DataType::TYPE_INVALID,
{0},
nullptr)))));
} else {
throw OpException({OpErrorType::ERROR_UNIMPLEMENTED, "allGatherloraLinear qscheme type not supported"});
}
printBufferData(*ffn_input_ptr, "ffn_ag_input");
GemmParams up_gemm_params = fuse_gate_up_weight? GemmParams(*ag_recv_buffer, *(params.weights.gate_up_weight->kernel)):
GemmParams(*ag_recv_buffer, *(params.weights.up_weight->kernel));
AllGatherLoraLinearOutput all_gather_output = allGatherloraLinear({LoraLinearParams(up_gemm_params, params.lora_input.up_lora_input), ffn_input_ptr, ag_recv_buffer, params.qscheme, params.output->type(), ParallelMode::FFN_TP});
// syncAndCheck();
ffn_input_ptr = all_gather_output.all_gather_recv_buffer;
up_output = all_gather_output.output;
printBufferData(*ffn_input_ptr, "ffn_ag_inter_output");
printBufferData(*up_output, "ffn_ag_final_output");
} else {
printBufferData(params.input, "input");
GemmParams up_gemm_params = fuse_gate_up_weight? GemmParams(params.input, *(params.weights.gate_up_weight->kernel)):
GemmParams(params.input, *(params.weights.up_weight->kernel));
up_output = loraLinear(LoraLinearParams(up_gemm_params, params.lora_input.up_lora_input)).output;
printBufferData(*up_output, "ffn_up");
}
if (!fuse_gate_up_weight) {
BufferPtr gate_output = nullptr;
if (params.enable_sp && init_params_.ffn_tp_size > 1) {
GemmParams gate_gemm_params = GemmParams(*ffn_input_ptr, *(params.weights.gate_weight->kernel));
gate_output = loraLinear(LoraLinearParams(gate_gemm_params, params.lora_input.gate_lora_input)).output;
} else {
GemmParams gate_gemm_params = GemmParams(params.input, *(params.weights.gate_weight->kernel));
gate_output = loraLinear(LoraLinearParams(gate_gemm_params, params.lora_input.gate_lora_input)).output;
}
printBufferData(*gate_output, "ffn_gate");
activation({params.configs.activation_type,
up_output,
mayGetRef(params.weights.up_weight->bias),
*gate_output,
std::nullopt,
mayGetRef(params.weights.act_scale)});
} else {
printBufferData(*up_output, "ffn_up_gate");
bool is_cuda = init_params_.device_type == DeviceType::Cuda;
if (is_cuda && (params.configs.activation_type == ActivationType::Swiglu ||
params.configs.activation_type == ActivationType::Silu ||
params.configs.activation_type == ActivationType::Gelu)) {
auto act_output = allocateBuffer({up_output->type(), {up_output->shape()[0], up_output->shape()[1] / 2}, AllocationType::DEVICE});
activation({params.configs.activation_type,
up_output,
std::nullopt,
std::nullopt,
std::nullopt,
std::nullopt,
act_output,
true});
up_output = std::move(act_output);
} else {
torch::Tensor gate_up_output_torch_tensor = Buffer2torchTensor(up_output, false);
std::vector<torch::Tensor> split_tensors = torch::chunk(gate_up_output_torch_tensor, 2, -1);
torch::Tensor first_half = split_tensors[0].clone();
torch::Tensor second_half = split_tensors[1].clone();
up_output = torchTensor2Buffer(second_half);
BufferPtr gate_output = torchTensor2Buffer(first_half);
auto act_output = allocateBuffer({up_output->type(), {up_output->shape()[0], up_output->shape()[1] / 2}, AllocationType::DEVICE});
activation({params.configs.activation_type,
up_output,
std::nullopt,
*gate_output,
std::nullopt,
mayGetRef(params.weights.act_scale),
act_output});
up_output = std::move(act_output);
}
}
} else {
RTP_LLM_CHECK_WITH_INFO(!params.enable_sp, "enable_sp is not supported for non-gated activation");
auto up_gemm_params = GemmParams(params.input, *(params.weights.up_weight->kernel));
auto lora_linear_params = LoraLinearParams(up_gemm_params, params.lora_input.up_lora_input);
auto activation_params = ActivationParams(params.configs.activation_type,
nullptr,
mayGetRef(params.weights.up_weight->bias),
std::nullopt,
std::nullopt,
mayGetRef(params.weights.act_scale));
up_output = loraLinearWithActivation({lora_linear_params, activation_params});
}
if (params.qscheme != QScheme::NoQuantize && params.qscheme != QScheme::Qfp8PerTokenBlock) {
DataType quant_out_data_type = params.qscheme == QScheme::Qfp8PerTensor || params.qscheme == QScheme::Qfp8PerTokenBlock ? DataType::TYPE_FP8_E4M3 : DataType::TYPE_INT8;
auto quant_params = QuantizeParams(
*up_output,
quant_out_data_type,
1,
params.qscheme,
params.weights.smoother_weight ? (OptionalConstBufferRef) * (params.weights.smoother_weight->kernel) :
std::nullopt,
std::nullopt,
params.weights.intermediate_weight2_static_scale_weight ?
(OptionalConstBufferRef) * (params.weights.intermediate_weight2_static_scale_weight->kernel) :
std::nullopt,
params.weights.intermediate_weight2_static_scale_reciprocal_weight ?
(OptionalConstBufferRef)
* (params.weights.intermediate_weight2_static_scale_reciprocal_weight->kernel) :
std::nullopt);
up_output = quantize(quant_params);
}
printBufferData(*up_output, "ffn_act");
if (params.enable_sp && init_params_.ffn_tp_size > 1) {
BufferPtr gemm_output = allocateBuffer({params.output->type(), {up_output->shape()[0], params.weights.down_weight->kernel->shape()[1]}},
{"ffn_rs_input"});
GemmParams down_gemm_params = GemmParams(*(up_output), *(params.weights.down_weight->kernel), nullopt, gemm_output);
ReduceScatterLoraLinearOutput reduce_scatter_output = loraLinearReduceScatter({LoraLinearParams(down_gemm_params, params.lora_input.down_lora_input), params.output, params.qscheme, params.output->type(), ParallelMode::FFN_TP});
// syncAndCheck();
gemm_output = reduce_scatter_output.reduce_scatter_recv_buffer;
output = reduce_scatter_output.output;
printBufferData(*gemm_output, "ffn_rs_inter_output");
printBufferData(*output, "ffn_rs_final_output");
} else {
auto down_gemm_params = GemmParams(*(up_output), *(params.weights.down_weight->kernel), nullopt, params.output);
output = loraLinear(LoraLinearParams(down_gemm_params, params.lora_input.down_lora_input)).output;
}
}
printBufferData(*output, "ffn_out");
return FfnLayerOutput({std::move(output)});
}