in maga_transformer/cpp/rocm/rocmFmhaWrapper.cc [20:289]
uint32_t rocmFmhaWrapper::runCKFmha(void* q,
void* k,
void* v,
void* output,
void* softmax_lse_,
size_t batch_size,
size_t seq_len,
void* seqstart_q,
void* seqstart_k,
void* lse_acc_buf,
void* linear_bias_slopes,
void* biasBuffer) {
// map parms from FT to CK
mode_enum mode = mode_enum::group;
auto data_type = getDataTypeStr(dtype_);
auto batch = static_cast<ck_tile::index_t>(batch_size);
auto nhead = static_cast<ck_tile::index_t>(head_num_);
auto nhead_k = static_cast<ck_tile::index_t>(kv_head_num_);
if (nhead_k < 0)
nhead_k = nhead;
if (nhead % nhead_k != 0) {
std::cerr << "nhead:" << nhead << " must be multiple of nhead_k:" << nhead_k << std::endl;
return false;
}
ck_tile::index_t seqlen_q = seq_len;
ck_tile::index_t seqlen_k = seq_len;
if (seqlen_k < 0)
seqlen_k = seqlen_q;
// auto [seqlen_qs, seqlen_ks, seqlen_kpads] = decode_seqlen(mode_enum::batch,
// batch,
// std::to_string(seqlen_q),
// std::to_string(seqlen_k),
// "-1");
ck_tile::index_t hdim_q = size_per_head_;
ck_tile::index_t hdim_v = size_per_head_;
if (hdim_v < 0)
hdim_v = hdim_q;
// the output of add_fusedQKV_bias_transpose_kernel:
// This kernel add bias to QKV, which has shape [batch_size, seq_len, 3, head_num, size_per_head], and
// QKV split to 3 split buffer q, k, v and transpose them to [batch_size, head_num, seq_len, size_per_head].
// For q and k, also apply the rotary embedding.
bool i_perm = false; // if true, will be batch * nhead * seqlen * hdim
bool o_perm = false; // if false, will be batch * seqlen * nhead * hdim
float scale_s = 0.f;
if (scale_s == .0f)
scale_s = 1.0 / ck_tile::sqrt(static_cast<float>(hdim_q)); // TODO: q ? v ?
auto squant = false;
float scale_p = 1.f;
float scale_o = 1.f;
// if (squant) {
// scale_s = scale_s * (range_q / dtype_max) * (range_k / dtype_max);
// scale_p = dtype_max / range_p;
// // scale_p = [max(fp8_t)/range_o] * [range_p/max(fp8_t)] * [range_v/max(fp8_t)]
// scale_o = range_p * range_v / range_o / dtype_max;
// }
bool is_v_rowmajor = true;
auto lse = softmax_lse_ ? true : false;
std::string msk_str;
if (mtype_ == AttentionMaskType::noMask) {
msk_str="0";
}
else if (mtype_ == AttentionMaskType::causalMask)
{
msk_str = "b";
// RTP_LLM_LOG_INFO("Using causal_bottom_right Mask");
}
else
{
RTP_LLM_LOG_ERROR("Mask type not supported");
}
bias_info bias = bias_info::decode(linear_bias_slopes ? "a" : "n");
mask_info mask = mask_info::decode(msk_str, seqlen_q, seqlen_k); // TODO: we don't need x/y anymore
float p_drop = 0.;
uint64_t drop_seed = 1.;
uint64_t drop_offset = 0.;
bool drop_prefs = false;
if (p_drop < 0.0f || p_drop > 1.0f) {
std::cerr << "The value of p_drop should be 0~1" << std::endl;
return false;
}
bool s_randval = false;
if(p_drop > 0.0f)
{
s_randval = true;
}
int num_splits = 1;
const ck_tile::index_t max_seqlen_q = seq_len; // max of all batch
const ck_tile::index_t max_seqlen_k = seq_len;
// host memory for storing all the tensor elements
const ck_tile::index_t shape_batch = (mode == mode_enum::batch ? batch : 1);
const ck_tile::index_t shape_seqlen_q = (mode == mode_enum::batch ? seq_len : seq_len);
const ck_tile::index_t shape_seqlen_k = (mode == mode_enum::batch ? seq_len : seq_len);
const ck_tile::index_t seqlen_knew = 0;
ck_tile::HostTensor<ck_tile::half_t> lse_acc_host(
1 < num_splits ? std::array<ck_tile::index_t, 4>{num_splits, batch, nhead, max_seqlen_q} :
std::array<ck_tile::index_t, 4>{1, 1, 1, 1});
if(lse_acc_buf == nullptr)
{
//printf("[CK] size = %d\n", lse_acc_host.get_element_space_size_in_bytes());
return lse_acc_host.get_element_space_size_in_bytes();
}
//ck_tile::DeviceMem lse_acc_buf(lse_acc_host.get_element_space_size_in_bytes());
auto fmha_traits = fmha_fwd_traits{hdim_q,
hdim_v,
data_type,
mode == mode_enum::group,
is_v_rowmajor,
mask.type,
bias.type,
lse,
p_drop > 0.0f,
squant};
auto fmha_args = [&]() {
assert(nhead % nhead_k == 0);
// QKV, which has shape [batch_size, seq_len, 3, head_num, size_per_head]
const ck_tile::index_t stride_q = (i_perm ? hdim_q : (nhead + 2 * nhead_k) * hdim_q);
const ck_tile::index_t stride_k = (i_perm ? hdim_q : (nhead + 2 * nhead_k) * hdim_q);
const ck_tile::index_t stride_knew = (i_perm ? hdim_q : (nhead + 2 * nhead_k) * hdim_q);
const ck_tile::index_t stride_v = [&]() {
if(is_v_rowmajor)
return i_perm ? hdim_v : (nhead + 2 * nhead_k) * hdim_v;
else
return i_perm ? shape_seqlen_k : (nhead + 2 * nhead_k) * shape_seqlen_k;
}();
const ck_tile::index_t stride_vnew = [&]() {
if(is_v_rowmajor)
return i_perm ? hdim_v : (nhead + 2 * nhead_k) * hdim_v;
else
return i_perm ? seqlen_knew : (nhead + 2 * nhead_k) * seqlen_knew;
}();
const ck_tile::index_t stride_bias = (i_perm ? shape_seqlen_k : 1 * shape_seqlen_k);
const ck_tile::index_t stride_randval = (max_seqlen_k);
const ck_tile::index_t stride_o_acc = hdim_v;
const ck_tile::index_t stride_o = (o_perm ? hdim_v : nhead * hdim_v);
// setup nhead_stride_* arguments
const ck_tile::index_t nhead_stride_q = (i_perm ? shape_seqlen_q * hdim_q : hdim_q);
const ck_tile::index_t nhead_stride_k = (i_perm ? shape_seqlen_k * hdim_q : hdim_q);
const ck_tile::index_t nhead_stride_knew = (i_perm ? seqlen_knew * hdim_q : hdim_q);
const ck_tile::index_t nhead_stride_v = [&]() {
if(is_v_rowmajor)
return i_perm ? shape_seqlen_k * hdim_v : hdim_v;
else
return i_perm ? hdim_v * shape_seqlen_k : shape_seqlen_k;
}();
const ck_tile::index_t nhead_stride_vnew = [&]() {
if(is_v_rowmajor)
return i_perm ? seqlen_knew * hdim_v : hdim_v;
else
return i_perm ? hdim_v * seqlen_knew : seqlen_knew;
}();
const ck_tile::index_t nhead_stride_bias =
(i_perm ? 0 * shape_seqlen_q * shape_seqlen_k : 0 * shape_seqlen_k);
const ck_tile::index_t nhead_stride_randval = (shape_seqlen_q * max_seqlen_k);
const ck_tile::index_t nhead_stride_lse = shape_seqlen_q;
const ck_tile::index_t nhead_stride_lse_acc = shape_seqlen_q;
const ck_tile::index_t nhead_stride_o_acc = (max_seqlen_q * hdim_v);
const ck_tile::index_t nhead_stride_o = (o_perm ? shape_seqlen_q * hdim_v : hdim_v);
// setup batch_stride_* arguments
const ck_tile::index_t batch_stride_q = (nhead * shape_seqlen_q * hdim_q);
const ck_tile::index_t batch_stride_k = (nhead_k * shape_seqlen_k * hdim_q);
const ck_tile::index_t batch_stride_knew = (nhead_k * seqlen_knew * hdim_q);
const ck_tile::index_t batch_stride_v = (nhead_k * hdim_v * shape_seqlen_k);
const ck_tile::index_t batch_stride_vnew = (nhead_k * hdim_v * seqlen_knew);
const ck_tile::index_t batch_stride_bias = (0 * nhead * shape_seqlen_q * shape_seqlen_k);
const ck_tile::index_t batch_stride_randval = (nhead * shape_seqlen_q * max_seqlen_k);
const ck_tile::index_t batch_stride_lse = (nhead * shape_seqlen_q);
const ck_tile::index_t batch_stride_lse_acc = (nhead * shape_seqlen_q);
const ck_tile::index_t batch_stride_o_acc = (nhead * max_seqlen_q * hdim_v);
const ck_tile::index_t batch_stride_o = (nhead * shape_seqlen_q * hdim_v);
// setup split_stride_* arguments (only used in split-kv kernel)
const ck_tile::index_t split_stride_lse_acc = (shape_batch * nhead * shape_seqlen_q);
const ck_tile::index_t split_stride_o_acc = (batch * nhead * max_seqlen_q * hdim_v);
return fmha_fwd_args{q,
k,
v,
bias.type == bias_enum::alibi ? linear_bias_slopes : biasBuffer,
nullptr, // randval_buf.GetDeviceBuffer(),
//lse_acc_buf.GetDeviceBuffer(), // lse_acc_buf.GetDeviceBuffer(),
//lse_acc_buf, // lse_acc_buf.GetDeviceBuffer(),
//nullptr, // o_acc_buf.GetDeviceBuffer(),
softmax_lse_,
output,
seqstart_q,
seqstart_k,
nullptr, // seqlen_kpads
shape_seqlen_q,
shape_seqlen_k,
batch,
max_seqlen_q,
hdim_q,
hdim_v,
nhead,
nhead_k,
//num_splits,
scale_s,
scale_p,
scale_o,
stride_q,
stride_k,
stride_v,
bias.type == bias_enum::alibi ? (bias.rank_info == 0 ? 0 : nhead) : stride_bias,
stride_randval,
//stride_o_acc,
stride_o,
nhead_stride_q,
nhead_stride_k,
nhead_stride_v,
nhead_stride_bias,
nhead_stride_randval,
nhead_stride_lse,
//nhead_stride_lse_acc,
//nhead_stride_o_acc,
nhead_stride_o,
batch_stride_q,
batch_stride_k,
batch_stride_v,
batch_stride_bias,
batch_stride_randval,
batch_stride_lse,
//batch_stride_lse_acc,
//batch_stride_o_acc,
batch_stride_o,
//split_stride_lse_acc,
//split_stride_o_acc,
mask.left,
mask.right,
static_cast<ck_tile::index_t>(mask.type),
p_drop,
s_randval,
std::make_pair(drop_seed, drop_offset)};
}();
ck_tile::stream_config stream_config{
stream_, // stream_id_
false, // time_kernel_
0, // log_level_
0, // cold_niters_
1, // nrepeat_
// false //
};
float run_time = fmha_fwd(fmha_traits, fmha_args, stream_config);
// std::cout << "\nrun_time for ck fmha_fwd: " << run_time << std::endl;
if (run_time < 0) {
CK_FAIL("fmha_fwd faild");
} else {
return 1;
}
}