def backward()

in xformers/triton/dropout.py [0:0]


    def backward(ctx, grad_out):
        (seeds, bias, inputs) = ctx.saved_tensors

        # Soft-flatten an hypothetical 3rd dimension
        grad_out_ = grad_out.reshape(-1, grad_out.shape[-1]).contiguous()
        grad_in = torch.empty_like(grad_out_)

        M, N = grad_out_.shape

        # Optional inputs to compute the activation contribution to the gradient
        assert inputs is not None or ctx.activation_grad is None

        if inputs is None:
            inputs = grad_out_
        elif inputs.ndim > 2:
            inputs = inputs.reshape(-1, N)

        # We split the problem in tiles:
        # - over M there will be a follow up reduction
        # - over M, we go by 4 tiles at at time (consequence of the random number generation)
        # - over N we compromise in between trying to use as much memory paralellism as possible,
        # (fill in the warps, there are 32 threads per warps, and 4 warps default), and not being too
        # big because of register spilling
        N_BLOCKS_M = triton.cdiv(M, GROUP_M)

        if ctx.trainable_bias:
            grad_bias = torch.empty(
                (
                    N_BLOCKS_M,
                    N,
                ),
                device=grad_in.device,
                dtype=grad_in.dtype,
            )

        else:
            grad_bias = grad_in  # will not be used

        def grid(meta):
            # NOTE: We use Triton Philox random number generator, which optimally generates 4 blocks for
            # a given seed and offsets. "BLOCK_M" here describes the size of one of these blocks
            # but we need to take this factor of 4 into account when scheduling all the kernels
            return (
                triton.cdiv(M, meta["BLOCK_M"] * 4),
                triton.cdiv(N, meta["BLOCK_N"]),
            )

        # fmt: off
        k_dropout_bw[grid](
            grad_in, grad_bias, grad_out_,
            inputs, bias if bias is not None else inputs,
            seeds,
            grad_out_.stride(0), inputs.stride(0),
            M, N,
            ctx.p,
            grad_in.dtype == torch.float16,
            USE_BIAS=bias is not None,
            ACTIVATION_GRAD=ctx.activation_grad,
            TRAINABLE_BIAS=ctx.trainable_bias,
            BLOCK_M=BLOCK_M,
            BLOCK_N=BLOCK_N,
        )
        # fmt: on

        return (
            grad_in.reshape_as(grad_out),
            None,
            torch.sum(grad_bias, dim=0) if ctx.trainable_bias else None,
            None,
            None,
            None,
        )