in picotron/context_parallel/context_parallel.py [0:0]
def forward(ctx, q, k, v, sm_scale, is_causal):
comm = ContextCommunicate("comm")
#TODO(fmom): add flex attention
#TODO(fmom): add flash attention
#TODO(fmom): Find a better to save these tensors without cloning
k_og = k.clone()
v_og = v.clone()
out, lse = None, None
next_k, next_v = None, None
for step in range(comm.world_size):
if step + 1 != comm.world_size:
next_k = comm.send_recv(k)
next_v = comm.send_recv(v)
comm.commit()
if not is_causal or step <= comm.rank:
block_out, block_lse = ring_attention_forward(
q, k, v, sm_scale, is_causal and step == 0
)
out, lse = update_out_and_lse(out, lse, block_out, block_lse)
if step + 1 != comm.world_size:
comm.wait()
k = next_k
v = next_v
out = out.to(q.dtype)
ctx.save_for_backward(q, k_og, v_og, out, lse.squeeze(-1))
ctx.sm_scale = sm_scale
ctx.is_causal = is_causal
return out