tzrec/ops/triton/triton_addmm.py (281 lines of code) (raw):

# Copyright (c) 2025, Alibaba Group; # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # http://www.apache.org/licenses/LICENSE-2.0 # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # We use the addmm ops from generative-recommenders a starting point. # https://github.com/facebookresearch/generative-recommenders # thanks to their public work. from typing import List, Tuple import torch import triton import triton.language as tl from triton.runtime.autotuner import autotune as triton_autotune ENABLE_FULL_TURNING_SPACE = False def get_mm_configs() -> List[triton.Config]: if torch.version.hip: if ENABLE_FULL_TURNING_SPACE: block_m_range = [32, 64, 128, 256] block_n_range = [32, 64, 128, 256] block_k_range = [32, 64] group_m_range = [4, 8] matrix_instr_nonkdim_range = [16] waves_per_eu_range = [0] kpack_range = [1, 2] num_warps_range = [4, 8] num_stage_range = [2] if triton.__version__ >= "3.2.0" else [0] else: block_m_range = [256] block_n_range = [256] block_k_range = [32] group_m_range = [8] matrix_instr_nonkdim_range = [16] waves_per_eu_range = [0] kpack_range = [2] num_warps_range = [8] num_stage_range = [2] if triton.__version__ >= "3.2.0" else [0] return [ triton.Config( { "BLOCK_M": block_m, "BLOCK_N": block_n, "BLOCK_K": block_k, "GROUP_M": group_m, "matrix_instr_nonkdim": matrix_instr_nonkdim, "waves_per_eu": waves_per_eu, "kpack": kpack, }, num_stages=num_stages, num_warps=num_warps, ) for block_m in block_m_range for block_n in block_n_range for block_k in block_k_range for group_m in group_m_range for matrix_instr_nonkdim in matrix_instr_nonkdim_range for waves_per_eu in waves_per_eu_range for kpack in kpack_range for num_stages in num_stage_range for num_warps in num_warps_range ] else: return [ triton.Config( { "BLOCK_M": 32, "BLOCK_N": 64, "BLOCK_K": 32, "GROUP_M": 8, }, num_stages=5, num_warps=2, ), triton.Config( { "BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 64, "GROUP_M": 8, }, num_stages=3, num_warps=8, ), triton.Config( { "BLOCK_M": 64, "BLOCK_N": 256, "BLOCK_K": 32, "GROUP_M": 8, }, num_stages=4, num_warps=4, ), triton.Config( { "BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 32, "GROUP_M": 8, }, num_stages=4, num_warps=4, ), triton.Config( { "BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_K": 32, "GROUP_M": 8, }, num_stages=4, num_warps=4, ), triton.Config( { "BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 32, "GROUP_M": 8, }, num_stages=4, num_warps=4, ), triton.Config( { "BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 32, "GROUP_M": 8, }, num_stages=4, num_warps=4, ), triton.Config( { "BLOCK_M": 64, "BLOCK_N": 32, "BLOCK_K": 32, "GROUP_M": 8, }, num_stages=5, num_warps=2, ), ] @triton_autotune( configs=get_mm_configs(), key=["N", "K"], ) @triton.jit def _addmm_fwd( x_ptr, w_ptr, y_ptr, z_ptr, M, N, K, stride_xm, stride_xk, stride_wk, stride_wn, stride_ym, stride_yn, stride_zm, stride_zn, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, GROUP_M: tl.constexpr, ALLOW_TF32: tl.constexpr, BROADCAST_Y: tl.constexpr, ): pid_0, pid_1 = tl.program_id(axis=0), tl.program_id(axis=1) pid = pid_0 * tl.num_programs(axis=1) + pid_1 num_pid_m = tl.cdiv(M, BLOCK_M) num_pid_n = tl.cdiv(N, BLOCK_N) num_pid_in_group = GROUP_M * num_pid_n group_id = pid // num_pid_in_group first_pid_m = group_id * GROUP_M group_size_m = min(num_pid_m - first_pid_m, GROUP_M) pid_m = first_pid_m + (pid % group_size_m) pid_n = (pid % num_pid_in_group) // group_size_m offs_m = tl.arange(0, BLOCK_M) offs_k = tl.arange(0, BLOCK_K) offs_n = tl.arange(0, BLOCK_N) mask_m = (pid_m * BLOCK_M + offs_m)[:, None] < M mask_n = (pid_n * BLOCK_N + offs_n)[None, :] < N x_ptr += pid_m.to(tl.int64) * BLOCK_M * stride_xm x_ptrs = x_ptr + (offs_m[:, None] * stride_xm + offs_k[None, :] * stride_xk) w_ptr += pid_n.to(tl.int64) * BLOCK_N * stride_wn w_ptrs = w_ptr + (offs_k[:, None] * stride_wk + offs_n[None, :] * stride_wn) accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) for k in range(0, tl.cdiv(K, BLOCK_K)): mask_k = offs_k[None, :] < K - k * BLOCK_K x = tl.load(x_ptrs, mask=mask_k & mask_m, other=0.0) mask_k = offs_k[:, None] < K - k * BLOCK_K w = tl.load(w_ptrs, mask=mask_k & mask_n, other=0.0) accumulator += tl.dot(x, w, allow_tf32=ALLOW_TF32) x_ptrs += BLOCK_K * stride_xk w_ptrs += BLOCK_K * stride_wk z_mask = mask_m & mask_n if BROADCAST_Y: # y is a vector, broadcast to add to z y_ptr += pid_n.to(tl.int64) * BLOCK_N * stride_yn y_ptrs = y_ptr + stride_yn * offs_n[None, :] y = tl.load(y_ptrs, mask=mask_n) else: y_ptr += pid_m.to(tl.int64) * BLOCK_M * stride_ym y_ptr += pid_n.to(tl.int64) * BLOCK_N * stride_yn y_ptrs = y_ptr + stride_ym * offs_m[:, None] + stride_yn * offs_n[None, :] y = tl.load(y_ptrs, mask=z_mask) z = (accumulator + y.to(tl.float32)).to(z_ptr.dtype.element_ty) z_ptr += pid_m.to(tl.int64) * BLOCK_M * stride_zm z_ptr += pid_n.to(tl.int64) * BLOCK_N * stride_zn z_ptrs = z_ptr + stride_zm * offs_m[:, None] + stride_zn * offs_n[None, :] tl.store(z_ptrs, z, mask=z_mask) def triton_addmm_fwd( x: torch.Tensor, w: torch.Tensor, y: torch.Tensor, ) -> torch.Tensor: M, K = x.shape KB, N = w.shape assert K == KB, f"incompatible dimensions {K}, {KB}" is_y_1d = y.dim() == 1 NY = y.shape[0] if is_y_1d else y.shape[1] assert N == NY, f"incompatible dimensions {N}, {NY}" # Allocate output z = torch.empty((M, N), device=x.device, dtype=x.dtype) if M == 0 or N == 0: return z grid = lambda meta: ( # noqa E731 triton.cdiv(M, meta["BLOCK_M"]), triton.cdiv(N, meta["BLOCK_N"]), ) _addmm_fwd[grid]( x, w, y, z, M, N, K, x.stride(0), x.stride(1), w.stride(0), w.stride(1), y.stride(0) if not is_y_1d else 0, y.stride(1) if not is_y_1d else y.stride(0), z.stride(0), z.stride(1), ALLOW_TF32=torch.backends.cuda.matmul.allow_tf32, BROADCAST_Y=is_y_1d, ) return z def triton_addmm_bwd( x: torch.Tensor, w: torch.Tensor, dz: torch.Tensor, is_y_1d: bool, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: if is_y_1d: dy = torch.sum(dz, dim=0) else: dy = dz dw = torch.mm(x.t(), dz) dx = torch.mm(dz, w.t()) return dx, dw, dy class _AddMmFunction(torch.autograd.Function): @staticmethod # pyre-ignore[14] def forward( ctx, x: torch.Tensor, w: torch.Tensor, y: torch.Tensor, ) -> torch.Tensor: ctx.save_for_backward(x, w) ctx.is_y_1d = y.dim() == 1 return triton_addmm_fwd(x, w, y) @staticmethod # pyre-ignore[14] def backward( ctx, dz: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: (x, w) = ctx.saved_tensors return triton_addmm_bwd(x, w, dz, ctx.is_y_1d) @torch.fx.wrap def triton_addmm( input: torch.Tensor, mat1: torch.Tensor, mat2: torch.Tensor, ) -> torch.Tensor: return _AddMmFunction.apply(mat1, mat2, input)