in maga_transformer/cpp/devices/rocm_impl/ROCmLayernorm.cc [24:240]
LayernormOutput ROCmDevice::layernorm(const LayernormParams& params) {
BufferPtr input = params.input;
BufferPtr norm_output = input;
BufferPtr output = params.before_norm_output;
float* scales_ptr = nullptr;
int8_t* quant_output = nullptr;
const auto data_type = input->type();
const auto m = input->shape()[0];
const auto n = input->shape()[1];
auto norm_weight = params.norm_weight;
const auto& gamma = norm_weight ? norm_weight->get().gamma.get()->data() : nullptr;
const auto& beta = (norm_weight && norm_weight->get().beta) ? norm_weight->get().beta.get()->data() : nullptr;
const auto norm_type = params.norm_type;
const auto eps = params.eps;
const auto& weights = params.norm_weight;
if (!params.is_inplace && params.qscheme == QScheme::NoQuantize) {
norm_output = allocateBufferLike(*params.input);
} else if (params.qscheme == Qint8PerToken) {
auto kernel = allocateBuffer({DataType::TYPE_INT8, {input->shape()}, AllocationType::DEVICE}, {"kernel"});
auto scales = allocateBuffer({DataType::TYPE_FP32, {input->shape()[1]}, AllocationType::DEVICE}, {"scales"});
norm_output = BufferPtr(new QBuffer(
std::move(kernel),
std::move(scales),
std::move(BufferPtr(new Buffer(MemoryType::MEMORY_GPU, DataType::TYPE_INVALID, {0}, nullptr)))));
quant_output = std::dynamic_pointer_cast<QBuffer>(norm_output)->kernel().data<int8_t>();
scales_ptr = std::dynamic_pointer_cast<QBuffer>(norm_output)->scalesData<float>();
}
if (!weights.has_value()) {
if (params.alpha != 0 || (norm_type == NormType::alphanorm)) {
const auto alpha = params.alpha;
DISPATCH_CUDA_FUNCTION_DATA_TYPE(data_type,
invokeAlphaAddBiasResidual,
norm_output->data(),
input->data(),
params.residual1 ? params.residual1.value().get().data() : nullptr,
params.bias ? params.bias.value().get().data() : nullptr,
params.alpha,
m,
n,
stream_);
sync_check_cuda_error();
return LayernormOutput({std::move(norm_output), nullptr});
} else if (params.bias.has_value() || params.residual1.has_value() || params.residual2.has_value()) {
DISPATCH_CUDA_FUNCTION_DATA_TYPE(data_type,
invokeAddBiasResidual,
output->data(),
input->data(),
params.residual1 ? params.residual1.value().get().data() : nullptr,
params.residual2 ? params.residual2.value().get().data() : nullptr,
params.bias.has_value() ? params.bias.value().get().data() : nullptr,
nullptr, // scale_inter
nullptr, // scale_out
m,
n,
stream_);
sync_check_cuda_error();
return LayernormOutput({std::move(norm_output), nullptr});
} else {
throw OpException(OpErrorType::ERROR_UNIMPLEMENTED);
}
}
if (!(norm_type == NormType::layernorm || norm_type == NormType::rmsnorm)) {
throw OpException(OpErrorType::ERROR_UNIMPLEMENTED);
}
auto quant_data_type = (params.qscheme == QScheme::Qfp8PerTensor) ? DataType::TYPE_FP8_E4M3 : DataType::TYPE_INT8;
if (params.residual1.has_value() || params.bias.has_value()) {
if (params.norm_type == NormType::layernorm) {
if ((!params.bias.has_value()) && (data_type == DataType::TYPE_FP16 && m > 32 && n <= 768)) {
layernorm2d_fwd_traits traits{"fp16", "fp16", "fp32", "fp32", 0, 1, 0};
layernorm2d_fwd_args args{input->data(),
params.residual1.value().get().data(),
nullptr,
nullptr,
gamma,
beta,
norm_output->data(),
(params.before_norm_output == nullptr) ? input->data() :
params.before_norm_output->data(),
nullptr,
nullptr, // p_mean, unsupported yet
nullptr, // p_invStd, unsupported yet
static_cast<float>(eps),
static_cast<int32_t>(m),
static_cast<int32_t>(n),
static_cast<int32_t>(n), // x row_stride
static_cast<int32_t>(n), // x residule row stride
static_cast<int32_t>(n), // y row stride
static_cast<int32_t>(n)}; // y residule row stride
layernorm2d_fwd(traits, args, {stream_, false, 0, 0, 1});
} else {
DISPATCH_CUDA_FUNCTION_DATA_TYPE(
data_type,
invokeGeneralAddBiasResidualLayerNorm,
// add_bias_output.data(),
(params.before_norm_output == nullptr) ? input->data() : params.before_norm_output->data(),
norm_output->data(),
input->data(),
params.bias ? params.bias.value().get().data() : nullptr,
params.residual1 ? params.residual1.value().get().data() : nullptr,
gamma,
beta,
eps,
m,
n,
stream_,
true, // use_diff_of_squares
nullptr, // scale
scales_ptr, // dynamic_scale
quant_output, // out_quant
params.return_normed_output);
}
sync_check_cuda_error();
return LayernormOutput({norm_output, params.before_norm_output});
} else if (params.norm_type == NormType::rmsnorm) {
DISPATCH_CUDA_FUNCTION_COMPUTE_QUANT_TYPES(
data_type,
quant_data_type,
invokeAddBiasResidualRmsNorm,
(params.before_norm_output == nullptr) ? input->data() : params.before_norm_output->data(), // or null
norm_output->data(),
input->data(),
params.bias ? params.bias.value().get().data() : nullptr,
params.residual1 ? params.residual1.value().get().data() : nullptr,
params.residual2 ? params.residual2.value().get().data() : nullptr,
gamma,
beta,
eps,
m,
n,
stream_,
nullptr, // scale
scales_ptr, // dynamic_scale
quant_output // out_quant
);
sync_check_cuda_error();
return LayernormOutput({norm_output, params.before_norm_output});
} else {
throw OpException(OpErrorType::ERROR_UNIMPLEMENTED);
}
} else {
if (params.norm_type == NormType::layernorm) {
if (data_type == DataType::TYPE_FP16 && m > 32 && n <= 768) {
layernorm2d_fwd_traits traits{"fp16", "fp16", "fp32", "fp32", 0, 0, 0};
layernorm2d_fwd_args args{input->data(),
nullptr,
nullptr,
nullptr,
gamma,
beta,
norm_output->data(),
nullptr,
nullptr,
nullptr, // p_mean, unsupported yet
nullptr, // p_invStd, unsupported yet
static_cast<float>(eps),
static_cast<int32_t>(m),
static_cast<int32_t>(n),
static_cast<int32_t>(n), // x row_stride
static_cast<int32_t>(n), // x residule row stride
static_cast<int32_t>(n), // y row stride
static_cast<int32_t>(n)}; // y residule row stride
layernorm2d_fwd(traits, args, {stream_, false, 0, 0, 1});
} else {
DISPATCH_CUDA_FUNCTION_DATA_TYPE(data_type,
invokeGeneralLayerNorm,
nullptr,
norm_output->data(),
input->data(),
gamma,
beta,
eps,
m,
n,
stream_,
true, // use_diff_of_squares
nullptr, // scale
scales_ptr, // dynamic_scale
quant_output, // out_quant
params.return_normed_output);
}
sync_check_cuda_error();
return LayernormOutput({norm_output, params.before_norm_output});
} else if (params.norm_type == NormType::rmsnorm) {
DISPATCH_CUDA_FUNCTION_COMPUTE_QUANT_TYPES(data_type,
quant_data_type,
invokeGeneralRmsNorm,
norm_output->data(),
input->data(),
gamma,
beta,
eps,
m,
n,
stream_,
nullptr, // scale
scales_ptr, // dynamic_scale
quant_output // out_quant
);
sync_check_cuda_error();
return LayernormOutput({norm_output, params.before_norm_output});
} else {
throw OpException(OpErrorType::ERROR_UNIMPLEMENTED);
}
}
throw OpException(OpErrorType::ERROR_UNIMPLEMENTED);
}