def backward()

in src/nanotron/nn/ring_attention_lucidrain.py [0:0]


    def backward(ctx, do):

        (
            causal,
            softmax_scale,
            mask,
            bucket_size,
            ring_reduce_col,
            max_ring_passes,
            num_lookback_buckets,
            striped_ring_attn,
            q_head_groups,
            softclamp_qk_sim,
            softclamp_value,
            dtype,
            ring_pg,
        ) = ctx.args

        q, k, v, o, lse = ctx.saved_tensors
        ring_size = ring_pg.size() if ring_pg else get_world_size()

        do = do.type(o.dtype)

        device = q.device

        if causal:
            mask = None

        q.shape[-3]

        per_machine_seq_size = k.shape[-3]
        per_machine_seq_size // bucket_size

        ring_pass_fn = all_ring_pass if ring_reduce_col else null_ring_pass

        device = q.device

        dq = torch.zeros(q.shape, device=device, dtype=torch.float32)
        dk = torch.zeros_like(k, device=device)
        dv = torch.zeros_like(v, device=device)

        # k and v will have 16 bits, and dk and dv can also be accumulated safely with the same type, i think

        assert k.dtype == v.dtype

        kv_and_dkv = torch.stack((k, v, dk, dv))

        # receive buffers, to be alternated with sent buffer

        receive_kv_and_dkv = None
        receive_mask = None

        # caching the delta (do * o for backwards pass) across ring reduce

        delta = None

        for (ring_rank, _), ((kv_and_dkv, mask), (receive_kv_and_dkv, receive_mask)) in ring_pass_fn(
            kv_and_dkv,
            mask,
            receive_buffers=(receive_kv_and_dkv, receive_mask),
            max_iters=max_ring_passes,
            ring_size=ring_size,
            ring_pg=ring_pg,
        ):

            k, v, dk, dv = kv_and_dkv

            # account for grouped query attention

            k, v = (repeat(t, "... h d -> ... (g h) d", g=q_head_groups) for t in (k, v))

            # translate key padding mask to bias

            bias = None

            if exists(mask):
                bias = torch.where(mask, 0.0, float("-inf"))
                # bias = rearrange(bias, "b j -> b 1 1 j")

            # determine whether to do causal mask or not
            # depends on whether it is striped attention, as well as current machine rank vs ring rank

            if causal and striped_ring_attn:
                need_accum = True
                block_causal = True
                causal_mask_diagonal = (get_rank() if ring_pg is None else dist.get_rank(ring_pg)) < ring_rank
            elif causal:
                need_accum = (get_rank() if ring_pg is None else dist.get_rank(ring_pg)) >= ring_rank
                block_causal = (get_rank() if ring_pg is None else dist.get_rank(ring_pg)) == ring_rank
                causal_mask_diagonal = False
            else:
                need_accum = True
                block_causal = False
                causal_mask_diagonal = False

            # use flash attention backwards kernel to calculate dq, dk, dv and accumulate

            if need_accum:
                ring_dq = torch.empty(q.shape, device=device, dtype=torch.float32)
                ring_dk = torch.empty_like(k)
                ring_dv = torch.empty_like(v)

                with torch.inference_mode():
                    delta = flash_attn_backward(
                        do,
                        q,
                        k,
                        v,
                        o,
                        lse,
                        ring_dq,
                        ring_dk,
                        ring_dv,
                        delta=delta,
                        bias=bias,
                        causal=block_causal,
                        causal_mask_diagonal=causal_mask_diagonal,
                        softmax_scale=softmax_scale,
                        softclamp_qk_sim=softclamp_qk_sim,
                        softclamp_value=softclamp_value,
                    )

                # account for grouped query attention

                ring_dk = reduce(ring_dk, "... (g h) d -> ... h d", g=q_head_groups, reduction="sum")
                ring_dv = reduce(ring_dv, "... (g h) d -> ... h d", g=q_head_groups, reduction="sum")

                dq.add_(ring_dq)
                dk.add_(ring_dk)
                dv.add_(ring_dv)

            if not ring_reduce_col:
                continue

            dkv = kv_and_dkv[2:]

            max_ring_passes = default(max_ring_passes, ring_size)
            dkv = ring_pass(ring_size - max_ring_passes + 1, dkv, ring_pg=ring_pg)

            dk, dv = dkv

        dq, dk, dv = (t.to(dtype) for t in (dq, dk, dv))

        return dq, dk, dv, None, None, None, None, None, None, None, None, None