def _()

in utils/pipeline_utils.py [0:0]


def _(q, k, v, **kwargs):
    # two outputs:
    # 1. output: (batch, seq_len, num_heads, head_dim)
    # 2. softmax_lse: (batch, num_heads, seq_len) with dtype=torch.float32
    meta_q = torch.empty_like(q).contiguous()
    return meta_q #, q.new_empty((q.size(0), q.size(2), q.size(1)), dtype=torch.float32)