in kernels/fmha_api.cpp [38:122]
void set_params_fprop(FMHA_fprop_params ¶ms,
// sizes
const size_t b,
const size_t seqlen_q,
const size_t seqlen_k,
const size_t h,
const size_t d,
// device pointers
const at::Tensor q,
const at::Tensor k,
const at::Tensor v,
at::Tensor out,
void *cu_seqlens_q_d,
void *cu_seqlens_k_d,
void *o_tmp_d,
void *s_d,
void *softmax_lse_d,
float p_dropout,
float softmax_scale,
bool is_causal,
int num_splits) {
Data_type acc_type = DATA_TYPE_FP32;
Data_type data_type = !(q.dtype() == torch::kBFloat16) ? DATA_TYPE_FP16 : DATA_TYPE_BF16;
// Reset the parameters
memset(¶ms, 0, sizeof(params));
params.is_bf16 = q.dtype() == torch::kBFloat16;
// Set the pointers and strides.
params.q_ptr = q.data_ptr();
params.k_ptr = k.data_ptr();
params.v_ptr = v.data_ptr();
params.q_row_stride_in_elts = q.stride(0);
params.k_row_stride_in_elts = k.stride(0);
params.v_row_stride_in_elts = v.stride(0);
params.q_head_stride_in_elts = q.stride(1);
params.k_head_stride_in_elts = k.stride(1);
params.v_head_stride_in_elts = v.stride(1);
params.o_ptr = out.data_ptr();
params.o_row_stride_in_elts = out.stride(0);
params.o_head_stride_in_elts = out.stride(1);
params.o_tmp_ptr = o_tmp_d;
params.o_tmp_row_stride_in_elts = h * d;
params.o_tmp_head_stride_in_elts = d;
params.cu_seqlens_q = static_cast<int *>(cu_seqlens_q_d);
params.cu_seqlens_k = static_cast<int *>(cu_seqlens_k_d);
// S = softmax(P)
params.s_ptr = s_d;
params.s_stride_in_bytes = get_size_in_bytes(b * h * seqlen_k, data_type);
// Softmax sum
params.softmax_lse_ptr = softmax_lse_d;
// Set the dimensions.
params.b = b;
params.h = h;
params.seqlen_q = seqlen_q;
params.seqlen_k = seqlen_k;
params.d = d;
// Set the different scale values.
// const float scale_bmm1 = 1.f / sqrtf(d);
const float scale_bmm1 = softmax_scale;
params.scale_bmm1f = scale_bmm1;
set_alpha(params.scale_bmm1, scale_bmm1, data_type);
// Set this to probability of keeping an element to simplify things.
params.p_dropout = 1.f - p_dropout;
// Convert p from float to int so we don't have to convert the random uint to float to compare.
// [Minor] We want to round down since when we do the comparison we use <= instead of <
params.p_dropout_in_uint = uint32_t(std::floor(params.p_dropout * 4294967295.0));
params.p_dropout_in_uint16_t = uint16_t(std::floor(params.p_dropout * 65535.0));
params.rp_dropout = 1.f / params.p_dropout;
params.scale_bmm1_rp_dropout = params.rp_dropout * params.scale_bmm1f;
TORCH_CHECK(p_dropout < 1.f);
set_alpha(params.scale_dropout, params.rp_dropout, data_type);
params.is_causal = is_causal;
params.num_splits = num_splits;
}