in xformers/triton/k_layer_norm.py [0:0]
def backward(ctx, dy):
y, var, weight, bias = ctx.saved_tensors
# heuristics for amount of parallel reduction stream for DG/DB
N = y.size(-1)
GROUP_SIZE_M = 64
if N <= 8192:
GROUP_SIZE_M = 96
if N <= 4096:
GROUP_SIZE_M = 128
if N <= 1024:
GROUP_SIZE_M = 256
# flatten the batch dimension, if any.
# We're interested in 'samples' x norm_dimension
y = y.reshape(-1, y.size(-1))
M, N = y.size()
# allocate output
locks = torch.zeros(2 * GROUP_SIZE_M, dtype=torch.int32, device="cuda")
t_args = {"dtype": y.dtype, "device": y.device}
_dw = torch.empty((GROUP_SIZE_M, y.size(-1)), **t_args)
_db = torch.empty((GROUP_SIZE_M, y.size(-1)), **t_args)
dw = torch.empty((y.size(-1),), **t_args)
db = torch.empty((y.size(-1),), **t_args)
dy = dy.contiguous()
dx = torch.empty_like(dy)
# Check the tensor shapes and layouts
# we suppose in the kernel that they have the same size and are contiguous
assert dx.numel() == y.numel(), \
"Something is wrong in the backward graph, possibly because of an inplace operation after the layernorm"
# enqueue kernel using forward pass heuristics
# also compute partial sums for DW and DB
# fmt: off
meta = {"BLOCK_SIZE_N": ctx.BLOCK_SIZE_N,
"GROUP_SIZE_M": GROUP_SIZE_M,
"num_warps": ctx.num_warps}
if weight is None:
_layer_norm_no_affine_bwd[(M,)](dx, dy, y, var, y.stride(0), N, **meta)
return dx, None, None, None
_layer_norm_bwd_dx_fused[(M,)](
dx, dy, _dw, _db,
y, weight, bias, var,
locks,
y.stride(0),
N,
**meta
)
# fmt: on
def grid(meta):
return [triton.cdiv(N, meta["BLOCK_SIZE_N"])]
# accumulate partial sums in separate kernel
# fmt: off
_layer_norm_bwd_dwdb[grid](
_dw, _db, dw, db,
GROUP_SIZE_M,
N,
BLOCK_SIZE_M=32,
BLOCK_SIZE_N=128
)
# fmt: on
dx = dx.reshape_as(dy)
return dx, dw, db, None