in picotron/context_parallel/context_parallel.py [0:0]
def ring_attention_backward(dO, Q, K, V, O, softmax_lse, sm_scale, is_causal):
batch_size, nheads, seqlen, d = Q.shape
# Recreate S and P from log_sum_exp
S = torch.matmul(Q, K.transpose(-2, -1)) * sm_scale
if is_causal:
causal_mask = torch.triu(torch.ones(seqlen, seqlen, device=Q.device, dtype=torch.bool), diagonal=1)
S = S.masked_fill(causal_mask.unsqueeze(0).unsqueeze(1), float('-inf'))
P = torch.exp(S - softmax_lse.unsqueeze(-1))
# Step 1: Compute dV
dV = torch.matmul(P.transpose(-2, -1), dO)
# Step 2: Compute dP
dP = torch.matmul(dO, V.transpose(-2, -1))
# Step 3: Compute D
D = torch.sum(dO * O, dim=-1, keepdim=True)
# Step 4: Compute dS
dS = P * (dP - D)
# Apply causal mask to dS if is_causal is True
if is_causal:
dS = dS.masked_fill(causal_mask.unsqueeze(0).unsqueeze(1), 0)
# Step 5: Compute dQ
dQ = torch.matmul(dS, K) * sm_scale
# Step 6: Compute dK
dK = torch.matmul(dS.transpose(-2, -1), Q) * sm_scale
return dQ, dK, dV