def backward()

in xformers/triton/k_layer_norm.py [0:0]


    def backward(ctx, dy):
        y, var, weight, bias = ctx.saved_tensors

        # heuristics for amount of parallel reduction stream for DG/DB
        N = y.size(-1)
        GROUP_SIZE_M = 64
        if N <= 8192:
            GROUP_SIZE_M = 96
        if N <= 4096:
            GROUP_SIZE_M = 128
        if N <= 1024:
            GROUP_SIZE_M = 256

        # flatten the batch dimension, if any.
        # We're interested in 'samples' x norm_dimension
        y = y.reshape(-1, y.size(-1))
        M, N = y.size()

        # allocate output
        locks = torch.zeros(2 * GROUP_SIZE_M, dtype=torch.int32, device="cuda")
        t_args = {"dtype": y.dtype, "device": y.device}
        _dw = torch.empty((GROUP_SIZE_M, y.size(-1)), **t_args)
        _db = torch.empty((GROUP_SIZE_M, y.size(-1)), **t_args)
        dw = torch.empty((y.size(-1),), **t_args)
        db = torch.empty((y.size(-1),), **t_args)
        dy = dy.contiguous()
        dx = torch.empty_like(dy)

        # Check the tensor shapes and layouts
        # we suppose in the kernel that they have the same size and are contiguous
        assert dx.numel() == y.numel(), \
            "Something is wrong in the backward graph, possibly because of an inplace operation after the layernorm"

        # enqueue kernel using forward pass heuristics
        # also compute partial sums for DW and DB

        # fmt: off
        meta = {"BLOCK_SIZE_N": ctx.BLOCK_SIZE_N,
                "GROUP_SIZE_M": GROUP_SIZE_M,
                "num_warps": ctx.num_warps}
        if weight is None:
            _layer_norm_no_affine_bwd[(M,)](dx, dy, y, var, y.stride(0), N, **meta)
            return dx, None, None, None

        _layer_norm_bwd_dx_fused[(M,)](
            dx, dy, _dw, _db,
            y, weight, bias, var,
            locks,
            y.stride(0),
            N,
            **meta
        )

        # fmt: on

        def grid(meta):
            return [triton.cdiv(N, meta["BLOCK_SIZE_N"])]

        # accumulate partial sums in separate kernel
        # fmt: off
        _layer_norm_bwd_dwdb[grid](
            _dw, _db, dw, db,
            GROUP_SIZE_M,
            N,
            BLOCK_SIZE_M=32,
            BLOCK_SIZE_N=128
        )
        # fmt: on

        dx = dx.reshape_as(dy)
        return dx, dw, db, None