tzrec/ops/triton/triton_hstu_linear.py (1,186 lines of code) (raw):

# Copyright (c) 2025, Alibaba Group; # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # http://www.apache.org/licenses/LICENSE-2.0 # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # We use the hstu_linear ops from generative-recommenders a starting point. # https://github.com/facebookresearch/generative-recommenders # thanks to their public work. from typing import List, Optional, Tuple import torch import triton import triton.language as tl from triton.runtime.autotuner import autotune as triton_autotune from tzrec.ops.triton.triton_addmm import triton_addmm_fwd from tzrec.ops.utils import ( switch_to_contiguous_if_needed, ) @triton.jit def _ln_mul_dropout_fwd( X, U, Y, W, B, Mean, Rstd, D, eps, seed, dropout_ratio, stride_x, stride_u, stride_y, BLOCK_D: tl.constexpr, TRAINING: tl.constexpr, CONCAT_UX: tl.constexpr, ): row = tl.program_id(0) X += row.to(tl.int64) * stride_x U += row.to(tl.int64) * stride_u Y += row.to(tl.int64) * stride_y cols = tl.arange(0, BLOCK_D) # Compute mean mean = 0.0 x = tl.load(X + cols, mask=cols < D, other=0.0).to(tl.float32) mean = tl.sum(x, axis=0) / D # Compute variance _var = tl.zeros([BLOCK_D], dtype=tl.float32) x_mean = tl.where(cols < D, x - mean, 0.0) _var += x_mean * x_mean var = tl.sum(_var, axis=0) / D rstd = 1 / tl.sqrt(var + eps) tl.store(Mean + row, mean) tl.store(Rstd + row, rstd) # Normalize and apply linear transformation mask = cols < D y = x_mean * rstd w = tl.load(W + cols, mask=mask).to(tl.float32) b = tl.load(B + cols, mask=mask).to(tl.float32) y = y * w + b u = tl.load(U + cols, mask=cols < D, other=0.0).to(tl.float32) y = y * u if TRAINING: random_offsets = row * BLOCK_D + cols if CONCAT_UX: # apply dropout on u random_u = tl.rand(seed, random_offsets) u_keep = random_u > dropout_ratio u = tl.where(u_keep, u / (1.0 - dropout_ratio), 0.0) # apply dropout on x random_x = tl.rand(seed, random_offsets + D) x_keep = random_x > dropout_ratio x = tl.where(x_keep, x / (1.0 - dropout_ratio), 0.0) # apply dropout on y random_y = tl.rand(seed, random_offsets + 2 * D) y_keep = random_y > dropout_ratio y = tl.where(y_keep, y / (1.0 - dropout_ratio), 0.0) else: random = tl.rand(seed, random_offsets) y_keep = random > dropout_ratio # write-back y = tl.where(y_keep, y / (1.0 - dropout_ratio), 0.0) # Write output if CONCAT_UX: tl.store(Y + cols, u.to(Y.dtype.element_ty), mask=mask) tl.store(Y + D + cols, x.to(Y.dtype.element_ty), mask=mask) tl.store(Y + 2 * D + cols, y.to(Y.dtype.element_ty), mask=mask) else: tl.store(Y + cols, y.to(Y.dtype.element_ty), mask=mask) @triton.jit def _ln_mul_dropout_bwd_dx_du( DX, DU, DY, DW, DB, X, U, Y, W, B, Mean, Rstd, stride_dx, stride_du, stride_dy, stride_x, stride_u, stride_y, D, eps, seed, dropout_ratio, N, BLOCK_D: tl.constexpr, TRAINING: tl.constexpr, CONCAT_UX: tl.constexpr, COMPUTE_Y: tl.constexpr, ): pid = tl.program_id(0) tile_num = tl.num_programs(0) rows_per_tile = N // tile_num if pid < N % tile_num: rows_per_tile += 1 if rows_per_tile == 0: return cols = tl.arange(0, BLOCK_D) mask = cols < D row = pid X += row.to(tl.int64) * stride_x U += row.to(tl.int64) * stride_u if COMPUTE_Y: Y += row.to(tl.int64) * stride_y DY += row.to(tl.int64) * stride_dy DX += row.to(tl.int64) * stride_dx DU += row.to(tl.int64) * stride_du DW = DW + pid * D + cols DB = DB + pid * D + cols for idx in range(0, rows_per_tile): # Load data to SRAM x = tl.load(X + cols, mask=mask, other=0).to(tl.float32) if CONCAT_UX: du = tl.load(DY + cols, mask=mask, other=0).to(tl.float32) dx = tl.load(DY + D + cols, mask=mask, other=0).to(tl.float32) dy = tl.load(DY + 2 * D + cols, mask=mask, other=0).to(tl.float32) else: du = tl.zeros([BLOCK_D], dtype=tl.float32) dx = tl.zeros([BLOCK_D], dtype=tl.float32) dy = tl.load(DY + cols, mask=mask, other=0).to(tl.float32) if TRAINING: random_offsets = row * BLOCK_D + cols if CONCAT_UX: # apply dropout on du random_du = tl.rand(seed, random_offsets) du_keep = random_du > dropout_ratio du = tl.where(du_keep, du / (1.0 - dropout_ratio), 0.0) # apply dropout on dx random_dx = tl.rand(seed, random_offsets + D) dx_keep = random_dx > dropout_ratio dx = tl.where(dx_keep, dx / (1.0 - dropout_ratio), 0.0) # apply dropout on dy random_dy = tl.rand(seed, random_offsets + 2 * D) dy_keep = random_dy > dropout_ratio dy = tl.where(dy_keep, dy / (1.0 - dropout_ratio), 0.0) else: random = tl.rand(seed, random_offsets) dy_keep = random > dropout_ratio # write-back dy = tl.where(dy_keep, dy / (1.0 - dropout_ratio), 0.0) mean = tl.load(Mean + row) rstd = tl.load(Rstd + row) # Compute dx xhat = (x - mean) * rstd w = tl.load(W + cols, mask=mask).to(tl.float32) b = tl.load(B + cols, mask=mask).to(tl.float32) ln = xhat * w + b du += dy * ln tl.store(DU + cols, du.to(DU.dtype.element_ty), mask=mask) u = tl.load(U + cols, mask=mask, other=0).to(tl.float32) dy = dy * u wdy = w * dy if COMPUTE_Y: y = ln * u if TRAINING: if CONCAT_UX: u = tl.where( du_keep, # pyre-ignore [61] u / (1.0 - dropout_ratio), 0.0, ) x = tl.where( dx_keep, # pyre-ignore [61] x / (1.0 - dropout_ratio), 0.0, ) y = tl.where( dy_keep, # pyre-ignore [61] y / (1.0 - dropout_ratio), 0.0, ) else: y = tl.where( dy_keep, # pyre-ignore [61] y / (1.0 - dropout_ratio), 0.0, ) if CONCAT_UX: tl.store(Y + cols, u.to(Y.dtype.element_ty), mask=mask) tl.store(Y + D + cols, x.to(Y.dtype.element_ty), mask=mask) tl.store(Y + 2 * D + cols, y.to(Y.dtype.element_ty), mask=mask) else: tl.store(Y + cols, y.to(Y.dtype.element_ty), mask=mask) Y += tile_num.to(tl.int64) * stride_y xhat = tl.where(mask, xhat, 0.0) wdy = tl.where(mask, wdy, 0.0) c1 = tl.sum(xhat * wdy, axis=0) / D c2 = tl.sum(wdy, axis=0) / D dx += (wdy - (xhat * c1 + c2)) * rstd # Write dx tl.store(DX + cols, dx, mask=mask) # Accumulate partial sums for dw/db partial_dw = dy * xhat partial_db = dy # First store doesn't accumulate if idx > 0: partial_dw += tl.load(DW, mask=mask) partial_db += tl.load(DB, mask=mask) tl.store(DW, partial_dw, mask=mask) tl.store(DB, partial_db, mask=mask) X += tile_num.to(tl.int64) * stride_x U += tile_num.to(tl.int64) * stride_u DY += tile_num.to(tl.int64) * stride_dy DX += tile_num.to(tl.int64) * stride_dx DU += tile_num.to(tl.int64) * stride_du row += tile_num def _get_bwd_dwdb_configs() -> List[triton.Config]: configs = [] for BLOCK_N in [32, 64, 128, 256]: for num_warps in [8, 16] + ([] if torch.ops.hip else [32]): configs.append( triton.Config( {"BLOCK_N": BLOCK_N}, num_warps=num_warps, ) ) return configs @triton_autotune( configs=_get_bwd_dwdb_configs(), key=["D"], ) @triton.jit def _ln_mul_dropout_bwd_dwdb( DW, DB, FINAL_DW, FINAL_DB, N, D, BLOCK_N: tl.constexpr, BLOCK_D: tl.constexpr, ): pid = tl.program_id(0) cols = pid * BLOCK_D + tl.arange(0, BLOCK_D) dw = tl.zeros((BLOCK_N, BLOCK_D), dtype=tl.float32) db = tl.zeros((BLOCK_N, BLOCK_D), dtype=tl.float32) for i in range(0, N, BLOCK_N): rows = i + tl.arange(0, BLOCK_N) # pyre-fixme[16]: `int` has no attribute `__getitem__`. mask = (rows[:, None] < N) & (cols[None, :] < D) offs = rows[:, None] * D + cols[None, :] dw += tl.load(DW + offs, mask=mask, other=0.0) db += tl.load(DB + offs, mask=mask, other=0.0) sum_dw = tl.sum(dw, axis=0) sum_db = tl.sum(db, axis=0) tl.store(FINAL_DW + cols, sum_dw.to(FINAL_DW.dtype.element_ty), mask=cols < D) tl.store(FINAL_DB + cols, sum_db.to(FINAL_DB.dtype.element_ty), mask=cols < D) def triton_layer_norm_mul_dropout_fwd( x: torch.Tensor, u: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor, eps: float, dropout_ratio: float, training: bool, concat_ux: bool = False, seed: Optional[int] = None, ) -> Tuple[ torch.Tensor, torch.Tensor, torch.Tensor, int, int, int ]: # y, mean, rstd, BLOCK_D, num_warps, seed assert x.dim() == 2 x = switch_to_contiguous_if_needed(x) N, D = x.shape assert weight.dim() == 1 assert bias.dim() == 1 assert weight.numel() == D assert bias.numel() == D if concat_ux: y = torch.empty((N, 3 * D), dtype=x.dtype, device=x.device) else: y = torch.empty_like(x) mean = torch.empty((N,), dtype=torch.float32, device=x.device) rstd = torch.empty((N,), dtype=torch.float32, device=x.device) if N == 0: return y, mean, rstd, 0, 0, 0 # Less than 64KB per feature: enqueue fused kernel MAX_FUSED_SIZE = 65536 // x.element_size() BLOCK_D: int = min(MAX_FUSED_SIZE, triton.next_power_of_2(D)) if D > BLOCK_D: raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.") if seed is None: seed = torch.randint(low=0, high=2**62, size=(1,), dtype=torch.int64).item() num_warps: int = min(max(BLOCK_D // 256, 1), 8) # pyre-ignore[28] _ln_mul_dropout_fwd[(N,)]( x, u, y, weight, bias, mean, rstd, D, eps, seed, dropout_ratio, x.stride(0), u.stride(0), y.stride(0), BLOCK_D=BLOCK_D, TRAINING=training, CONCAT_UX=concat_ux, num_warps=num_warps, ) return y, mean, rstd, BLOCK_D, num_warps, seed # pyre-ignore [7] def triton_layer_norm_mul_dropout_bwd( dy: torch.Tensor, x: torch.Tensor, u: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor, mean: torch.Tensor, rstd: torch.Tensor, BLOCK_D: int, num_warps: int, eps: float, training: bool, dropout_ratio: float, seed: Optional[int] = None, concat_ux: bool = False, compute_y: bool = False, ) -> Tuple[ torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, Optional[torch.Tensor] ]: y = None N, D = x.shape if compute_y: if concat_ux: y = torch.empty((N, 3 * D), dtype=x.dtype, device=x.device) else: y = torch.empty_like(x) if N == 0: return ( torch.zeros_like(x), torch.zeros_like(u), torch.zeros((D,), dtype=weight.dtype, device=x.device), torch.zeros((D,), dtype=weight.dtype, device=x.device), y, ) dx = torch.empty_like(x) du = torch.empty_like(u) sms = torch.cuda.get_device_properties(x.device).multi_processor_count tile_num = max(1, min(sms * 64, N // 4)) _dweight = torch.empty((tile_num, D), dtype=torch.float32, device=x.device) _dbias = torch.empty((tile_num, D), dtype=torch.float32, device=x.device) dweight = torch.empty((D,), dtype=weight.dtype, device=x.device) dbias = torch.empty((D,), dtype=weight.dtype, device=x.device) # pyre-ignore[28] _ln_mul_dropout_bwd_dx_du[(tile_num,)]( dx, du, dy, _dweight, _dbias, x, u, y, weight, bias, mean, rstd, dx.stride(0), du.stride(0), dy.stride(0), x.stride(0), u.stride(0), y.stride(0) if compute_y else 0, # pyre-ignore [16] D, eps, seed, dropout_ratio, N=N, BLOCK_D=BLOCK_D, TRAINING=training, CONCAT_UX=concat_ux, COMPUTE_Y=compute_y, num_warps=num_warps, ) def grid(META): return (triton.cdiv(D, META["BLOCK_D"]),) blocks = triton.next_power_of_2(sms * 4) BLOCK_D = triton.next_power_of_2(triton.cdiv(D, blocks)) BLOCK_D = min(max(BLOCK_D, 4), 128) _ln_mul_dropout_bwd_dwdb[grid]( _dweight, _dbias, dweight, dbias, tile_num, D, BLOCK_D=BLOCK_D, ) return dx, du, dweight, dbias, y class LayerNormMulDropoutFunction(torch.autograd.Function): @staticmethod # pyre-ignore[14] def forward( ctx, x: torch.Tensor, u: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor, eps: float, dropout_ratio: float, training: bool, concat_ux: bool = False, seed: Optional[int] = None, ) -> torch.Tensor: y, mean, rstd, BLOCK_D, num_warps, seed = triton_layer_norm_mul_dropout_fwd( x=x, u=u, weight=weight, bias=bias, eps=eps, dropout_ratio=dropout_ratio, training=training, concat_ux=concat_ux, seed=seed, ) ctx.save_for_backward(x, u, weight, bias, mean, rstd) ctx.BLOCK_D = BLOCK_D ctx.num_warps = num_warps ctx.eps = eps ctx.seed = seed ctx.training = training ctx.concat_ux = concat_ux ctx.dropout_ratio = dropout_ratio return y @staticmethod # pyre-ignore[14] def backward( ctx, dy: torch.Tensor ) -> Tuple[ torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, None, None, None, None, None, ]: x, u, weight, bias, mean, rstd = ctx.saved_tensors dx, du, dweight, dbias, _ = triton_layer_norm_mul_dropout_bwd( dy=dy, x=x, u=u, weight=weight, bias=bias, mean=mean, rstd=rstd, BLOCK_D=ctx.BLOCK_D, num_warps=ctx.num_warps, eps=ctx.eps, training=ctx.training, dropout_ratio=ctx.dropout_ratio, seed=ctx.seed, concat_ux=ctx.concat_ux, compute_y=False, ) return dx, du, dweight, dbias, None, None, None, None, None @triton.jit def _group_norm_mul_dropout_fwd( X, U, Y, W, B, Mean, Rstd, D, Heads, eps, seed, dropout_ratio, stride_x, stride_u, stride_y, BLOCK_D: tl.constexpr, BLOCK_H: tl.constexpr, TRAINING: tl.constexpr, CONCAT_UX: tl.constexpr, ): row = tl.program_id(0) X += row.to(tl.int64) * stride_x U += row.to(tl.int64) * stride_u Y += row.to(tl.int64) * stride_y cols = tl.arange(0, BLOCK_D) heads = tl.arange(0, BLOCK_H) offsets = heads[:, None] * D + cols[None, :] mask_h = heads < Heads mask_c = cols < D mask = mask_c[None, :] & mask_h[:, None] # Compute mean mean = 0.0 x = tl.load(X + offsets, mask=mask, other=0.0).to(tl.float32) mean = tl.sum(x, axis=1) / D mean = tl.ravel(mean) # Compute variance _var = tl.zeros([BLOCK_H, BLOCK_D], dtype=tl.float32) x_mean = tl.where(mask, x - mean[:, None], 0.0) _var += x_mean * x_mean var = tl.sum(_var, axis=1) / D var = tl.ravel(var) rstd = 1 / tl.sqrt(var + eps) tl.store(Mean + row * Heads + heads, mean, mask=mask_h) tl.store(Rstd + row * Heads + heads, rstd, mask=mask_h) # Normalize and apply linear transformation y = x_mean * rstd[:, None] # pyre-ignore [16] w = tl.load(W + heads, mask=mask_h).to(tl.float32) b = tl.load(B + heads, mask=mask_h).to(tl.float32) y = y * w[:, None] + b[:, None] u = tl.load(U + offsets, mask=mask, other=0.0).to(tl.float32) y = y * u if TRAINING: if CONCAT_UX: random_offsets = row * 3 * D * Heads + offsets # apply dropout on u random_u = tl.rand(seed, random_offsets) u_keep = random_u > dropout_ratio u = tl.where(u_keep, u / (1.0 - dropout_ratio), 0.0) # apply dropout on x random_x = tl.rand(seed, random_offsets + Heads * D) x_keep = random_x > dropout_ratio x = tl.where(x_keep, x / (1.0 - dropout_ratio), 0.0) # apply dropout on y random_y = tl.rand(seed, random_offsets + 2 * Heads * D) y_keep = random_y > dropout_ratio y = tl.where(y_keep, y / (1.0 - dropout_ratio), 0.0) else: random_offsets = row * D * Heads + offsets random = tl.rand(seed, random_offsets) y_keep = random > dropout_ratio # write-back y = tl.where(y_keep, y / (1.0 - dropout_ratio), 0.0) # Write output if CONCAT_UX: tl.store(Y + offsets, u.to(Y.dtype.element_ty), mask=mask) tl.store(Y + Heads * D + offsets, x.to(Y.dtype.element_ty), mask=mask) tl.store(Y + 2 * Heads * D + offsets, y.to(Y.dtype.element_ty), mask=mask) else: tl.store(Y + offsets, y.to(Y.dtype.element_ty), mask=mask) @triton.jit def _group_norm_mul_dropout_bwd_dx_du( DX, DU, DY, DW, DB, X, U, Y, W, B, Mean, Rstd, stride_dx, stride_du, stride_dy, stride_x, stride_u, stride_y, D, Heads, eps, seed, dropout_ratio, GROUP_N: tl.constexpr, BLOCK_D: tl.constexpr, BLOCK_H: tl.constexpr, TRAINING: tl.constexpr, CONCAT_UX: tl.constexpr, COMPUTE_Y: tl.constexpr, ): row = tl.program_id(0) cols = tl.arange(0, BLOCK_D) off_heads = tl.arange(0, BLOCK_H) mask_c = cols < D mask_h = off_heads < Heads mask = mask_c[None, :] & mask_h[:, None] X += row.to(tl.int64) * stride_x U += row.to(tl.int64) * stride_u DY += row.to(tl.int64) * stride_dy DX += row.to(tl.int64) * stride_dx DU += row.to(tl.int64) * stride_du offsets = off_heads[:, None] * D + cols[None, :] # Load data to SRAM x = tl.load(X + offsets, mask=mask, other=0).to(tl.float32) if CONCAT_UX: du = tl.load(DY + offsets, mask=mask, other=0).to(tl.float32) dx = tl.load(DY + Heads * D + offsets, mask=mask, other=0).to(tl.float32) dy = tl.load(DY + 2 * Heads * D + offsets, mask=mask, other=0).to(tl.float32) else: du = tl.zeros([BLOCK_H, BLOCK_D], dtype=tl.float32) dx = tl.zeros([BLOCK_H, BLOCK_D], dtype=tl.float32) dy = tl.load(DY + offsets, mask=mask, other=0).to(tl.float32) if TRAINING: if CONCAT_UX: random_offsets = row * 3 * D * Heads + offsets # apply dropout on du random_du = tl.rand(seed, random_offsets) du_keep = random_du > dropout_ratio du = tl.where(du_keep, du / (1.0 - dropout_ratio), 0.0) # apply dropout on dx random_dx = tl.rand(seed, random_offsets + Heads * D) dx_keep = random_dx > dropout_ratio dx = tl.where(dx_keep, dx / (1.0 - dropout_ratio), 0.0) # apply dropout on dy random_dy = tl.rand(seed, random_offsets + 2 * Heads * D) dy_keep = random_dy > dropout_ratio dy = tl.where(dy_keep, dy / (1.0 - dropout_ratio), 0.0) else: random_offsets = row * D * Heads + offsets random = tl.rand(seed, random_offsets) dy_keep = random > dropout_ratio # write-back dy = tl.where(dy_keep, dy / (1.0 - dropout_ratio), 0.0) mean = tl.load(Mean + row * Heads + off_heads) rstd = tl.load(Rstd + row * Heads + off_heads) # Compute dx xhat = (x - mean[:, None]) * rstd[:, None] w = tl.load(W + off_heads, mask=mask_h).to(tl.float32) b = tl.load(B + off_heads, mask=mask_h).to(tl.float32) ln = xhat * w[:, None] + b[:, None] du += dy * ln tl.store(DU + offsets, du.to(DU.dtype.element_ty), mask=mask) u = tl.load(U + offsets, mask=mask, other=0).to(tl.float32) dy = dy * u wdy = w[:, None] * dy if COMPUTE_Y: Y += row.to(tl.int64) * stride_y y = ln * u if TRAINING: if CONCAT_UX: u = tl.where( du_keep, # pyre-ignore [61] u / (1.0 - dropout_ratio), 0.0, ) x = tl.where( dx_keep, # pyre-ignore [61] x / (1.0 - dropout_ratio), 0.0, ) y = tl.where( dy_keep, # pyre-ignore [61] y / (1.0 - dropout_ratio), 0.0, ) else: y = tl.where( dy_keep, # pyre-ignore [61] y / (1.0 - dropout_ratio), 0.0, ) if CONCAT_UX: tl.store(Y + offsets, u.to(Y.dtype.element_ty), mask=mask) tl.store(Y + Heads * D + offsets, x.to(Y.dtype.element_ty), mask=mask) tl.store(Y + 2 * Heads * D + offsets, y.to(Y.dtype.element_ty), mask=mask) else: tl.store(Y + offsets, y.to(Y.dtype.element_ty), mask=mask) xhat = tl.where(mask, xhat, 0.0) wdy = tl.where(mask, wdy, 0.0) c1 = tl.sum(xhat * wdy, axis=1) / D c2 = tl.sum(wdy, axis=1) / D dx += (wdy - (xhat * c1[:, None] + c2[:, None])) * rstd[:, None] # Write dx tl.store(DX + offsets, dx, mask=mask) # Offset locks and weights/biases gradient pointer for parallel reduction lock_id = row % GROUP_N DW = DW + lock_id * Heads + off_heads DB = DB + lock_id * Heads + off_heads # Accumulate partial sums for dw/db partial_dw = tl.sum(dy * xhat, axis=1) partial_dw = tl.ravel(partial_dw) partial_db = tl.sum(dy, axis=1) partial_db = tl.ravel(partial_db) tl.atomic_add( DW, partial_dw, mask=mask_h, sem="relaxed", ) tl.atomic_add( DB, partial_db, mask=mask_h, sem="relaxed", ) def triton_group_norm_mul_dropout_fwd( x: torch.Tensor, u: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor, eps: float, dropout_ratio: float, training: bool, concat_ux: bool = False, num_heads: int = 1, linear_dim: int = -1, seed: Optional[int] = None, ) -> Tuple[ torch.Tensor, torch.Tensor, torch.Tensor, int, int, int, int ]: # y, mean, rstd, BLOCK_D, BLOCK_H, num_warps, seed assert x.dim() == 2 assert x.shape == u.shape assert x.shape[1] == num_heads * linear_dim x = switch_to_contiguous_if_needed(x) u = switch_to_contiguous_if_needed(u) N, _ = x.shape assert weight.dim() == 1 assert bias.dim() == 1 assert weight.numel() == num_heads assert bias.numel() == num_heads if concat_ux: y = torch.empty((N, 3 * num_heads * linear_dim), dtype=x.dtype, device=x.device) else: y = torch.empty((N, num_heads * linear_dim), dtype=x.dtype, device=x.device) mean = torch.empty((N * num_heads,), dtype=torch.float32, device=x.device) rstd = torch.empty((N * num_heads,), dtype=torch.float32, device=x.device) if N == 0: return y, mean, rstd, 0, 0, 0, 0 # Less than 64KB per feature: enqueue fused kernel MAX_FUSED_SIZE = 65536 // x.element_size() BLOCK_D: int = triton.next_power_of_2(linear_dim) BLOCK_H: int = triton.next_power_of_2(num_heads) if BLOCK_D * BLOCK_H > MAX_FUSED_SIZE: raise RuntimeError( "This group norm doesn't support num_heads * linear_dim >= 64KB." ) if seed is None: seed = torch.randint(low=0, high=2**62, size=(1,), dtype=torch.int64).item() num_warps: int = min(max(BLOCK_D * BLOCK_H // 256, 1), 8) # pyre-ignore[28] _group_norm_mul_dropout_fwd[(N,)]( x, u, y, weight, bias, mean, rstd, linear_dim, num_heads, eps, seed, dropout_ratio, x.stride(0), u.stride(0), y.stride(0), BLOCK_D=BLOCK_D, BLOCK_H=BLOCK_H, TRAINING=training, CONCAT_UX=concat_ux, num_warps=num_warps, ) return y, mean, rstd, BLOCK_D, BLOCK_H, num_warps, seed # pyre-ignore [7] def triton_group_norm_mul_dropout_bwd( dy: torch.Tensor, x: torch.Tensor, u: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor, mean: torch.Tensor, rstd: torch.Tensor, BLOCK_D: int, BLOCK_H: int, num_warps: int, eps: float, training: bool, dropout_ratio: float, seed: Optional[int] = None, concat_ux: bool = False, num_heads: int = 1, linear_dim: int = -1, compute_y: bool = False, ) -> Tuple[ torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, Optional[torch.Tensor] ]: y = None N, dim = x.shape if compute_y: if concat_ux: y = torch.empty( (N, 3 * num_heads * linear_dim), dtype=x.dtype, device=x.device ) else: y = torch.empty((N, num_heads * linear_dim), dtype=x.dtype, device=x.device) if N == 0: return ( torch.zeros_like(x), torch.zeros_like(u), torch.zeros_like(weight), torch.zeros_like(bias), y, ) dx = torch.empty_like(x) du = torch.empty_like(u) if dim <= 1024: GROUP_N = 256 * 8 elif dim <= 4096: GROUP_N = 128 * 8 elif dim <= 8192: GROUP_N = 96 * 8 else: GROUP_N = 64 * 8 GROUP_N = N if GROUP_N > N else GROUP_N _dweight = torch.zeros((GROUP_N, num_heads), dtype=torch.float32, device=x.device) _dbias = torch.zeros((GROUP_N, num_heads), dtype=torch.float32, device=x.device) dweight = torch.empty((num_heads,), dtype=weight.dtype, device=x.device) dbias = torch.empty((num_heads,), dtype=weight.dtype, device=x.device) # pyre-ignore[28] _group_norm_mul_dropout_bwd_dx_du[(N,)]( dx, du, dy, _dweight, _dbias, x, u, y, weight, bias, mean, rstd, dx.stride(0), du.stride(0), dy.stride(0), x.stride(0), u.stride(0), y.stride(0) if compute_y else 0, # pyre-ignore [16] linear_dim, num_heads, eps, seed, dropout_ratio, GROUP_N=GROUP_N, BLOCK_D=BLOCK_D, BLOCK_H=BLOCK_H, TRAINING=training, CONCAT_UX=concat_ux, COMPUTE_Y=compute_y, num_warps=num_warps, ) _group_norm_bwd_dwdb[(num_heads,)]( _dweight, _dbias, dweight, dbias, GROUP_N, ) return dx, du, dweight, dbias, y def _get_bwd_dwdb_configs() -> List[triton.Config]: configs = [] for BLOCK_N in [32, 64, 128, 256]: for num_warps in [8, 16] + ([] if torch.ops.hip else [32]): configs.append( triton.Config( {"BLOCK_N": BLOCK_N}, num_warps=num_warps, ) ) return configs @triton_autotune( configs=_get_bwd_dwdb_configs(), key=[], ) @triton.jit def _group_norm_bwd_dwdb( DW, DB, FINAL_DW, FINAL_DB, N, BLOCK_N: tl.constexpr, ): col = tl.program_id(0) num_heads = tl.num_programs(0) dw = tl.zeros((BLOCK_N,), dtype=tl.float32) db = tl.zeros((BLOCK_N,), dtype=tl.float32) for i in range(0, N, BLOCK_N): rows = i + tl.arange(0, BLOCK_N) mask = rows < N offs = rows * num_heads + col dw += tl.load(DW + offs, mask=mask, other=0.0) db += tl.load(DB + offs, mask=mask, other=0.0) sum_dw = tl.sum(dw, axis=0) sum_db = tl.sum(db, axis=0) tl.store(FINAL_DW + col, sum_dw.to(FINAL_DW.dtype.element_ty)) tl.store(FINAL_DB + col, sum_db.to(FINAL_DB.dtype.element_ty)) class GroupNormMulDropoutFunction(torch.autograd.Function): @staticmethod # pyre-ignore[14] def forward( ctx, x: torch.Tensor, u: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor, eps: float, dropout_ratio: float, training: bool, concat_ux: bool = False, num_heads: int = 1, linear_dim: int = -1, seed: Optional[int] = None, ) -> torch.Tensor: y, mean, rstd, BLOCK_D, BLOCK_H, num_warps, seed = ( triton_group_norm_mul_dropout_fwd( x=x, u=u, weight=weight, bias=bias, eps=eps, dropout_ratio=dropout_ratio, training=training, concat_ux=concat_ux, num_heads=num_heads, linear_dim=linear_dim, seed=seed, ) ) ctx.save_for_backward(x, u, weight, bias, mean, rstd) ctx.BLOCK_D = BLOCK_D ctx.BLOCK_H = BLOCK_H ctx.num_warps = num_warps ctx.eps = eps ctx.seed = seed ctx.training = training ctx.concat_ux = concat_ux ctx.dropout_ratio = dropout_ratio ctx.num_heads = num_heads ctx.linear_dim = linear_dim return y @staticmethod # pyre-ignore[14] def backward( ctx, dy: torch.Tensor ) -> Tuple[ torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, None, None, None, None, None, None, None, ]: x, u, weight, bias, mean, rstd = ctx.saved_tensors dx, du, dweight, dbias, _ = triton_group_norm_mul_dropout_bwd( dy=dy, x=x, u=u, weight=weight, bias=bias, mean=mean, rstd=rstd, BLOCK_D=ctx.BLOCK_D, BLOCK_H=ctx.BLOCK_H, num_warps=ctx.num_warps, eps=ctx.eps, training=ctx.training, dropout_ratio=ctx.dropout_ratio, seed=ctx.seed, concat_ux=ctx.concat_ux, num_heads=ctx.num_heads, linear_dim=ctx.linear_dim, compute_y=False, ) return ( dx, du, dweight, dbias, None, None, None, None, None, None, None, ) class HSTUComputeOutputFunction(torch.autograd.Function): @staticmethod # pyre-ignore[14] def forward( ctx, attn: torch.Tensor, u: torch.Tensor, x: torch.Tensor, norm_weight: torch.Tensor, norm_bias: torch.Tensor, output_weight: torch.Tensor, eps: float, dropout_ratio: float, training: bool, concat_ux: bool = False, group_norm: bool = False, num_heads: int = 1, linear_dim: int = -1, seed: Optional[int] = None, recompute_y_in_backward: bool = False, ) -> torch.Tensor: if group_norm: y, mean, rstd, BLOCK_D, BLOCK_H, num_warps, seed = ( triton_group_norm_mul_dropout_fwd( x=attn, u=u, weight=norm_weight, bias=norm_bias, eps=eps, dropout_ratio=dropout_ratio, training=training, concat_ux=concat_ux, num_heads=num_heads, linear_dim=linear_dim, seed=seed, ) ) ctx.BLOCK_H = BLOCK_H else: y, mean, rstd, BLOCK_D, num_warps, seed = triton_layer_norm_mul_dropout_fwd( x=attn, u=u, weight=norm_weight, bias=norm_bias, eps=eps, dropout_ratio=dropout_ratio, training=training, concat_ux=concat_ux, seed=seed, ) # NOTE: for AMD training, we go with torch.addmm instead of the triton # version before Triton on AMD achieves on-par perf with NV GPU. if torch.version.hip: out = torch.addmm(x, y, output_weight) else: out = triton_addmm_fwd(x=y, w=output_weight, y=x) saved_tensors = [attn, u, norm_weight, norm_bias, mean, rstd, output_weight] if not recompute_y_in_backward: saved_tensors.append(y) ctx.save_for_backward(*saved_tensors) ctx.BLOCK_D = BLOCK_D ctx.num_warps = num_warps ctx.eps = eps ctx.seed = seed ctx.training = training ctx.concat_ux = concat_ux ctx.dropout_ratio = dropout_ratio ctx.num_heads = num_heads ctx.linear_dim = linear_dim ctx.group_norm = group_norm ctx.recompute_y_in_backward = recompute_y_in_backward return out @staticmethod # pyre-ignore[14] def backward( ctx, dout: torch.Tensor ) -> Tuple[ torch.Tensor, # dattn torch.Tensor, # du torch.Tensor, # dx torch.Tensor, # d_norm_weight torch.Tensor, # d_norm_bias torch.Tensor, # d_output_weight None, None, None, None, None, None, None, None, None, ]: attn, u, norm_weight, norm_bias, mean, rstd, output_weight = ctx.saved_tensors[ :7 ] dy = torch.mm(dout, output_weight.t()) if ctx.group_norm: dattn, du, d_norm_weight, d_norm_bias, y = ( triton_group_norm_mul_dropout_bwd( dy=dy, x=attn, u=u, weight=norm_weight, bias=norm_bias, mean=mean, rstd=rstd, BLOCK_D=ctx.BLOCK_D, BLOCK_H=ctx.BLOCK_H, num_warps=ctx.num_warps, eps=ctx.eps, training=ctx.training, dropout_ratio=ctx.dropout_ratio, seed=ctx.seed, concat_ux=ctx.concat_ux, num_heads=ctx.num_heads, linear_dim=ctx.linear_dim, compute_y=ctx.recompute_y_in_backward, ) ) else: dattn, du, d_norm_weight, d_norm_bias, y = ( triton_layer_norm_mul_dropout_bwd( dy=dy, x=attn, u=u, weight=norm_weight, bias=norm_bias, mean=mean, rstd=rstd, BLOCK_D=ctx.BLOCK_D, num_warps=ctx.num_warps, eps=ctx.eps, training=ctx.training, dropout_ratio=ctx.dropout_ratio, seed=ctx.seed, concat_ux=ctx.concat_ux, compute_y=ctx.recompute_y_in_backward, ) ) if not ctx.recompute_y_in_backward: y = ctx.saved_tensors[7] d_output_weight = torch.mm(y.t(), dout) return ( dattn, du, dout, d_norm_weight, d_norm_bias, d_output_weight, None, None, None, None, None, None, None, None, None, ) @torch.fx.wrap def triton_norm_mul_dropout( x: torch.Tensor, u: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor, eps: float, dropout_ratio: float, training: bool, concat_ux: bool = False, group_norm: bool = False, num_heads: int = 1, linear_dim: int = -1, seed: Optional[int] = None, ) -> torch.Tensor: if group_norm: return GroupNormMulDropoutFunction.apply( x, u, weight, bias, eps, dropout_ratio, training, concat_ux, num_heads, linear_dim, seed, ) else: return LayerNormMulDropoutFunction.apply( x, u, weight, bias, eps, dropout_ratio, training, concat_ux, seed ) @torch.fx.wrap def triton_hstu_compute_output( attn: torch.Tensor, u: torch.Tensor, x: torch.Tensor, norm_weight: torch.Tensor, norm_bias: torch.Tensor, output_weight: torch.Tensor, eps: float, dropout_ratio: float, training: bool, concat_ux: bool = False, group_norm: bool = False, num_heads: int = 1, linear_dim: int = -1, seed: Optional[int] = None, recompute_y_in_backward: bool = False, ) -> torch.Tensor: return HSTUComputeOutputFunction.apply( attn, u, x, norm_weight, norm_bias, output_weight, eps, dropout_ratio, training, concat_ux, group_norm, num_heads, linear_dim, seed, recompute_y_in_backward, )