in utils/pipeline_utils.py [0:0]
def _(q, k, v, **kwargs):
# two outputs:
# 1. output: (batch, seq_len, num_heads, head_dim)
# 2. softmax_lse: (batch, num_heads, seq_len) with dtype=torch.float32
meta_q = torch.empty_like(q).contiguous()
return meta_q #, q.new_empty((q.size(0), q.size(2), q.size(1)), dtype=torch.float32)