def flash_attn_func()

in utils/pipeline_utils.py [0:0]


def flash_attn_func(
    q: torch.Tensor,
    k: torch.Tensor,
    v: torch.Tensor,
    softmax_scale: Optional[float] =None,
    causal: bool = False,
    # probably wrong type for these 4
    qv: Optional[float] = None,
    q_descale: Optional[float] = None,
    k_descale: Optional[float] = None,
    v_descale: Optional[float] = None,
    window_size: Optional[List[int]] = None,
    sink_token_length: int = 0,
    softcap: float = 0.0,
    num_splits: int = 1,
    # probably wrong type for this too
    pack_gqa: Optional[float] = None,
    deterministic: bool = False,
    sm_margin: int = 0,