aiops/ContraAD/model/attend.py (90 lines of code) (raw):

from functools import partial import torch from torch import nn, einsum, Tensor import torch.nn.functional as F from collections import namedtuple from functools import wraps from packaging import version from einops import rearrange, repeat # constants EfficientAttentionConfig = namedtuple('EfficientAttentionConfig', ['enable_flash', 'enable_math', 'enable_mem_efficient']) # helpers def exists(val): return val is not None def default(val, d): return val if exists(val) else d def once(fn): called = False @wraps(fn) def inner(x): nonlocal called if called: return called = True return fn(x) return inner print_once = once(print) # main class class Attend(nn.Module): def __init__( self, *, dropout = 0., heads = None, scale = None, flash = False, causal = False ): super().__init__() self.scale = scale self.dropout = dropout self.attn_dropout = nn.Dropout(dropout) self.causal = causal # flash attention self.flash = flash assert not (flash and version.parse(torch.__version__) < version.parse('2.0.0')), 'in order to use flash attention, you must be using pytorch 2.0 or above' # determine efficient attention configs for cuda and cpu self.cpu_config = EfficientAttentionConfig(True, True, True) self.cuda_config = None if not torch.cuda.is_available() or not flash: return device_properties = torch.cuda.get_device_properties(torch.device('cuda')) major, minor = device_properties.major, device_properties.minor if (major, minor) == (8, 0): print_once('A100 GPU detected, using flash attention if input tensor is on cuda') self.cuda_config = EfficientAttentionConfig(True, False, False) elif (major, minor) == (9, 0): print_once('H100 GPU detected, using flash attention') self.cuda_config = EfficientAttentionConfig(True, False, False) elif (major,minor) == (8,6): # print_once('3090 detected using falsh attention') self.cuda_config = EfficientAttentionConfig(True,False,False) else: print_once('Non-A100 GPU detected, using math or mem efficient attention if input tensor is on cuda') self.cuda_config = EfficientAttentionConfig(False, True, True) def flash_attn( self, q, k, v ): batch, heads, q_len, _, k_len, is_cuda, device = *q.shape, k.shape[-2], q.is_cuda, q.device # Check if there is a compatible device for flash attention config = self.cuda_config if is_cuda else self.cpu_config # pytorch 2.0 flash attn: q, k, v, mask, dropout, causal, softmax_scale with torch.backends.cuda.sdp_kernel(enable_flash=config.enable_flash): out = F.scaled_dot_product_attention( q, k, v, is_causal = self.causal, dropout_p = self.dropout if self.training else 0. ) return out def forward( self, q, k, v ): """ einstein notation b - batch h - heads n, i, j - sequence length (base sequence length, source, target) d - feature dimension """ n, heads, kv_heads, device = q.shape[-2], q.shape[1], k.shape[1], q.device scale = default(self.scale, q.shape[-1] ** -0.5) if self.flash: return self.flash_attn(q, k, v) sim = einsum(f'b h i d, b h j d -> b h i j', q, k) * scale if self.causal: i, j, dtype = *sim.shape[-2:], sim.dtype mask_value = -torch.finfo(sim.dtype).max causal_mask = torch.ones((i, j), dtype = torch.bool, device = device).triu(j - i + 1) sim = sim.masked_fill(causal_mask, mask_value) attn = sim.softmax(dim = -1) attn = attn.type(dtype) attn = self.attn_dropout(attn) out = einsum(f'b h i j, b h j d -> b h i d', attn, v) return out