in xformers/triton/dropout.py [0:0]
def forward(ctx, x, p, bias, activation, activation_grad, trainable_bias):
# Soft-flatten an hypothetical 3rd dimension
x_ = x.reshape(-1, x.shape[-1]).contiguous()
y = torch.empty_like(x_)
M, N = x_.shape
assert bias is None or (bias.dtype == x.dtype and bias.shape[0] == N)
def grid(meta):
# NOTE: We use Triton Philox random number generator, which optimally generates 4 blocks for
# a given seed and offsets. "BLOCK_M" here describes the size of one of these blocks
# but we need to take this factor of 4 into account when scheduling all the kernels
return (
triton.cdiv(M, meta["BLOCK_M"] * 4),
triton.cdiv(N, meta["BLOCK_N"]),
)
N_BLOCK_N = triton.cdiv(N, BLOCK_N)
# Generate one seed per sample
# seed max is int32 max for positive numbers: 2**16
seeds = torch.randint(65536, (N_BLOCK_N,), device=x.device).to(torch.int32)
# fmt: off
k_dropout_fw[grid](
y, x_,
bias if bias is not None else x_,
seeds,
y.stride(0),
M, N,
p,
x.dtype == torch.float16,
USE_BIAS=bias is not None,
ACTIVATION=activation,
BLOCK_M=BLOCK_M,
BLOCK_N=BLOCK_N,
)
# fmt: on
if activation is not None:
ctx.save_for_backward(seeds, bias, x)
else:
ctx.save_for_backward(seeds, bias, None)
ctx.trainable_bias = bias is not None and trainable_bias
ctx.activation_grad = activation_grad
ctx.p = p
return y.reshape_as(x)