aiops/Pathformer_ICLR2024/utils/Other.py (137 lines of code) (raw):
import torch
import torch.nn as nn
import numpy as np
import math
import torch.fft as fft
from einops import rearrange, reduce, repeat
class SparseDispatcher(object):
def __init__(self, num_experts, gates):
"""Create a SparseDispatcher."""
self._gates = gates
self._num_experts = num_experts
# sort experts
sorted_experts, index_sorted_experts = torch.nonzero(gates).sort(0)
_, self._expert_index = sorted_experts.split(1, dim=1)
# get according batch index for each expert
self._batch_index = torch.nonzero(gates)[index_sorted_experts[:, 1], 0]
self._part_sizes = (gates > 0).sum(0).tolist()
gates_exp = gates[self._batch_index.flatten()]
self._nonzero_gates = torch.gather(gates_exp, 1, self._expert_index)
def dispatch(self, inp):
# assigns samples to experts whose gate is nonzero
# expand according to batch index so we can just split by _part_sizes
inp_exp = inp[self._batch_index].squeeze(1)
return torch.split(inp_exp, self._part_sizes, dim=0)
def combine(self, expert_out, multiply_by_gates=True):
# apply exp to expert outputs, so we are not longer in log space
stitched = torch.cat(expert_out, 0).exp()
if multiply_by_gates:
stitched = torch.einsum("ijkh,ik -> ijkh", stitched, self._nonzero_gates)
zeros = torch.zeros(self._gates.size(0), expert_out[-1].size(1), expert_out[-1].size(2), expert_out[-1].size(3),
requires_grad=True, device=stitched.device)
# combine samples that have been processed by the same k experts
combined = zeros.index_add(0, self._batch_index, stitched.float())
# add eps to all zero values in order to avoid nans when going back to log space
combined[combined == 0] = np.finfo(float).eps
# back to log space
return combined.log()
def expert_to_gates(self):
# split nonzero gates for each expert
return torch.split(self._nonzero_gates, self._part_sizes, dim=0)
class MLP(nn.Module):
def __init__(self, input_size, output_size):
super(MLP, self).__init__()
self.fc = nn.Conv2d(in_channels=input_size,
out_channels=output_size,
kernel_size=(1, 1),
bias=True)
def forward(self, x):
out = self.fc(x)
return out
class moving_avg(nn.Module):
"""
Moving average block to highlight the trend of time series
"""
def __init__(self, kernel_size, stride):
super(moving_avg, self).__init__()
self.kernel_size = kernel_size
self.avg = nn.AvgPool1d(kernel_size=kernel_size, stride=stride, padding=0)
def forward(self, x):
# padding on the both ends of time series
front = x[:, 0:1, :].repeat(1, self.kernel_size - 1 - math.floor((self.kernel_size - 1) // 2), 1)
end = x[:, -1:, :].repeat(1, math.floor((self.kernel_size - 1) // 2), 1)
x = torch.cat([front, x, end], dim=1)
x = self.avg(x.permute(0, 2, 1))
x = x.permute(0, 2, 1)
return x
class series_decomp(nn.Module):
"""
Series decomposition block
"""
def __init__(self, kernel_size):
super(series_decomp, self).__init__()
self.moving_avg = moving_avg(kernel_size, stride=1)
def forward(self, x):
moving_mean = self.moving_avg(x)
res = x - moving_mean
return res, moving_mean
class series_decomp_multi(nn.Module):
"""
Series decomposition block
"""
def __init__(self, kernel_size):
super(series_decomp_multi, self).__init__()
self.moving_avg = [moving_avg(kernel, stride=1) for kernel in kernel_size]
self.layer = torch.nn.Linear(1, len(kernel_size))
def forward(self, x):
moving_mean = []
for func in self.moving_avg:
moving_avg = func(x)
moving_mean.append(moving_avg.unsqueeze(-1))
moving_mean = torch.cat(moving_mean, dim=-1)
moving_mean = torch.sum(moving_mean * nn.Softmax(-1)(self.layer(x.unsqueeze(-1))), dim=-1)
res = x - moving_mean
return res, moving_mean
class FourierLayer(nn.Module):
def __init__(self, pred_len, k=None, low_freq=1, output_attention=False):
super().__init__()
# self.d_model = d_model
self.pred_len = pred_len
self.k = k
self.low_freq = low_freq
self.output_attention = output_attention
def forward(self, x):
"""x: (b, t, d)"""
if self.output_attention:
return self.dft_forward(x)
b, t, d = x.shape
x_freq = fft.rfft(x, dim=1)
if t % 2 == 0:
x_freq = x_freq[:, self.low_freq:-1]
f = fft.rfftfreq(t)[self.low_freq:-1]
else:
x_freq = x_freq[:, self.low_freq:]
f = fft.rfftfreq(t)[self.low_freq:]
x_freq, index_tuple = self.topk_freq(x_freq)
f = repeat(f, 'f -> b f d', b=x_freq.size(0), d=x_freq.size(2))
f = f.to(x_freq.device)
f = rearrange(f[index_tuple], 'b f d -> b f () d').to(x_freq.device)
return self.extrapolate(x_freq, f, t), None
def extrapolate(self, x_freq, f, t):
x_freq = torch.cat([x_freq, x_freq.conj()], dim=1)
f = torch.cat([f, -f], dim=1)
t_val = rearrange(torch.arange(t + self.pred_len, dtype=torch.float),
't -> () () t ()').to(x_freq.device)
amp = rearrange(x_freq.abs() / t, 'b f d -> b f () d')
phase = rearrange(x_freq.angle(), 'b f d -> b f () d')
x_time = amp * torch.cos(2 * math.pi * f * t_val + phase)
return reduce(x_time, 'b f t d -> b t d', 'sum')
def topk_freq(self, x_freq):
values, indices = torch.topk(x_freq.abs(), self.k, dim=1, largest=True, sorted=True)
mesh_a, mesh_b = torch.meshgrid(torch.arange(x_freq.size(0)), torch.arange(x_freq.size(2)))
index_tuple = (mesh_a.unsqueeze(1), indices, mesh_b.unsqueeze(1))
x_freq = x_freq[index_tuple]
return x_freq, index_tuple
def dft_forward(self, x):
T = x.size(1)
dft_mat = fft.fft(torch.eye(T))
i, j = torch.meshgrid(torch.arange(self.pred_len + T), torch.arange(T))
omega = np.exp(2 * math.pi * 1j / T)
idft_mat = (np.power(omega, i * j) / T).cfloat()
x_freq = torch.einsum('ft,btd->bfd', [dft_mat, x.cfloat()])
if T % 2 == 0:
x_freq = x_freq[:, self.low_freq:T // 2]
else:
x_freq = x_freq[:, self.low_freq:T // 2 + 1]
_, indices = torch.topk(x_freq.abs(), self.k, dim=1, largest=True, sorted=True)
indices = indices + self.low_freq
indices = torch.cat([indices, -indices], dim=1)
dft_mat = repeat(dft_mat, 'f t -> b f t d', b=x.shape[0], d=x.shape[-1])
idft_mat = repeat(idft_mat, 't f -> b t f d', b=x.shape[0], d=x.shape[-1])
mesh_a, mesh_b = torch.meshgrid(torch.arange(x.size(0)), torch.arange(x.size(2)))
dft_mask = torch.zeros_like(dft_mat)
dft_mask[mesh_a, indices, :, mesh_b] = 1
dft_mat = dft_mat * dft_mask
idft_mask = torch.zeros_like(idft_mat)
idft_mask[mesh_a, :, indices, mesh_b] = 1
idft_mat = idft_mat * idft_mask
attn = torch.einsum('bofd,bftd->botd', [idft_mat, dft_mat]).real
return torch.einsum('botd,btd->bod', [attn, x]), rearrange(attn, 'b o t d -> b d o t')