in picotron/context_parallel/context_parallel.py [0:0]
def backward(ctx, dout, *args):
q, k, v, out, softmax_lse = ctx.saved_tensors
sm_scale = ctx.sm_scale
is_causal = ctx.is_causal
kv_comm = ContextCommunicate("kv_comm")
d_kv_comm = ContextCommunicate("d_kv_comm")
dq, dk, dv = None, None, None
next_dk, next_dv = None, None
block_dq_buffer = torch.empty(q.shape, dtype=q.dtype, device=q.device)
block_dk_buffer = torch.empty(k.shape, dtype=k.dtype, device=k.device)
block_dv_buffer = torch.empty(v.shape, dtype=v.dtype, device=v.device)
next_dk, next_dv = None, None
next_k, next_v = None, None
for step in range(kv_comm.world_size):
if step + 1 != kv_comm.world_size:
next_k = kv_comm.send_recv(k)
next_v = kv_comm.send_recv(v)
kv_comm.commit()
if step <= kv_comm.rank or not is_causal:
bwd_causal = is_causal and step == 0
block_dq_buffer, block_dk_buffer, block_dv_buffer = ring_attention_backward(
dout, q, k, v, out, softmax_lse, sm_scale, bwd_causal
)
if dq is None:
dq = block_dq_buffer.to(torch.float32)
dk = block_dk_buffer.to(torch.float32)
dv = block_dv_buffer.to(torch.float32)
else:
dq += block_dq_buffer
d_kv_comm.wait()
dk = block_dk_buffer + next_dk
dv = block_dv_buffer + next_dv
elif step != 0:
d_kv_comm.wait()
dk = next_dk
dv = next_dv
if step + 1 != kv_comm.world_size:
kv_comm.wait()
k = next_k
v = next_v
next_dk = d_kv_comm.send_recv(dk)
next_dv = d_kv_comm.send_recv(dv)
d_kv_comm.commit()
d_kv_comm.wait()
return dq, next_dk, next_dv, None, None