tzrec/ops/benchmarks/hstu_attention_bench.py (277 lines of code) (raw):

# Copyright (c) 2025, Alibaba Group; # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # http://www.apache.org/licenses/LICENSE-2.0 # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os from typing import List, Optional, Tuple import click import pandas as pd import torch # @manual=//triton:triton import triton from tzrec.ops import Kernel from tzrec.ops.hstu_attention import delta_hstu_mha, hstu_mha from tzrec.utils.test_util import generate_sparse_seq_len def _apply_sampling( lengths: torch.Tensor, alpha: float, max_seq_len: int, ) -> torch.Tensor: threshold = int(max_seq_len ** (alpha / 2)) no_sample_prob = (max_seq_len**alpha) / torch.pow(lengths, 2) users_to_sample = torch.logical_and( lengths > threshold, torch.rand_like(no_sample_prob) < 1 - no_sample_prob, ) lengths = torch.where(users_to_sample, threshold, lengths) return lengths def _get_kernel(provider: str) -> Kernel: if provider == "triton": return Kernel.TRITON elif provider == "pytorch": return Kernel.PYTORCH else: raise ValueError(f"Unknown provider {provider}") def _flops( batch_size: int, max_seqlen: int, attn_dim: int, hidden_dim: int, nheads: int, seq_offsets: torch.Tensor, mode: str = "fwd", ) -> float: assert mode in ["fwd", "bwd", "fwd_bwd"] ratio = 2.0 # triangular masking f1 = 0.0 f2 = 0.0 for i in range(batch_size): seq_len = int((seq_offsets[i + 1] - seq_offsets[i]).item()) # (QK^T), dQ = d(QK^T)K, dK^T = Q^Td(QK^T) f1 += 2 * nheads * attn_dim * seq_len**2 // ratio # (QK^T)V, d(QK^T) = dOV^T, dV = (QK^T)^TdO, f2 += 2 * nheads * hidden_dim * seq_len**2 // ratio if mode == "fwd": return f1 + f2 # computes (QK^T) and (QK^T)V elif mode == "bwd": return 3 * f1 + 2 * f2 # computes (QK^T), dQ, dK, dV, d(QK^T) else: return 4 * f1 + 3 * f2 @click.command() @click.option( "--batch-size", type=int, default=512, ) @click.option("--heads", type=int, default=4) @click.option("--attn-dim", type=int, default=128) @click.option("--hidden-dim", type=int, default=128) @click.option("--max-seq-len-log2", type=int, default=13) @click.option("--data-type", type=str, default="bf16") @click.option("--seq-sparsity", type=float, default=0.95) @click.option("--has-delta-q", type=bool, default=False) @click.option("--delta-size", type=int, default=256) @click.option("--target-size", type=int, default=20) @click.option("--bench-backward", type=bool, default=True) @click.option("--bench-forward", type=bool, default=True) @click.option("--bench-pytorch", type=bool, default=False) @click.option("--report-flops", type=bool, default=False) @click.option("--return-result", type=bool, default=False) @click.option("--max-attn-len", type=int, default=0) @click.option("--contextual-seq-len", type=int, default=0) @click.option("--sampling-alpha", type=float, default=2.0) def main( # noqa: C901 batch_size: int, heads: int, attn_dim: int, hidden_dim: int, max_seq_len_log2: int, data_type: str, seq_sparsity: float, has_delta_q: bool, delta_size: int, target_size: int, bench_backward: bool, bench_forward: bool, bench_pytorch: bool, report_flops: bool, return_result: bool, max_attn_len: int, contextual_seq_len: int, sampling_alpha: float, ) -> Optional[Tuple[List[triton.testing.Benchmark], List[pd.DataFrame]]]: torch.backends.cudnn.allow_tf32 = True torch.backends.cuda.matmul.allow_tf32 = True if data_type == "fp32": dtype = torch.float32 elif data_type == "fp16": dtype = torch.float16 elif data_type == "bf16": dtype = torch.bfloat16 else: raise ValueError(f"Unsupported data type: {data_type}.") line_vals = ["triton"] line_names = ["Triton"] styles = [("red", "-")] if bench_pytorch: line_vals.append("pytorch") line_names.append("PyTorch") styles.append(("green", "-")) bench_backward = False if has_delta_q else bench_backward modes = [] if bench_forward: modes.append("fwd") if bench_backward: modes.append("bwd") assert len(modes) > 0 configs: List[triton.testing.Benchmark] = [ triton.testing.Benchmark( x_names=["seq_len"], x_vals=[2**i for i in range(8, max_seq_len_log2)], line_arg="provider", line_vals=line_vals, line_names=line_names, styles=styles, ylabel="ms", plot_name=f"hstu-attn-b{batch_size}-h{heads}-d{attn_dim}-v{hidden_dim}--sparsity{seq_sparsity}-{mode}-{dtype}-target{target_size}-mattn{max_attn_len}-c{contextual_seq_len}-sl_alpha{sampling_alpha}", args={ "batch_size": batch_size, "heads": heads, "attn_dim": attn_dim, "hidden_dim": hidden_dim, "dtype": dtype, "mode": mode, "seq_sparsity": seq_sparsity, "has_delta_q": has_delta_q, "delta_size": delta_size, "target_size": target_size, "bench_backward": bench_backward, "report_flops": report_flops, "max_attn_len": max_attn_len, "contextual_seq_len": contextual_seq_len, "sampling_alpha": sampling_alpha, }, ) for mode in modes ] @triton.testing.perf_report(configs) def _bench_hstu_attention( batch_size: int, heads: int, seq_len: int, attn_dim: int, hidden_dim: int, mode: str, provider: str, dtype: torch.dtype, seq_sparsity: float, has_delta_q: bool, delta_size: int, target_size: int, bench_backward: bool, report_flops: bool, max_attn_len: int, contextual_seq_len: int, sampling_alpha: float, ) -> float: assert mode in ["fwd", "bwd"] warmup = 25 rep = 1000 torch.manual_seed(1001) # for reproducibility alpha = 1.0 / attn_dim causal = True lengths = generate_sparse_seq_len( size=batch_size, max_seq_len=seq_len, sparsity=seq_sparsity, device=torch.device("cuda"), ) lengths = _apply_sampling(lengths, sampling_alpha, max_seq_len=seq_len) if has_delta_q: lengths = lengths + delta_size num_targets = torch.ones_like(lengths) * delta_size seq_len = seq_len + delta_size else: delta_size = 0 num_targets = None if target_size != 0: num_targets = torch.randint( 1, target_size + 1, (batch_size,), device=lengths.device, dtype=lengths.dtype, ) num_targets = torch.where(num_targets > lengths, lengths, num_targets) max_attn_len = max_attn_len if max_attn_len < seq_len else seq_len seq_offsets = torch.zeros( (batch_size + 1,), dtype=torch.int64, device=torch.device("cuda") ) seq_offsets[1:] = torch.cumsum(lengths, dim=0) L = int(seq_offsets[-1].item()) x = torch.empty( (L, heads, attn_dim * 2 + hidden_dim), dtype=dtype, device=torch.device("cuda"), ).uniform_(-0.01, 0.01) q, k, v = torch.split(x, [attn_dim, attn_dim, hidden_dim], dim=-1) delta_q = torch.empty( (batch_size * delta_size, heads, attn_dim), dtype=dtype, device=torch.device("cuda"), ).uniform_(-0.1, 0.1) delta_x_offsets = torch.arange(0, delta_size, device=torch.device("cuda")) delta_x_offsets = (seq_offsets[1:] - delta_size).view( batch_size, 1 ) + delta_x_offsets.view(1, delta_size) delta_x_offsets = delta_x_offsets.view(-1) if bench_backward: q = q.requires_grad_(True) k = k.requires_grad_(True) v = v.requires_grad_(True) assert provider in ["triton", "pytorch"] if has_delta_q: fn = lambda: delta_hstu_mha( # noqa E731 max_seq_len=seq_len, alpha=alpha, delta_q=delta_q, k=k, v=v, seq_offsets=seq_offsets, num_targets=num_targets, kernel=_get_kernel(provider), ) else: fn = lambda: hstu_mha( # noqa E73 max_seq_len=seq_len, alpha=alpha, q=q, k=k, v=v, seq_offsets=seq_offsets, causal=causal, dropout_pr=0.0, training=True, num_targets=num_targets, max_attn_len=max_attn_len, contextual_seq_len=contextual_seq_len, sort_by_length=True, kernel=_get_kernel(provider), ) if mode == "bwd": o = fn() do = torch.randn_like(o) fn = lambda: o.backward(do, retain_graph=True) # noqa E731 ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep) all_flops = _flops( batch_size, seq_len, attn_dim, hidden_dim, heads, seq_offsets, mode ) if has_delta_q: all_flops = all_flops / seq_len * delta_size if report_flops: return all_flops / ms / 1e9 else: return ms df = _bench_hstu_attention.run( print_data=True, show_plots=False, save_path="/tmp/" + os.environ["USER"], return_df=return_result, ) if return_result: return configs, df if __name__ == "__main__": main()