in utils/pipeline_utils.py [0:0]
def flash_attn_func(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
softmax_scale: Optional[float] =None,
causal: bool = False,
# probably wrong type for these 4
qv: Optional[float] = None,
q_descale: Optional[float] = None,
k_descale: Optional[float] = None,
v_descale: Optional[float] = None,
window_size: Optional[List[int]] = None,
sink_token_length: int = 0,
softcap: float = 0.0,
num_splits: int = 1,
# probably wrong type for this too
pack_gqa: Optional[float] = None,
deterministic: bool = False,
sm_margin: int = 0,