modules/SwissArmyTransformer/sat/model/position_embedding/triton_rotary.py (181 lines of code) (raw):

from typing import Optional, Union import torch import triton import triton.language as tl # @triton.autotune( # configs=[ # triton.Config({"BLOCK_M": 2}), # triton.Config({"BLOCK_M": 4}), # triton.Config({"BLOCK_M": 8}), # triton.Config({"BLOCK_M": 16}), # ], # key=["CACHE_KEY_SEQLEN", "BLOCK_K", "INTERLEAVED"], # ) @triton.jit def rotary_kernel( OUT, # Pointers to matrices X, COS, SIN, POSITIONS, CU_SEQLENS, SEQLEN_OFFSETS, # this could be int or a pointer # Matrix dimensions seqlen, nheads, rotary_dim, seqlen_ro, CACHE_KEY_SEQLEN, # strides stride_out_batch, stride_out_nheads, stride_out_seqlen, stride_out_headdim, stride_x_batch, stride_x_nheads, stride_x_seqlen, stride_x_headdim, stride_p_batch, # Meta-parameters BLOCK_K: tl.constexpr, IS_SEQLEN_OFFSETS_TENSOR: tl.constexpr, IS_VARLEN: tl.constexpr, INTERLEAVED: tl.constexpr, CONJUGATE: tl.constexpr, BLOCK_M: tl.constexpr, ): pid_m = tl.program_id(axis=0) pid_batch = tl.program_id(axis=1) pid_head = tl.program_id(axis=2) rotary_dim_half = rotary_dim // 2 X = X + pid_batch * stride_x_batch + pid_head * stride_x_nheads OUT = OUT + pid_batch * stride_out_batch + pid_head * stride_out_nheads POS = POSITIONS + pid_batch * stride_p_batch if pid_m * BLOCK_M >= seqlen: return rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) rm_cs = tl.load(POS + rm, mask=rm < seqlen, other=0) #tl.device_print("rm rm_cs, ", rm, rm_cs) if not IS_SEQLEN_OFFSETS_TENSOR: rm_cs = rm_cs + SEQLEN_OFFSETS else: rm_cs = rm_cs + tl.load(SEQLEN_OFFSETS + pid_batch) rk = tl.arange(0, BLOCK_K) rk_half = tl.arange(0, BLOCK_K // 2) if not INTERLEAVED: # Load the 1st and 2nd halves of X, do calculation, then store to 1st and 2nd halves of OUT X = X + (rm[:, None] * stride_x_seqlen + rk_half[None, :] * stride_x_headdim) COS = COS + (rm_cs[:, None] * rotary_dim_half + rk_half[None, :]) SIN = SIN + (rm_cs[:, None] * rotary_dim_half + rk_half[None, :]) cos = tl.load( COS, mask=(rm_cs[:, None] < seqlen_ro) & (rk_half[None, :] < rotary_dim_half), other=1.0 ) sin = tl.load( SIN, mask=(rm_cs[:, None] < seqlen_ro) & (rk_half[None, :] < rotary_dim_half), other=0.0 ) x0 = tl.load( X, mask=(rm[:, None] < seqlen) & (rk_half[None, :] < rotary_dim_half), other=0.0 ) x1 = tl.load( X + rotary_dim_half * stride_x_headdim, mask=(rm[:, None] < seqlen) & (rk_half[None, :] < rotary_dim_half), other=0.0, ) if CONJUGATE: sin = -sin o0 = x0 * cos - x1 * sin o1 = x0 * sin + x1 * cos # write back result OUT = OUT + (rm[:, None] * stride_out_seqlen + rk_half[None, :] * stride_out_headdim) tl.store(OUT, o0, mask=(rm[:, None] < seqlen) & (rk_half[None, :] < rotary_dim_half)) tl.store( OUT + rotary_dim_half * stride_out_headdim, o1, mask=(rm[:, None] < seqlen) & (rk_half[None, :] < rotary_dim_half), ) else: # We don't want to load X[0, 2, 4, ...] and X[1, 3, 5, ...] separately since both are slow. # Instead, we load x0 = X[0, 1, 2, 3, ...] and x1 = X[1, 0, 3, 2, ...]. # Loading x0 will be fast but x1 will be slow. # Then we load cos = COS[0, 0, 1, 1, ...] and sin = SIN[0, 0, 1, 1, ...]. # Then we do the calculation and use tl.where to pick put the right outputs for the even # and for the odd indices. rk_swap = rk + ((rk + 1) % 2) * 2 - 1 # 1, 0, 3, 2, 5, 4, ... rk_repeat = tl.arange(0, BLOCK_K) // 2 X0 = X + (rm[:, None] * stride_x_seqlen + rk[None, :] * stride_x_headdim) X1 = X + (rm[:, None] * stride_x_seqlen + rk_swap[None, :] * stride_x_headdim) COS = COS + (rm_cs[:, None] * rotary_dim_half + rk_repeat[None, :]) SIN = SIN + (rm_cs[:, None] * rotary_dim_half + rk_repeat[None, :]) cos = tl.load( COS, mask=(rm_cs[:, None] < seqlen_ro) & (rk_repeat[None, :] < rotary_dim_half), other=1.0, ) sin = tl.load( SIN, mask=(rm_cs[:, None] < seqlen_ro) & (rk_repeat[None, :] < rotary_dim_half), other=0.0, ) x0 = tl.load(X0, mask=(rm[:, None] < seqlen) & (rk[None, :] < rotary_dim), other=0.0) x1 = tl.load(X1, mask=(rm[:, None] < seqlen) & (rk_swap[None, :] < rotary_dim), other=0.0) if CONJUGATE: sin = -sin x0_cos = x0 * cos x1_sin = x1 * sin out = tl.where(rk[None, :] % 2 == 0, x0_cos - x1_sin, x0_cos + x1_sin) OUT = OUT + (rm[:, None] * stride_out_seqlen + rk[None, :] * stride_out_headdim) tl.store(OUT, out, mask=(rm[:, None] < seqlen) & (rk[None, :] < rotary_dim)) def apply_rotary( x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, position_ids: torch.Tensor, seqlen_offsets: Union[int, torch.Tensor] = 0, cu_seqlens: Optional[torch.Tensor] = None, max_seqlen: Optional[int] = None, interleaved=False, inplace=False, conjugate=False, ) -> torch.Tensor: """ Arguments: x: (batch, nheads, seqlen, headdim) if cu_seqlens is None else (total_seqlen, nheads, headdim). cos: (seqlen_ro, rotary_dim / 2) sin: (seqlen_ro, rotary_dim / 2) seqlen_offsets: integer or integer tensor of size (batch,) cu_seqlens: (batch + 1,) or None max_seqlen: int Returns: y: (batch, nheads, seqlen, headdim) """ batch, nheads, seqlen, headdim = x.shape seqlen_ro, rotary_dim = cos.shape batch_p, seqlen_p = position_ids.shape assert batch_p == batch and seqlen_p == seqlen assert sin.shape == cos.shape rotary_dim *= 2 assert rotary_dim <= headdim, "rotary_dim must be <= headdim" assert headdim <= 256, "Only support headdim <= 256" assert seqlen_ro >= max_seqlen, "seqlen_ro must be >= max_seqlen" assert ( cos.dtype == sin.dtype ), f"cos and sin must have the same dtype, got {cos.dtype} and {sin.dtype}" assert ( x.dtype == cos.dtype ), f"Input and cos/sin must have the same dtype, got {x.dtype} and {cos.dtype}" cos, sin = cos.contiguous(), sin.contiguous() if isinstance(seqlen_offsets, torch.Tensor): assert seqlen_offsets.shape == (batch,) assert seqlen_offsets.dtype in [torch.int32, torch.int64] seqlen_offsets = seqlen_offsets.contiguous() else: assert seqlen_offsets + max_seqlen <= seqlen_ro output = torch.empty_like(x) if not inplace else x if rotary_dim < headdim and not inplace: output[..., rotary_dim:].copy_(x[..., rotary_dim:]) BLOCK_K = ( 32 if rotary_dim <= 32 else (64 if rotary_dim <= 64 else (128 if rotary_dim <= 128 else 256)) ) grid = lambda META: (triton.cdiv(seqlen, META["BLOCK_M"]), batch, nheads) # noqa BLOCK_M = 4 if interleaved else (8 if rotary_dim <= 64 else 4) # Need this, otherwise Triton tries to launch from cuda:0 and we get # ValueError: Pointer argument (at 0) cannot be accessed from Triton (cpu tensor?) with torch.cuda.device(x.device.index): rotary_kernel[grid]( output, # data ptrs x, cos, sin, position_ids, cu_seqlens, seqlen_offsets, seqlen, # shapes nheads, rotary_dim, seqlen_ro, seqlen // 128, # key for triton cache (limit number of compilations) output.stride(0), # batch_strides output.stride(-3), # nheads_stride output.stride(-2), # seqlen_stride output.stride(-1), # headdim_stride x.stride(0), # batch_strides x.stride(-3), # nheads stride x.stride(-2), # seqlen stride x.stride(-1), # headdim stride position_ids.stride(0), # batch_strides BLOCK_K, isinstance(seqlen_offsets, torch.Tensor), False, interleaved, conjugate, BLOCK_M, ) return output