aiops/ContraAD/model/PointAttention.py (292 lines of code) (raw):

import torch from torch import nn import torch.nn.functional as F from einops import repeat, rearrange, reduce from einops.layers.torch import Rearrange, Reduce from .attend import Attend from torch.nn import Module, ModuleList import torch from torch import nn, einsum, Tensor from torch.nn import Module, ModuleList import torch.nn.functional as F from typing import Optional, Union, Tuple from einops import rearrange, reduce, repeat, pack, unpack from einops.layers.torch import Rearrange from .RevIN import RevIN from rotary_embedding_torch import RotaryEmbedding def normalize(x,method='z-score'): # x shape : batch, size, channel # min-max normalization if len(x.shape) <3: x = x.unsqueeze(dim=-1) b,w,c = x.shape min_vals,_ = torch.min(x,dim=1) # batch,channel max_vals,_ = torch.max(x,dim=1) # batch,channel mean_vals = torch.mean(x,dim=1) std_vals = torch.std(x,dim=1) if method == 'min-max': min_vals = repeat(min_vals,'b c -> b w c',w=w) max_vals = repeat(max_vals,'b c -> b w c',w=w) x = (x - min_vals) / (max_vals - min_vals + 1e-8) # z-score normalization elif method == 'z-score': mean_vals = repeat(mean_vals,'b c -> b w c',w=w) std_vals = repeat(std_vals,'b c -> b w c',w=w) x = torch.abs((x - mean_vals) / (std_vals + 1e-8)) # softmax normalization elif method == 'softmax': x = F.softmax(x,dim=1) else : raise ValueError('Unknown normalization method') if c ==1: x = x.squeeze(dim=-1) return x def exists(v): return v is not None def default(v, d): return v if exists(v) else d def pack_one(t, pattern): return pack([t], pattern) def unpack_one(t, ps, pattern): return unpack(t, ps, pattern)[0] def identity(t, *args, **kwargs): return t def divisible_by(num, den): return (num % den) == 0 def cast_tuple(t): return (t,) if not isinstance(t, tuple) else t class Attention(Module): def __init__( self, dim, dim_head=32, heads=4, dropout=0.0, causal=False, flash=True, rotary_emb: Optional[RotaryEmbedding] = None, ): super().__init__() self.scale = dim_head**-0.5 dim_inner = dim_head * heads self.rotary_emb = rotary_emb self.to_qkv = nn.Sequential( nn.Linear(dim, dim_inner * 3, bias=False), Rearrange("b n (qkv h d) -> qkv b h n d", qkv=3, h=heads), ) self.to_v_gates = nn.Sequential( nn.Linear(dim, dim_inner, bias=False), nn.SiLU(), Rearrange("b n (h d) -> b h n d", h=heads), ) self.attend = Attend(flash=flash, dropout=dropout, causal=causal) self.to_out = nn.Sequential( Rearrange("b h n d -> b n (h d)"), nn.Linear(dim_inner, dim, bias=False), nn.Dropout(dropout), ) def forward(self, x): q, k, v = self.to_qkv(x) if exists(self.rotary_emb): q, k = map(self.rotary_emb.rotate_queries_or_keys, (q, k)) out = self.attend(q, k, v) out = out * self.to_v_gates(x) return self.to_out(out) # feedforward class GEGLU(Module): def forward(self, x): x, gate = rearrange(x, "... (r d) -> r ... d", r=2) return x * F.gelu(gate) def FeedForward(dim, mult=4, dropout=0.0): dim_inner = int(dim * mult * 2 / 3) return nn.Sequential( nn.Linear(dim, dim_inner * 2), GEGLU(), nn.Dropout(dropout), nn.Linear(dim_inner, dim), ) # transformer block class TransformerBlock(Module): def __init__( self, *, dim, causal=False, dim_head=32, heads=8, ff_mult=4, flash_attn=True, attn_dropout=0.0, ff_dropout=0.0, rotary_emb: Optional[RotaryEmbedding] = None, ): super().__init__() self.rotary_emb = rotary_emb self.attn = Attention( flash=flash_attn, rotary_emb=rotary_emb, causal=causal, dim=dim, dim_head=dim_head, heads=heads, dropout=attn_dropout, ) self.ff = FeedForward(dim=dim, mult=ff_mult, dropout=ff_dropout) self.attn_norm = nn.LayerNorm(dim) self.ff_norm = nn.LayerNorm(dim) def forward(self, x, rotary_emb: Optional[RotaryEmbedding] = None): x = self.attn(x) + x x = self.attn_norm(x) x = self.ff(x) + x x = self.ff_norm(x) return x def slide_window(x,hop=4): # b c w x_pad = F.pad(x,(hop,hop),"replicate") return x_pad.unfold(2,hop*2+1,1) class PatchAttention(nn.Module): def __init__( self, win_size:int, channel:int, depth:int=4, dim:int=512, dim_head:int=128, heads:int=4, attn_dropout:float=0.2, ff_mult:int=4, ff_dropout:float=0.2, hop:int=4, reduction='mean', use_RN=False, flash_attn=True, ): super().__init__() self.win_size = win_size self.channel = channel self.hop = hop self.intra_layer = nn.ModuleList([]) self.reduction = reduction rotary_emb = RotaryEmbedding(dim_head) for _ in range(depth): self.intra_layer.append( ModuleList( [ Attention( dim, dim_head=dim_head, heads=heads, dropout=attn_dropout, flash=flash_attn, rotary_emb=rotary_emb ), nn.LayerNorm(dim), FeedForward(dim, mult=ff_mult, dropout=ff_dropout), nn.LayerNorm(dim), ] ) ) patch_size = 2 * hop +1 self.intra_revin = RevIN(patch_size) self.intra_in = nn.Sequential( # (b r ) p c nn.Linear(channel,dim), nn.LayerNorm(dim) ) self.intra_seq_revin = RevIN(win_size) # self.to_out = nn.Linear(win_size * dim,win_size) def forward(self,input_x): b,w,c = input_x.shape x = rearrange(input_x,'b w c -> b c w') x = slide_window(x,self.hop) # batch channel rolling_num patch_size x = rearrange(x,'b c r p -> (b r) p c') #TODO:// check revin dim intra_x,reverse_fn = self.intra_revin(x) intra_x = self.intra_in(intra_x) for attn,attn_post_norm,ff,ff_post_norm in self.intra_layer: intra_x = attn(intra_x) + intra_x intra_x = attn_post_norm(intra_x) intra_x = ff(intra_x) + intra_x intra_x = ff_post_norm(intra_x) intra_x = rearrange(intra_x,'(b r) p d -> b r p d' ,b=b) intra_x = reduce(intra_x,'b r p d -> b r d',self.reduction) intra_x_seq = rearrange(input_x, "b w c -> b w c") intra_x_seq, reverse_fn = self.intra_seq_revin(intra_x_seq) intra_x_seq = self.intra_in(intra_x_seq) for attn, attn_post_norm, ff, ff_post_norm in self.intra_layer: intra_x_seq = attn(intra_x_seq) + intra_x_seq intra_x_seq = attn_post_norm(intra_x_seq) intra_x_seq = ff(intra_x_seq) + intra_x_seq intra_x_seq = ff_post_norm(intra_x_seq) return intra_x + intra_x_seq # x = rearrange(intra_x + intra_x_seq,'b w d -> b (w d)') # return self.to_out(x) class FeatureDistance(Module): def __init__(self): super().__init__() def forward(self, x): # x shape (batch, num_variates, embedding_dim) dis = torch.cdist(x, x) return dis.sum(2).sum(1) / 2 def cal_metric(x,z_score,mode='z-score',soft=True,soft_mode='min',model_mode='train'): if mode =='z-score_mae': dis = torch.cdist(x,x).sum(2) if soft: if soft_mode=='sum': val = dis.sum(dim=1) val = repeat(val,"b -> b w", w=dis.size(1)) dis = normalize(dis/val) # batch,win elif soft_mode == 'min': val,_ = dis.min(dim=1) val = repeat(val,"b -> b w" ,w=dis.size(1)) dis = normalize(dis/val) # batch,win if model_mode =='train': return F.l1_loss(dis,z_score,reduction='mean'),dis else: return dis elif mode == 'z_score_mse': dis = torch.cdist(x,x).sum(2) if soft: if soft_mode=='sum': val = dis.sum(dim=1) val = repeat(val,"b -> b w", w=dis.size(1)) dis = normalize(dis/val) # batch,win elif soft_mode == 'min': val,_ = dis.min(dim=1) val = repeat(val,"b -> b w" ,w=dis.size(1)) dis = normalize(dis/val) # batch,win if model_mode =='train': return F.mse_loss(dis,z_score,reduction='mean'),dis else: return dis elif mode == 'z_score_clamp': dis = torch.cdist(x,x).sum(2) if soft: if soft_mode=='sum': val = dis.sum(dim=1) val = repeat(val,"b -> b w", w=dis.size(1)) dis = normalize(dis/val) # batch,win elif soft_mode == 'min': val,_ = dis.min(dim=1) val = repeat(val,"b -> b w" ,w=dis.size(1)) dis = normalize(dis/val) # batch,win if model_mode == 'train': return torch.where(dis>z_score,dis,z_score-dis).sum(dim=1).mean(),dis else: return dis elif mode == 'distance': dis = torch.cdist(x,x).sum(2) if soft: if soft_mode=='sum': val = dis.sum(dim=1) val = repeat(val,"b -> b w", w=dis.size(1)) dis = normalize(dis/val) # batch,win elif soft_mode == 'min': val,_ = dis.min(dim=1) val = repeat(val,"b -> b w" ,w=dis.size(1)) dis = normalize(dis/val) # batch,win if model_mode == 'train': return dis.sum(dim=1).mean(),dis else: return dis class PointHingeLoss(Module): def __init__(self,mode='distance',soft=True,soft_mode='min'): super().__init__() self.mode = mode self.soft = soft self.soft_mode = soft_mode def forward(self,x,z_score): loss,metric = cal_metric(x=x,z_score=z_score,mode=self.mode,soft=self.soft,soft_mode=self.soft_mode) return loss,metric if __name__ == "__main__": x = torch.randn(2, 30, 3) # f = torch.randn(2,30,512) # z_score = torch.sum(normalize(x),dim=-1) # cri = PointHingeLoss() # loss = cri(f,z_score) # print(loss.shape) # PatchAttention() model = PatchAttention(30, 3, 4, 512, 128, 4, 0.2, 4, 0.2,4,'mean', False, True) intra_x = model(x) print(intra_x.shape)