in xformers/triton/dropout.py [0:0]
def backward(ctx, grad_out):
(seeds, bias, inputs) = ctx.saved_tensors
# Soft-flatten an hypothetical 3rd dimension
grad_out_ = grad_out.reshape(-1, grad_out.shape[-1]).contiguous()
grad_in = torch.empty_like(grad_out_)
M, N = grad_out_.shape
# Optional inputs to compute the activation contribution to the gradient
assert inputs is not None or ctx.activation_grad is None
if inputs is None:
inputs = grad_out_
elif inputs.ndim > 2:
inputs = inputs.reshape(-1, N)
# We split the problem in tiles:
# - over M there will be a follow up reduction
# - over M, we go by 4 tiles at at time (consequence of the random number generation)
# - over N we compromise in between trying to use as much memory paralellism as possible,
# (fill in the warps, there are 32 threads per warps, and 4 warps default), and not being too
# big because of register spilling
N_BLOCKS_M = triton.cdiv(M, GROUP_M)
if ctx.trainable_bias:
grad_bias = torch.empty(
(
N_BLOCKS_M,
N,
),
device=grad_in.device,
dtype=grad_in.dtype,
)
else:
grad_bias = grad_in # will not be used
def grid(meta):
# NOTE: We use Triton Philox random number generator, which optimally generates 4 blocks for
# a given seed and offsets. "BLOCK_M" here describes the size of one of these blocks
# but we need to take this factor of 4 into account when scheduling all the kernels
return (
triton.cdiv(M, meta["BLOCK_M"] * 4),
triton.cdiv(N, meta["BLOCK_N"]),
)
# fmt: off
k_dropout_bw[grid](
grad_in, grad_bias, grad_out_,
inputs, bias if bias is not None else inputs,
seeds,
grad_out_.stride(0), inputs.stride(0),
M, N,
ctx.p,
grad_in.dtype == torch.float16,
USE_BIAS=bias is not None,
ACTIVATION_GRAD=ctx.activation_grad,
TRAINABLE_BIAS=ctx.trainable_bias,
BLOCK_M=BLOCK_M,
BLOCK_N=BLOCK_N,
)
# fmt: on
return (
grad_in.reshape_as(grad_out),
None,
torch.sum(grad_bias, dim=0) if ctx.trainable_bias else None,
None,
None,
None,
)