candle-flash-attn-v1/kernels/flash_api.cu (90 lines of code) (raw):
#include "fmha.h"
#include "fmha_utils.h"
void run_fmha_fwd(Launch_params<FMHA_fprop_params> &launch_params) {
if (launch_params.params.d <= 32) {
run_fmha_fwd_hdim32(launch_params);
} else if (launch_params.params.d <= 64) {
run_fmha_fwd_hdim64(launch_params);
} else if (launch_params.params.d <= 128) {
run_fmha_fwd_hdim128(launch_params);
}
}
extern "C" void run_mha(
void *q_ptr,
void *k_ptr,
void *v_ptr,
void *o_ptr,
void *o_tmp_ptr,
void *softmax_lse_ptr,
int32_t *cu_seqlens_q_ptr,
int32_t *cu_seqlens_k_ptr,
uint32_t q_row_stride,
uint32_t k_row_stride,
uint32_t v_row_stride,
uint32_t o_row_stride,
uint32_t o_tmp_row_stride,
uint32_t q_head_stride,
uint32_t k_head_stride,
uint32_t v_head_stride,
uint32_t o_head_stride,
uint32_t o_tmp_head_stride,
uint32_t b,
uint32_t h,
uint32_t d,
float softmax_scale,
uint32_t seqlen_q,
uint32_t seqlen_k,
int is_causal,
int is_bf16,
int32_t multi_processor_count,
int32_t num_splits
) {
Data_type data_type = !is_bf16 ? DATA_TYPE_FP16 : DATA_TYPE_BF16;
Launch_params<FMHA_fprop_params> launch_params;
launch_params.elts_per_thread = 0;
launch_params.multi_processor_count = multi_processor_count;
launch_params.stream = 0;
launch_params.is_dropout = false;
launch_params.return_softmax = false;
FMHA_fprop_params ¶ms = launch_params.params;
// Set the pointers and strides.
params.q_ptr = q_ptr;
params.k_ptr = k_ptr;
params.v_ptr = v_ptr;
params.o_ptr = o_ptr;
params.o_tmp_ptr = o_tmp_ptr;
params.softmax_lse_ptr = softmax_lse_ptr;
// All stride are in elements, not bytes.
params.q_row_stride_in_elts = q_row_stride;
params.k_row_stride_in_elts = k_row_stride;
params.v_row_stride_in_elts = v_row_stride;
params.o_row_stride_in_elts = o_row_stride;
params.o_tmp_row_stride_in_elts = o_tmp_row_stride;
params.q_head_stride_in_elts = q_head_stride;
params.k_head_stride_in_elts = k_head_stride;
params.v_head_stride_in_elts = v_head_stride;
params.o_head_stride_in_elts = o_head_stride;
params.o_tmp_head_stride_in_elts = o_tmp_head_stride;
// Set the dimensions.
params.h = h;
params.b = b;
params.seqlen_q = seqlen_q;
params.seqlen_k = seqlen_k;
params.d = d;
// Set the different scale values.
const float scale_bmm1 = softmax_scale;
params.scale_bmm1f = scale_bmm1;
set_alpha(params.scale_bmm1, scale_bmm1, data_type);
params.p_dropout = 1.; // probability to keep
params.p_dropout_in_uint = uint32_t(std::floor(params.p_dropout * 4294967295.0));
params.p_dropout_in_uint16_t = uint16_t(std::floor(params.p_dropout * 65535.0));
params.rp_dropout = 1.f / params.p_dropout;
params.scale_bmm1_rp_dropout = params.rp_dropout * params.scale_bmm1f;
set_alpha(params.scale_dropout, params.rp_dropout, data_type);
params.is_bf16 = is_bf16;
params.is_causal = is_causal;
params.cu_seqlens_q = cu_seqlens_q_ptr;
params.cu_seqlens_k = cu_seqlens_k_ptr;
params.num_splits = num_splits;
run_fmha_fwd(launch_params);
}