in src/nanotron/nn/ring_attention_lucidrain.py [0:0]
def backward(ctx, do):
(
causal,
softmax_scale,
mask,
bucket_size,
ring_reduce_col,
max_ring_passes,
num_lookback_buckets,
striped_ring_attn,
q_head_groups,
softclamp_qk_sim,
softclamp_value,
dtype,
ring_pg,
) = ctx.args
q, k, v, o, lse = ctx.saved_tensors
ring_size = ring_pg.size() if ring_pg else get_world_size()
do = do.type(o.dtype)
device = q.device
if causal:
mask = None
q.shape[-3]
per_machine_seq_size = k.shape[-3]
per_machine_seq_size // bucket_size
ring_pass_fn = all_ring_pass if ring_reduce_col else null_ring_pass
device = q.device
dq = torch.zeros(q.shape, device=device, dtype=torch.float32)
dk = torch.zeros_like(k, device=device)
dv = torch.zeros_like(v, device=device)
# k and v will have 16 bits, and dk and dv can also be accumulated safely with the same type, i think
assert k.dtype == v.dtype
kv_and_dkv = torch.stack((k, v, dk, dv))
# receive buffers, to be alternated with sent buffer
receive_kv_and_dkv = None
receive_mask = None
# caching the delta (do * o for backwards pass) across ring reduce
delta = None
for (ring_rank, _), ((kv_and_dkv, mask), (receive_kv_and_dkv, receive_mask)) in ring_pass_fn(
kv_and_dkv,
mask,
receive_buffers=(receive_kv_and_dkv, receive_mask),
max_iters=max_ring_passes,
ring_size=ring_size,
ring_pg=ring_pg,
):
k, v, dk, dv = kv_and_dkv
# account for grouped query attention
k, v = (repeat(t, "... h d -> ... (g h) d", g=q_head_groups) for t in (k, v))
# translate key padding mask to bias
bias = None
if exists(mask):
bias = torch.where(mask, 0.0, float("-inf"))
# bias = rearrange(bias, "b j -> b 1 1 j")
# determine whether to do causal mask or not
# depends on whether it is striped attention, as well as current machine rank vs ring rank
if causal and striped_ring_attn:
need_accum = True
block_causal = True
causal_mask_diagonal = (get_rank() if ring_pg is None else dist.get_rank(ring_pg)) < ring_rank
elif causal:
need_accum = (get_rank() if ring_pg is None else dist.get_rank(ring_pg)) >= ring_rank
block_causal = (get_rank() if ring_pg is None else dist.get_rank(ring_pg)) == ring_rank
causal_mask_diagonal = False
else:
need_accum = True
block_causal = False
causal_mask_diagonal = False
# use flash attention backwards kernel to calculate dq, dk, dv and accumulate
if need_accum:
ring_dq = torch.empty(q.shape, device=device, dtype=torch.float32)
ring_dk = torch.empty_like(k)
ring_dv = torch.empty_like(v)
with torch.inference_mode():
delta = flash_attn_backward(
do,
q,
k,
v,
o,
lse,
ring_dq,
ring_dk,
ring_dv,
delta=delta,
bias=bias,
causal=block_causal,
causal_mask_diagonal=causal_mask_diagonal,
softmax_scale=softmax_scale,
softclamp_qk_sim=softclamp_qk_sim,
softclamp_value=softclamp_value,
)
# account for grouped query attention
ring_dk = reduce(ring_dk, "... (g h) d -> ... h d", g=q_head_groups, reduction="sum")
ring_dv = reduce(ring_dv, "... (g h) d -> ... h d", g=q_head_groups, reduction="sum")
dq.add_(ring_dq)
dk.add_(ring_dk)
dv.add_(ring_dv)
if not ring_reduce_col:
continue
dkv = kv_and_dkv[2:]
max_ring_passes = default(max_ring_passes, ring_size)
dkv = ring_pass(ring_size - max_ring_passes + 1, dkv, ring_pg=ring_pg)
dk, dv = dkv
dq, dk, dv = (t.to(dtype) for t in (dq, dk, dv))
return dq, dk, dv, None, None, None, None, None, None, None, None, None