bitsandbytes/triton/quantize_columnwise_and_transpose.py (63 lines of code) (raw):

import math import torch from bitsandbytes.triton.triton_utils import is_triton_available if not is_triton_available(): def quantize_columnwise_and_transpose(x: torch.Tensor): return None else: import triton import triton.language as tl # This kernel does fused columnwise quantization and transpose. # TODO: autotune this better. @triton.autotune( configs=[ triton.Config({}, num_stages=1), triton.Config({}, num_stages=2), triton.Config({}, num_stages=4), triton.Config({}, num_stages=8), triton.Config({}, num_stages=16), triton.Config({}, num_stages=1, num_warps=8), triton.Config({}, num_stages=2, num_warps=8), triton.Config({}, num_stages=4, num_warps=8), triton.Config({}, num_stages=8, num_warps=8), triton.Config({}, num_stages=16, num_warps=8), triton.Config({}, num_warps=1), triton.Config({}, num_warps=2), triton.Config({}, num_warps=4), triton.Config({}, num_warps=8), ], key=["n_elements"], ) @triton.jit def _quantize_columnwise_and_transpose( x_ptr, output_ptr, output_maxs, n_elements, M: tl.constexpr, N: tl.constexpr, BLOCK_SIZE: tl.constexpr, P2: tl.constexpr, ): pid = tl.program_id(axis=0) block_start = pid p2_arange = tl.arange(0, P2) p2_arange_mask = p2_arange < M arange = p2_arange * N offsets = block_start + arange x = tl.load(x_ptr + offsets, mask=p2_arange_mask) abs_x = tl.abs(x) max_val = tl.max(tl.where(p2_arange_mask, abs_x, 0), axis=0) output = tl.libdevice.llrint(127.0 * (x / max_val)) new_start = pid * M new_offsets = new_start + p2_arange tl.store(output_ptr + new_offsets, output, mask=p2_arange_mask) tl.store(output_maxs + pid, max_val) def quantize_columnwise_and_transpose(x: torch.Tensor): M, N = x.shape output = torch.empty(N, M, device=x.device, dtype=torch.int8) output_maxs = torch.empty(x.shape[1], device=x.device, dtype=torch.float16) P2 = int(2 ** (math.ceil(math.log2(M)))) assert x.is_cuda and output.is_cuda n_elements = output.numel() grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) _quantize_columnwise_and_transpose[grid](x, output, output_maxs, n_elements, M, N, BLOCK_SIZE=M, P2=P2) return output, output_maxs