aiops/Pathformer_ICLR2024/layers/Layer.py (233 lines of code) (raw):
import math
import torch
import torch.nn as nn
from torch.nn import init
import time
import torch.nn.functional as F
from layers.Embedding import *
class Transformer_Layer(nn.Module):
def __init__(self, device, d_model, d_ff, num_nodes, patch_nums, patch_size, dynamic, factorized, layer_number):
super(Transformer_Layer, self).__init__()
self.device = device
self.d_model = d_model
self.num_nodes = num_nodes
self.dynamic = dynamic
self.patch_nums = patch_nums
self.patch_size = patch_size
self.layer_number = layer_number
##intra_patch_attention
self.intra_embeddings = nn.Parameter(torch.rand(self.patch_nums, 1, 1, self.num_nodes, 16),
requires_grad=True)
self.embeddings_generator = nn.ModuleList([nn.Sequential(*[
nn.Linear(16, self.d_model)]) for _ in range(self.patch_nums)])
self.intra_d_model = self.d_model
self.intra_patch_attention = Intra_Patch_Attention(self.intra_d_model, factorized=factorized)
self.weights_generator_distinct = WeightGenerator(self.intra_d_model, self.intra_d_model, mem_dim=16, num_nodes=num_nodes,
factorized=factorized, number_of_weights=2)
self.weights_generator_shared = WeightGenerator(self.intra_d_model, self.intra_d_model, mem_dim=None, num_nodes=num_nodes,
factorized=False, number_of_weights=2)
self.intra_Linear = nn.Linear(self.patch_nums, self.patch_nums*self.patch_size)
##inter_patch_attention
self.stride = patch_size
# patch_num = int((context_window - cut_size) / self.stride + 1)
self.inter_d_model = self.d_model * self.patch_size
##inter_embedding
self.emb_linear = nn.Linear(self.inter_d_model, self.inter_d_model)
# Positional encoding
self.W_pos = positional_encoding(pe='zeros', learn_pe=True, q_len=self.patch_nums, d_model=self.inter_d_model)
n_heads = self.d_model
d_k = self.inter_d_model // n_heads
d_v = self.inter_d_model // n_heads
self.inter_patch_attention = Inter_Patch_Attention(self.inter_d_model, self.inter_d_model, n_heads, d_k, d_v, attn_dropout=0,
proj_dropout=0.1, res_attention=False)
##Normalization
self.norm_attn = nn.Sequential(Transpose(1,2), nn.BatchNorm1d(self.d_model), Transpose(1,2))
self.norm_ffn = nn.Sequential(Transpose(1,2), nn.BatchNorm1d(self.d_model), Transpose(1,2))
##FFN
self.d_ff = d_ff
self.dropout = nn.Dropout(0.1)
self.ff = nn.Sequential(nn.Linear(self.d_model, self.d_ff, bias=True),
nn.GELU(),
nn.Dropout(0.2),
nn.Linear(self.d_ff, self.d_model, bias=True))
def forward(self, x):
new_x = x
batch_size = x.size(0)
intra_out_concat = None
weights_shared, biases_shared = self.weights_generator_shared()
weights_distinct, biases_distinct = self.weights_generator_distinct()
####intra Attention#####
for i in range(self.patch_nums):
t = x[:, i * self.patch_size:(i + 1) * self.patch_size, :, :]
intra_emb = self.embeddings_generator[i](self.intra_embeddings[i]).expand(batch_size, -1, -1, -1)
t = torch.cat([intra_emb, t], dim=1)
out, attention = self.intra_patch_attention(intra_emb, t, t, weights_distinct, biases_distinct, weights_shared,
biases_shared)
if intra_out_concat == None:
intra_out_concat = out
else:
intra_out_concat = torch.cat([intra_out_concat, out], dim=1)
intra_out_concat = intra_out_concat.permute(0,3,2,1)
intra_out_concat = self.intra_Linear(intra_out_concat)
intra_out_concat = intra_out_concat.permute(0,3,2,1)
####inter Attention######
x = x.unfold(dimension=1, size=self.patch_size, step=self.stride) # [b x patch_num x nvar x dim x patch_len]
x = x.permute(0, 2, 1, 3, 4) # [b x nvar x patch_num x dim x patch_len ]
b, nvar, patch_num, dim, patch_len = x.shape
x = torch.reshape(x, (
x.shape[0] * x.shape[1], x.shape[2], x.shape[3] * x.shape[-1])) # [b*nvar, patch_num, dim*patch_len]
x = self.emb_linear(x)
x = self.dropout(x + self.W_pos)
inter_out, attention = self.inter_patch_attention(Q=x, K=x, V=x) # [b*nvar, patch_num, dim]
inter_out = torch.reshape(inter_out, (b, nvar, inter_out.shape[-2], inter_out.shape[-1]))
inter_out = torch.reshape(inter_out, (b, nvar, inter_out.shape[-2], self.patch_size, self.d_model))
inter_out = torch.reshape(inter_out, (b, self.patch_size*self.patch_nums, nvar, self.d_model)) #[b, temporal, nvar, dim]
out = new_x + intra_out_concat + inter_out
##FFN
out = self.dropout(out)
out = self.ff(out) + out
return out, attention
class CustomLinear(nn.Module):
def __init__(self, factorized):
super(CustomLinear, self).__init__()
self.factorized = factorized
def forward(self, input, weights, biases):
if self.factorized:
return torch.matmul(input.unsqueeze(3), weights).squeeze(3) + biases
else:
return torch.matmul(input, weights) + biases
class Intra_Patch_Attention(nn.Module):
def __init__(self, d_model, factorized):
super(Intra_Patch_Attention, self).__init__()
self.head = 2
if d_model % self.head != 0:
raise Exception('Hidden size is not divisible by the number of attention heads')
self.head_size = int(d_model // self.head)
self.custom_linear = CustomLinear(factorized)
def forward(self, query, key, value, weights_distinct, biases_distinct, weights_shared, biases_shared):
batch_size = query.shape[0]
key = self.custom_linear(key, weights_distinct[0], biases_distinct[0])
value = self.custom_linear(value, weights_distinct[1], biases_distinct[1])
query = torch.cat(torch.split(query, self.head_size, dim=-1), dim=0)
key = torch.cat(torch.split(key, self.head_size, dim=-1), dim=0)
value = torch.cat(torch.split(value, self.head_size, dim=-1), dim=0)
query = query.permute((0, 2, 1, 3))
key = key.permute((0, 2, 3, 1))
value = value.permute((0, 2, 1, 3))
attention = torch.matmul(query, key)
attention /= (self.head_size ** 0.5)
attention = torch.softmax(attention, dim=-1)
x = torch.matmul(attention, value)
x = x.permute((0, 2, 1, 3))
x = torch.cat(torch.split(x, batch_size, dim=0), dim=-1)
if x.shape[0] == 0:
x = x.repeat(1, 1, 1, int(weights_shared[0].shape[-1] / x.shape[-1]))
x = self.custom_linear(x, weights_shared[0], biases_shared[0])
x = torch.relu(x)
x = self.custom_linear(x, weights_shared[1], biases_shared[1])
return x, attention
class Inter_Patch_Attention(nn.Module):
def __init__(self, d_model, out_dim, n_heads, d_k=None, d_v=None, res_attention=False, attn_dropout=0.,
proj_dropout=0., qkv_bias=True, lsa=False):
super().__init__()
d_k = d_model // n_heads if d_k is None else d_k
d_v = d_model // n_heads if d_v is None else d_v
self.n_heads, self.d_k, self.d_v = n_heads, d_k, d_v
self.W_Q = nn.Linear(d_model, d_k * n_heads, bias=qkv_bias)
self.W_K = nn.Linear(d_model, d_k * n_heads, bias=qkv_bias)
self.W_V = nn.Linear(d_model, d_v * n_heads, bias=qkv_bias)
# Scaled Dot-Product Attention (multiple heads)
self.res_attention = res_attention
self.sdp_attn = ScaledDotProductAttention(d_model, n_heads, attn_dropout=attn_dropout,
res_attention=self.res_attention, lsa=lsa)
# Poject output
self.to_out = nn.Sequential(nn.Linear(n_heads * d_v, out_dim), nn.Dropout(proj_dropout))
def forward(self, Q, K=None, V=None, prev=None, key_padding_mask=None, attn_mask=None):
bs = Q.size(0)
if K is None: K = Q
if V is None: V = Q
# Linear (+ split in multiple heads)
q_s = self.W_Q(Q).view(bs, Q.shape[1], self.n_heads, self.d_k).transpose(1,
2) # q_s : [bs x n_heads x q_len x d_k] 此处的q_len为patch_num
k_s = self.W_K(K).view(bs, K.shape[1], self.n_heads, self.d_k).permute(0, 2, 3,
1) # k_s : [bs x n_heads x d_k x q_len] - transpose(1,2) + transpose(2,3)
v_s = self.W_V(V).view(bs, V.shape[1], self.n_heads, self.d_v).transpose(1,
2) # v_s : [bs x n_heads x q_len x d_v]
# Apply Scaled Dot-Product Attention (multiple heads)
if self.res_attention:
output, attn_weights, attn_scores = self.sdp_attn(q_s, k_s, v_s, prev=prev,
key_padding_mask=key_padding_mask, attn_mask=attn_mask)
else:
output, attn_weights = self.sdp_attn(q_s, k_s, v_s, key_padding_mask=key_padding_mask, attn_mask=attn_mask)
output = output.transpose(1, 2).contiguous().view(bs, Q.shape[1],
self.n_heads * self.d_v) # output: [bs x q_len x n_heads * d_v]
output = self.to_out(output)
return output, attn_weights
class ScaledDotProductAttention(nn.Module):
r"""Scaled Dot-Product Attention module (Attention is all you need by Vaswani et al., 2017) with optional residual attention from previous layer
(Realformer: Transformer likes residual attention by He et al, 2020) and locality self sttention (Vision Transformer for Small-Size Datasets
by Lee et al, 2021)"""
def __init__(self, d_model, n_heads, attn_dropout=0., res_attention=False, lsa=False):
super().__init__()
self.attn_dropout = nn.Dropout(attn_dropout)
self.res_attention = res_attention
head_dim = d_model // n_heads
self.scale = nn.Parameter(torch.tensor(head_dim ** -0.5), requires_grad=lsa)
self.lsa = lsa
def forward(self, q, k, v, prev=None, key_padding_mask=None, attn_mask=None):
# Scaled MatMul (q, k) - similarity scores for all pairs of positions in an input sequence
attn_scores = torch.matmul(q, k) * self.scale # attn_scores : [bs x n_heads x max_q_len x q_len]
# Add pre-softmax attention scores from the previous layer (optional)
if prev is not None: attn_scores = attn_scores + prev
# Attention mask (optional)
if attn_mask is not None: # attn_mask with shape [q_len x seq_len] - only used when q_len == seq_len
if attn_mask.dtype == torch.bool:
attn_scores.masked_fill_(attn_mask, -np.inf)
else:
attn_scores += attn_mask
# Key padding mask (optional)
if key_padding_mask is not None: # mask with shape [bs x q_len] (only when max_w_len == q_len)
attn_scores.masked_fill_(key_padding_mask.unsqueeze(1).unsqueeze(2), -np.inf)
# normalize the attention weights
attn_weights = F.softmax(attn_scores, dim=-1) # attn_weights : [bs x n_heads x max_q_len x q_len]
attn_weights = self.attn_dropout(attn_weights)
# compute the new values given the attention weights
output = torch.matmul(attn_weights, v) # output: [bs x n_heads x max_q_len x d_v]
return output, attn_weights
class WeightGenerator(nn.Module):
def __init__(self, in_dim, out_dim, mem_dim, num_nodes, factorized, number_of_weights=4):
super(WeightGenerator, self).__init__()
#print('FACTORIZED {}'.format(factorized))
self.number_of_weights = number_of_weights
self.mem_dim = mem_dim
self.num_nodes = num_nodes
self.factorized = factorized
self.out_dim = out_dim
if self.factorized:
self.memory = nn.Parameter(torch.randn(num_nodes, mem_dim), requires_grad=True).to('cpu')
# self.memory = nn.Parameter(torch.randn(num_nodes, mem_dim), requires_grad=True).to('cuda:0')
self.generator = self.generator = nn.Sequential(*[
nn.Linear(mem_dim, 64),
nn.Tanh(),
nn.Linear(64, 64),
nn.Tanh(),
nn.Linear(64, 100)
])
self.mem_dim = 10
self.P = nn.ParameterList(
[nn.Parameter(torch.Tensor(in_dim, self.mem_dim), requires_grad=True) for _ in
range(number_of_weights)])
self.Q = nn.ParameterList(
[nn.Parameter(torch.Tensor(self.mem_dim, out_dim), requires_grad=True) for _ in
range(number_of_weights)])
self.B = nn.ParameterList(
[nn.Parameter(torch.Tensor(self.mem_dim ** 2, out_dim), requires_grad=True) for _ in
range(number_of_weights)])
else:
self.P = nn.ParameterList(
[nn.Parameter(torch.Tensor(in_dim, out_dim), requires_grad=True) for _ in range(number_of_weights)])
self.B = nn.ParameterList(
[nn.Parameter(torch.Tensor(1, out_dim), requires_grad=True) for _ in range(number_of_weights)])
self.reset_parameters()
def reset_parameters(self):
list_params = [self.P, self.Q, self.B] if self.factorized else [self.P]
for weight_list in list_params:
for weight in weight_list:
init.kaiming_uniform_(weight, a=math.sqrt(5))
if not self.factorized:
for i in range(self.number_of_weights):
fan_in, _ = init._calculate_fan_in_and_fan_out(self.P[i])
bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
init.uniform_(self.B[i], -bound, bound)
def forward(self):
if self.factorized:
memory = self.generator(self.memory.unsqueeze(1))
bias = [torch.matmul(memory, self.B[i]).squeeze(1) for i in range(self.number_of_weights)]
memory = memory.view(self.num_nodes, self.mem_dim, self.mem_dim)
weights = [torch.matmul(torch.matmul(self.P[i], memory), self.Q[i]) for i in range(self.number_of_weights)]
return weights, bias
else:
return self.P, self.B
class Transpose(nn.Module):
def __init__(self, *dims, contiguous=False):
super().__init__()
self.dims, self.contiguous = dims, contiguous
def forward(self, x):
if self.contiguous: return x.transpose(*self.dims).contiguous()
else: return x.transpose(*self.dims)