csrc/kernels/mla_combine.h (4 lines of code) (raw):

#pragma once #include "params.h" template<typename ElementT> void run_flash_mla_combine_kernel(Flash_fwd_mla_params &params, cudaStream_t stream);