in xformers/triton/k_layer_norm.py [0:0]
def forward(ctx, x, weight, bias, eps):
# allocate output
y = torch.empty_like(x)
# reshape input data into 2D tensor
x_arg = x.reshape(-1, x.shape[-1])
M, N = x_arg.shape
# allocate mean and std, they'll be used in the backward pass
mean = torch.empty((M,), dtype=torch.float32, device="cuda")
rstd = torch.empty((M,), dtype=torch.float32, device="cuda")
# Less than 64KB per feature: enqueue fused kernel
MAX_FUSED_SIZE = 65536 // x.element_size()
BLOCK_SIZE_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))
if N > BLOCK_SIZE_N:
raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
if not x_arg.is_contiguous() or not y.is_contiguous():
global _triton_registered_warnings
if not _triton_registered_warnings:
logging.warning("Non-contiguous input tensor found. Making it contiguous,"
+ " but could have perf or trainer implications")
_triton_registered_warnings = True
x_arg = x_arg.contiguous()
y = y.contiguous()
# heuristics for number of warps.
num_warps = min(max(BLOCK_SIZE_N // 256, 1), 8)
# enqueue kernel
# fmt: off
if weight is None:
_layer_norm_non_affine_fw[(M,)](
x_arg, y, mean, rstd,
x_arg.stride(0),
N,
eps,
num_warps=num_warps,
BLOCK_SIZE_N=BLOCK_SIZE_N
)
else:
_layer_norm_fw[(M,)](
x_arg, y, weight, bias, mean, rstd,
x_arg.stride(0),
N,
eps,
num_warps=num_warps,
BLOCK_SIZE_N=BLOCK_SIZE_N
)
# fmt: on
ctx.save_for_backward(y, rstd, weight, bias)
ctx.BLOCK_SIZE_N = BLOCK_SIZE_N
ctx.num_warps = num_warps
ctx.eps = eps
ctx.N = N
return y.reshape_as(x)