tzrec/ops/triton/triton_layer_norm.py (733 lines of code) (raw):
# Copyright (c) 2025, Alibaba Group;
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# We use the layer_norm ops from generative-recommenders a starting point.
# https://github.com/facebookresearch/generative-recommenders
# thanks to their public work.
from typing import List, Optional, Tuple
import torch
import triton
import triton.language as tl
from triton.runtime.autotuner import autotune as triton_autotune
from tzrec.ops.utils import switch_to_contiguous_if_needed
@triton.jit
def _layer_norm_fwd(
X,
Y,
Mean,
Rstd,
D,
eps,
stride_x,
stride_y,
TRAINING: tl.constexpr,
BLOCK_D: tl.constexpr,
COMPUTE_MEAN_AND_RSTD: tl.constexpr,
):
row = tl.program_id(0)
X += row.to(tl.int64) * stride_x
Y += row.to(tl.int64) * stride_y
cols = tl.arange(0, BLOCK_D)
x = tl.load(X + cols, mask=cols < D, other=0.0).to(tl.float32)
if COMPUTE_MEAN_AND_RSTD:
mean = tl.sum(x, axis=0) / D
else:
mean = tl.load(Mean + row)
x_mean = tl.where(cols < D, x - mean, 0.0)
if COMPUTE_MEAN_AND_RSTD:
_var = tl.zeros([BLOCK_D], dtype=tl.float32)
_var += x_mean * x_mean
var = tl.sum(_var, axis=0) / D
rstd = 1 / tl.sqrt(var + eps)
if TRAINING:
tl.store(Mean + row, mean)
tl.store(Rstd + row, rstd)
else:
rstd = tl.load(Rstd + row)
# Normalize and apply linear transformation
mask = cols < D
y = x_mean * rstd
# Write output
tl.store(Y + cols, y.to(Y.dtype.element_ty), mask=mask)
@triton.jit
def _weighted_layer_norm_fwd(
X,
Y,
W,
B,
Mean,
Rstd,
D,
eps,
stride_x,
stride_y,
IS_SWISH: tl.constexpr,
TRAINING: tl.constexpr,
BLOCK_D: tl.constexpr,
COMPUTE_MEAN_AND_RSTD: tl.constexpr,
):
row = tl.program_id(0)
X += row.to(tl.int64) * stride_x
Y += row.to(tl.int64) * stride_y
cols = tl.arange(0, BLOCK_D)
x = tl.load(X + cols, mask=cols < D, other=0.0).to(tl.float32)
if COMPUTE_MEAN_AND_RSTD:
mean = tl.sum(x, axis=0) / D
else:
mean = tl.load(Mean + row)
x_mean = tl.where(cols < D, x - mean, 0.0)
if COMPUTE_MEAN_AND_RSTD:
_var = tl.zeros([BLOCK_D], dtype=tl.float32)
_var += x_mean * x_mean
var = tl.sum(_var, axis=0) / D
rstd = 1 / tl.sqrt(var + eps)
if TRAINING:
tl.store(Mean + row, mean)
tl.store(Rstd + row, rstd)
else:
rstd = tl.load(Rstd + row)
# Normalize and apply linear transformation
mask = cols < D
y = x_mean * rstd
w = tl.load(W + cols, mask=mask).to(tl.float32)
b = tl.load(B + cols, mask=mask).to(tl.float32)
y = y * w + b
if IS_SWISH:
y = tl.sigmoid(y) * x
# Write output
tl.store(Y + cols, y.to(Y.dtype.element_ty), mask=mask)
@triton.jit
def _layer_norm_bwd_dx(
DX,
DY,
X,
Mean,
Rstd,
stride_dx,
stride_dy,
stride_x,
D,
eps,
BLOCK_D: tl.constexpr,
):
row = tl.program_id(0)
cols = tl.arange(0, BLOCK_D)
mask = cols < D
X += row.to(tl.int64) * stride_x
DY += row.to(tl.int64) * stride_dy
DX += row.to(tl.int64) * stride_dx
# Load data to SRAM
x = tl.load(X + cols, mask=mask, other=0).to(tl.float32)
dy = tl.load(DY + cols, mask=mask, other=0).to(tl.float32)
mean = tl.load(Mean + row)
rstd = tl.load(Rstd + row)
# Compute dx
xhat = (x - mean) * rstd
xhat = tl.where(mask, xhat, 0.0)
dy = tl.where(mask, dy, 0.0)
c1 = tl.sum(xhat * dy, axis=0) / D
c2 = tl.sum(dy, axis=0) / D
dx = (dy - (xhat * c1 + c2)) * rstd
# Write dx
tl.store(DX + cols, dx, mask=mask)
@triton.jit
def _weighted_layer_norm_bwd_dx(
DX,
DY,
DW,
DB,
X,
W,
B,
Mean,
Rstd,
stride_dx,
stride_dy,
stride_x,
D,
eps,
IS_SWISH: tl.constexpr,
N,
BLOCK_D: tl.constexpr,
):
pid = tl.program_id(0)
tile_num = tl.num_programs(0)
rows_per_tile = N // tile_num
if pid < N % tile_num:
rows_per_tile += 1
cols = tl.arange(0, BLOCK_D)
mask = cols < D
row = pid
for idx in range(rows_per_tile):
x_ptrs = X + row.to(tl.int64) * stride_x
dy_ptrs = DY + row.to(tl.int64) * stride_dy
dx_ptrs = DX + row.to(tl.int64) * stride_dx
dw_ptrs = DW + pid.to(tl.int64) * D
dw_ptrs += cols
db_ptrs = DB + pid.to(tl.int64) * D
db_ptrs += cols
# Load data to SRAM
x = tl.load(x_ptrs + cols, mask=mask, other=0).to(tl.float32)
dy = tl.load(dy_ptrs + cols, mask=mask, other=0).to(tl.float32)
mean = tl.load(Mean + row)
rstd = tl.load(Rstd + row)
# Compute dx
xhat = (x - mean) * rstd
w = tl.load(W + cols, mask=mask).to(tl.float32)
wdy = w * dy
xhat = tl.where(mask, xhat, 0.0)
wdy = tl.where(mask, wdy, 0.0)
sigmoid_layer_norm = None
if IS_SWISH:
b = tl.load(B + cols, mask=mask).to(tl.float32)
sigmoid_layer_norm = tl.sigmoid(xhat * w + b)
sigmoid_layer_norm = tl.where(mask, sigmoid_layer_norm, 0.0)
x_ = wdy * x * sigmoid_layer_norm * (1 - sigmoid_layer_norm)
x_ = tl.where(mask, x_, 0.0)
c1 = tl.sum(xhat * x_, axis=0) / D
c2 = tl.sum(x_, axis=0) / D
dx = (x_ - (xhat * c1 + c2)) * rstd
dx = dy * sigmoid_layer_norm + dx
else:
c1 = tl.sum(xhat * wdy, axis=0) / D
c2 = tl.sum(wdy, axis=0) / D
dx = (wdy - (xhat * c1 + c2)) * rstd
# Write dx
tl.store(dx_ptrs + cols, dx, mask=mask)
# Accumulate partial sums for dw/db
if IS_SWISH:
partial_dw = dy * x * xhat * sigmoid_layer_norm * (1 - sigmoid_layer_norm)
partial_db = dy * x * sigmoid_layer_norm * (1 - sigmoid_layer_norm)
else:
partial_dw = dy * xhat
partial_db = dy
# First store doesn't accumulate
if idx > 0:
partial_dw += tl.load(dw_ptrs, mask=mask)
partial_db += tl.load(db_ptrs, mask=mask)
tl.store(dw_ptrs, partial_dw, mask=mask)
tl.store(db_ptrs, partial_db, mask=mask)
row += tile_num
def _get_bwd_dwdb_configs() -> List[triton.Config]:
configs = []
for BLOCK_N in [32, 64, 128, 256]:
for num_warps in [8, 16] + ([] if torch.ops.hip else [32]):
configs.append(
triton.Config(
{"BLOCK_N": BLOCK_N},
num_warps=num_warps,
)
)
return configs
@triton_autotune(
configs=_get_bwd_dwdb_configs(),
key=["D"],
)
@triton.jit
def _layer_norm_bwd_dwdb(
DW,
DB,
FINAL_DW,
FINAL_DB,
N,
D,
BLOCK_N: tl.constexpr,
BLOCK_D: tl.constexpr,
):
pid = tl.program_id(0)
cols = pid * BLOCK_D + tl.arange(0, BLOCK_D)
dw = tl.zeros((BLOCK_N, BLOCK_D), dtype=tl.float32)
db = tl.zeros((BLOCK_N, BLOCK_D), dtype=tl.float32)
for i in range(0, N, BLOCK_N):
rows = i + tl.arange(0, BLOCK_N)
# pyre-fixme[16]: `int` has no attribute `__getitem__`.
mask = (rows[:, None] < N) & (cols[None, :] < D)
offs = rows[:, None] * D + cols[None, :]
dw += tl.load(DW + offs, mask=mask, other=0.0)
db += tl.load(DB + offs, mask=mask, other=0.0)
sum_dw = tl.sum(dw, axis=0)
sum_db = tl.sum(db, axis=0)
tl.store(FINAL_DW + cols, sum_dw.to(FINAL_DW.dtype.element_ty), mask=cols < D)
tl.store(FINAL_DB + cols, sum_db.to(FINAL_DB.dtype.element_ty), mask=cols < D)
def triton_weighted_layer_norm_fwd(
x: torch.Tensor,
weight: Optional[torch.Tensor],
bias: Optional[torch.Tensor],
eps: float,
mean: Optional[torch.Tensor] = None,
rstd: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, int, int]:
assert x.dim() == 2, f"x.dim() == {x.dim()}, expected 2"
x = switch_to_contiguous_if_needed(x)
N, D = x.shape
learnable = weight is not None
if learnable:
assert bias is not None and weight is not None
assert weight.dim() == 1
assert bias.dim() == 1
assert weight.numel() == D
assert bias.numel() == D
y = torch.empty_like(x)
compute_mean_and_rstd = mean is None or rstd is None
if mean is None:
mean = torch.empty((N,), dtype=torch.float32, device=x.device)
if rstd is None:
rstd = torch.empty((N,), dtype=torch.float32, device=x.device)
# Less than 64KB per feature: enqueue fused kernel
MAX_FUSED_SIZE = 65536 // x.element_size()
BLOCK_D: int = min(MAX_FUSED_SIZE, triton.next_power_of_2(D))
if D > BLOCK_D:
raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
num_warps: int = min(max(BLOCK_D // 256, 1), 8)
if N == 0:
return y, mean, rstd, BLOCK_D, num_warps
if learnable:
# pyre-ignore[28]
_weighted_layer_norm_fwd[(N,)](
x,
y,
weight,
bias,
mean,
rstd,
D,
eps,
x.stride(0),
y.stride(0),
IS_SWISH=False,
TRAINING=True,
BLOCK_D=BLOCK_D,
COMPUTE_MEAN_AND_RSTD=compute_mean_and_rstd,
num_warps=num_warps,
)
else:
# pyre-ignore[28]
_layer_norm_fwd[(N,)](
x,
y,
mean,
rstd,
D,
eps,
x.stride(0),
y.stride(0),
TRAINING=True,
BLOCK_D=BLOCK_D,
COMPUTE_MEAN_AND_RSTD=compute_mean_and_rstd,
num_warps=num_warps,
)
return y, mean, rstd, BLOCK_D, num_warps
def triton_weighted_layer_norm_bwd(
dy: torch.Tensor,
x: torch.Tensor,
weight: Optional[torch.Tensor],
bias: Optional[torch.Tensor],
mean: torch.Tensor,
rstd: torch.Tensor,
learnable: bool,
eps: float,
BLOCK_D: int,
num_warps: int,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
if learnable:
assert weight is not None and bias is not None
N, D = x.shape
dx = torch.empty_like(x)
sms = torch.cuda.get_device_properties(x.device).multi_processor_count
tile_num = max(1, min(sms * 8, N // 4))
_dweight = torch.empty((tile_num, D), dtype=torch.float32, device=x.device)
_dbias = torch.empty((tile_num, D), dtype=torch.float32, device=x.device)
dweight = torch.empty((D,), dtype=weight.dtype, device=x.device)
dbias = torch.empty((D,), dtype=weight.dtype, device=x.device)
if N == 0:
dweight.zero_()
dbias.zero_()
return dx, dweight, dbias
# pyre-ignore[28]
_weighted_layer_norm_bwd_dx[(tile_num,)](
dx,
dy,
_dweight,
_dbias,
x,
weight,
bias,
mean,
rstd,
dx.stride(0),
dy.stride(0),
x.stride(0),
D,
eps,
IS_SWISH=False,
N=N,
BLOCK_D=BLOCK_D,
num_warps=num_warps,
)
def grid(META):
return (triton.cdiv(D, META["BLOCK_D"]),)
blocks = triton.next_power_of_2(sms * 4)
BLOCK_D = triton.next_power_of_2(triton.cdiv(D, blocks))
BLOCK_D = min(max(BLOCK_D, 4), 128)
_layer_norm_bwd_dwdb[grid](
_dweight,
_dbias,
dweight,
dbias,
tile_num,
D,
BLOCK_D=BLOCK_D,
)
return dx, dweight, dbias
else:
N, D = x.shape
dx = torch.empty_like(x)
if N == 0:
return dx, None, None
# pyre-ignore[28]
_layer_norm_bwd_dx[(N,)](
dx,
dy,
x,
mean,
rstd,
dx.stride(0),
dy.stride(0),
x.stride(0),
D,
eps,
BLOCK_D=BLOCK_D,
num_warps=num_warps,
)
return dx, None, None
class LayerNormFunction(torch.autograd.Function):
@staticmethod
# pyre-ignore[14]
def forward(
ctx,
x: torch.Tensor,
weight: Optional[torch.Tensor],
bias: Optional[torch.Tensor],
eps: float,
) -> torch.Tensor:
y, mean, rstd, BLOCK_D, num_warps = triton_weighted_layer_norm_fwd(
x=x,
weight=weight,
bias=bias,
eps=eps,
)
learnable = weight is not None
if learnable:
ctx.save_for_backward(x, weight, bias, mean, rstd)
else:
ctx.save_for_backward(x, mean, rstd)
ctx.BLOCK_D = BLOCK_D
ctx.num_warps = num_warps
ctx.eps = eps
ctx.learnable = learnable
return y
@staticmethod
# pyre-ignore[14]
def backward(
ctx, dy: torch.Tensor
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor], None]:
if ctx.learnable:
x, weight, bias, mean, rstd = ctx.saved_tensors
else:
x, mean, rstd = ctx.saved_tensors
weight, bias = None, None
dx, dweight, dbias = triton_weighted_layer_norm_bwd(
dy=dy,
x=x,
weight=weight,
bias=bias,
mean=mean,
rstd=rstd,
learnable=ctx.learnable,
eps=ctx.eps,
BLOCK_D=ctx.BLOCK_D,
num_warps=ctx.num_warps,
)
return dx, dweight, dbias, None
@triton.jit
def _weighted_rms_norm_fwd(
X,
Y,
W,
Rstd,
D,
eps,
stride_x,
stride_y,
BLOCK_D: tl.constexpr,
):
row = tl.program_id(0)
X += row.to(tl.int64) * stride_x
Y += row.to(tl.int64) * stride_y
cols = tl.arange(0, BLOCK_D)
x = tl.load(X + cols, mask=cols < D, other=0.0).to(tl.float32)
# Compute variance
_var = tl.zeros([BLOCK_D], dtype=tl.float32)
x_mean = tl.where(cols < D, x, 0.0)
_var += x_mean * x_mean
var = tl.sum(_var, axis=0) / D
rstd = 1 / tl.sqrt(var + eps)
tl.store(Rstd + row, rstd)
# Normalize and apply linear transformation
mask = cols < D
y = x_mean * rstd
w = tl.load(W + cols, mask=mask).to(tl.float32)
y = y * w
# Write output
tl.store(Y + cols, y.to(Y.dtype.element_ty), mask=mask)
@triton.jit
def _weighted_rms_norm_bwd_dx(
DX,
DY,
DW,
X,
W,
Rstd,
Lock,
stride_dx,
stride_dy,
stride_x,
D,
eps,
GROUP_N,
BLOCK_D: tl.constexpr,
):
row = tl.program_id(0)
cols = tl.arange(0, BLOCK_D)
mask = cols < D
X += row.to(tl.int64) * stride_x
DY += row.to(tl.int64) * stride_dy
DX += row.to(tl.int64) * stride_dx
# Load data to SRAM
x = tl.load(X + cols, mask=mask, other=0).to(tl.float32)
dy = tl.load(DY + cols, mask=mask, other=0).to(tl.float32)
rstd = tl.load(Rstd + row)
# Compute dx
xhat = x * rstd
w = tl.load(W + cols, mask=mask).to(tl.float32)
wdy = w * dy
xhat = tl.where(mask, xhat, 0.0)
wdy = tl.where(mask, wdy, 0.0)
c1 = tl.sum(xhat * wdy, axis=0) / D
dx = (wdy - (xhat * c1)) * rstd
# Write dx
tl.store(DX + cols, dx, mask=mask)
# Offset locks and weights/biases gradient pointer for parallel reduction
lock_id = row % GROUP_N
Lock += lock_id
Count = Lock + GROUP_N
DW = DW + lock_id * D + cols
# Accumulate partial sums for dw/db
partial_dw = dy * xhat
while tl.atomic_cas(Lock, 0, 1) == 1:
pass
count = tl.load(Count)
# First store doesn't accumulate
if count == 0:
tl.atomic_xchg(Count, 1)
else:
partial_dw += tl.load(DW, mask=mask)
tl.store(DW, partial_dw, mask=mask)
# Release the lock
tl.atomic_xchg(Lock, 0)
@triton_autotune(
configs=_get_bwd_dwdb_configs(),
key=["D"],
)
@triton.jit
def _rms_norm_bwd_dwdb(
DW,
FINAL_DW,
N,
D,
BLOCK_N: tl.constexpr,
BLOCK_D: tl.constexpr,
):
pid = tl.program_id(0)
cols = pid * BLOCK_D + tl.arange(0, BLOCK_D)
dw = tl.zeros((BLOCK_N, BLOCK_D), dtype=tl.float32)
for i in range(0, N, BLOCK_N):
rows = i + tl.arange(0, BLOCK_N)
# pyre-fixme[16]: `int` has no attribute `__getitem__`.
mask = (rows[:, None] < N) & (cols[None, :] < D)
offs = rows[:, None] * D + cols[None, :]
dw += tl.load(DW + offs, mask=mask, other=0.0)
sum_dw = tl.sum(dw, axis=0)
tl.store(FINAL_DW + cols, sum_dw.to(FINAL_DW.dtype.element_ty), mask=cols < D)
class RMSNormFunction(torch.autograd.Function):
@staticmethod
# pyre-ignore[14]
def forward(
ctx,
x: torch.Tensor,
weight: torch.Tensor,
eps: float,
) -> torch.Tensor:
assert x.dim() == 2
x = switch_to_contiguous_if_needed(x)
N, D = x.shape
assert weight.dim() == 1
assert weight.numel() == D
y = torch.empty_like(x)
rstd = torch.empty((N,), dtype=torch.float32, device=x.device)
# Less than 64KB per feature: enqueue fused kernel
MAX_FUSED_SIZE = 65536 // x.element_size()
BLOCK_D = min(MAX_FUSED_SIZE, triton.next_power_of_2(D))
if D > BLOCK_D:
raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
num_warps = min(max(BLOCK_D // 256, 1), 8)
# pyre-ignore[28]
_weighted_rms_norm_fwd[(N,)](
x,
y,
weight,
rstd,
D,
eps,
x.stride(0),
y.stride(0),
BLOCK_D=BLOCK_D,
num_warps=num_warps,
)
ctx.save_for_backward(x, weight, rstd)
ctx.BLOCK_D = BLOCK_D
ctx.num_warps = num_warps
ctx.eps = eps
return y
@staticmethod
# pyre-ignore[14]
def backward(
ctx, dy: torch.Tensor
) -> Tuple[torch.Tensor, Optional[torch.Tensor], None]:
x, weight, rstd = ctx.saved_tensors
N, D = x.shape
dx = torch.empty_like(x)
if D <= 1024:
GROUP_N = 256 * 8
elif D <= 4096:
GROUP_N = 128 * 8
elif D <= 8192:
GROUP_N = 96 * 8
else:
GROUP_N = 64 * 8
GROUP_N = N if GROUP_N > N else GROUP_N
locks = torch.zeros(2 * GROUP_N, dtype=torch.int32, device=x.device)
_dweight = torch.empty((GROUP_N, D), dtype=torch.float32, device=x.device)
dweight = torch.empty((D,), dtype=weight.dtype, device=x.device)
# pyre-ignore[28]
_weighted_rms_norm_bwd_dx[(N,)](
dx,
dy,
_dweight,
x,
weight,
rstd,
locks,
dx.stride(0),
dy.stride(0),
x.stride(0),
D,
ctx.eps,
GROUP_N=GROUP_N,
BLOCK_D=ctx.BLOCK_D,
num_warps=ctx.num_warps,
)
def grid(META):
return (triton.cdiv(D, META["BLOCK_D"]),)
sms = torch.cuda.get_device_properties(x.device).multi_processor_count
blocks = triton.next_power_of_2(sms * 4)
BLOCK_D = triton.next_power_of_2(triton.cdiv(D, blocks))
BLOCK_D = min(max(BLOCK_D, 4), 128)
_rms_norm_bwd_dwdb[grid](
_dweight,
dweight,
GROUP_N,
D,
BLOCK_D=BLOCK_D,
)
return dx, dweight, None
class SwishLayerNormFunction(torch.autograd.Function):
@staticmethod
# pyre-ignore[14]
def forward(
ctx,
x: torch.Tensor,
weight: torch.Tensor,
bias: torch.Tensor,
eps: float,
) -> torch.Tensor:
assert x.dim() == 2, f"x.dim() == {x.dim()}, expected 2"
x = switch_to_contiguous_if_needed(x)
N, D = x.shape
assert bias is not None and weight is not None
assert weight.dim() == 1
assert bias.dim() == 1
assert weight.numel() == D
assert bias.numel() == D
y = torch.empty_like(x)
mean = torch.empty((N,), dtype=torch.float32, device=x.device)
rstd = torch.empty((N,), dtype=torch.float32, device=x.device)
BLOCK_D = triton.next_power_of_2(D)
num_warps = min(max(BLOCK_D // 256, 1), 8)
ctx.save_for_backward(x, weight, bias, mean, rstd)
ctx.BLOCK_D = BLOCK_D
ctx.num_warps = num_warps
ctx.eps = eps
if N == 0:
return y
# pyre-ignore[28]
_weighted_layer_norm_fwd[(N,)](
x,
y,
weight,
bias,
mean,
rstd,
D,
eps,
x.stride(0),
y.stride(0),
IS_SWISH=True,
TRAINING=True,
BLOCK_D=BLOCK_D,
COMPUTE_MEAN_AND_RSTD=True,
num_warps=num_warps,
)
return y
@staticmethod
# pyre-ignore[14]
def backward(
ctx, dy: torch.Tensor
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor], None]:
x, weight, bias, mean, rstd = ctx.saved_tensors
N, D = x.shape
dx = torch.empty_like(x)
sms = torch.cuda.get_device_properties(x.device).multi_processor_count
tile_num = max(1, min(sms * 8, N // 4))
_dweight = torch.empty((tile_num, D), dtype=torch.float32, device=x.device)
_dbias = torch.empty((tile_num, D), dtype=torch.float32, device=x.device)
dweight = torch.empty((D,), dtype=weight.dtype, device=x.device)
dbias = torch.empty((D,), dtype=weight.dtype, device=x.device)
if N == 0:
dweight.zero_()
dbias.zero_()
return dx, dweight, dbias, None
# pyre-ignore[28]
_weighted_layer_norm_bwd_dx[(tile_num,)](
dx,
dy,
_dweight,
_dbias,
x,
weight,
bias,
mean,
rstd,
dx.stride(0),
dy.stride(0),
x.stride(0),
D,
ctx.eps,
IS_SWISH=True,
N=N,
BLOCK_D=ctx.BLOCK_D,
num_warps=ctx.num_warps,
)
def grid(META):
return (triton.cdiv(D, META["BLOCK_D"]),)
blocks = triton.next_power_of_2(sms * 4)
BLOCK_D = triton.next_power_of_2(triton.cdiv(D, blocks))
BLOCK_D = min(max(BLOCK_D, 4), 128)
_layer_norm_bwd_dwdb[grid](
_dweight,
_dbias,
dweight,
dbias,
tile_num,
D,
BLOCK_D=BLOCK_D,
)
return dx, dweight, dbias, None
@torch.fx.wrap
def triton_layer_norm(
x: torch.Tensor,
weight: Optional[torch.Tensor],
bias: Optional[torch.Tensor],
eps: float,
) -> torch.Tensor:
return LayerNormFunction.apply(x, weight, bias, eps)
@torch.fx.wrap
def triton_rms_norm(
x: torch.Tensor,
weight: Optional[torch.Tensor],
eps: float,
) -> torch.Tensor:
return RMSNormFunction.apply(x, weight, eps)
@torch.fx.wrap
def triton_swish_layer_norm(
x: torch.Tensor,
normalized_shape: List[int],
weight: Optional[torch.Tensor],
bias: Optional[torch.Tensor],
eps: float,
) -> torch.Tensor:
return SwishLayerNormFunction.apply(x, weight, bias, eps)