in picotron/context_parallel/context_parallel.py [0:0]
def ring_attention_forward(q, k, v, sm_scale, is_causal):
batch_size, nheads, seqlen, d = q.shape
S = torch.matmul(q, k.transpose(-2, -1)) * sm_scale
if is_causal:
causal_mask = torch.triu(torch.ones(seqlen, seqlen, device=q.device, dtype=torch.bool), diagonal=1)
causal_mask = causal_mask.unsqueeze(0).unsqueeze(1).expand(batch_size, nheads, seqlen, seqlen)
S.masked_fill_(causal_mask, float('-inf'))
# Online softmax
S_max = torch.max(S, dim=-1, keepdim=True)[0]
exp_S = torch.exp(S - S_max)
exp_sum = torch.sum(exp_S, dim=-1, keepdim=True)
log_sum_exp = torch.log(exp_sum) + S_max
P = exp_S / exp_sum
O = torch.matmul(P, v)
return O, log_sum_exp.squeeze(-1)