aiops/Pathformer_ICLR2024/layers/AMS.py (93 lines of code) (raw):
import torch
import torch.nn as nn
from torch.distributions.normal import Normal
from layers.Layer import Transformer_Layer
from utils.Other import SparseDispatcher, FourierLayer, series_decomp_multi, MLP
class AMS(nn.Module):
def __init__(self, input_size, output_size, num_experts, device, num_nodes=1, d_model=32, d_ff=64, dynamic=False,
patch_size=[8, 6, 4, 2], noisy_gating=True, k=4, layer_number=1, residual_connection=1):
super(AMS, self).__init__()
self.num_experts = num_experts
self.output_size = output_size
self.input_size = input_size
self.k = k
self.start_linear = nn.Linear(in_features=num_nodes, out_features=1)
self.seasonality_model = FourierLayer(pred_len=0, k=3)
self.trend_model = series_decomp_multi(kernel_size=[4, 8, 12])
self.experts = nn.ModuleList()
self.MLPs = nn.ModuleList()
for patch in patch_size:
patch_nums = int(input_size / patch)
self.experts.append(Transformer_Layer(device=device, d_model=d_model, d_ff=d_ff,
dynamic=dynamic, num_nodes=num_nodes, patch_nums=patch_nums,
patch_size=patch, factorized=True, layer_number=layer_number))
self.w_gate = nn.Parameter(torch.zeros(input_size, num_experts), requires_grad=True)
self.w_noise = nn.Parameter(torch.zeros(input_size, num_experts), requires_grad=True)
self.residual_connection = residual_connection
self.end_MLP = MLP(input_size=input_size, output_size=output_size)
self.noisy_gating = noisy_gating
self.softplus = nn.Softplus()
self.softmax = nn.Softmax(1)
self.register_buffer("mean", torch.tensor([0.0]))
self.register_buffer("std", torch.tensor([1.0]))
assert (self.k <= self.num_experts)
def cv_squared(self, x):
eps = 1e-10
if x.shape[0] == 1:
return torch.tensor([0], device=x.device, dtype=x.dtype)
return x.float().var() / (x.float().mean() ** 2 + eps)
def _gates_to_load(self, gates):
return (gates > 0).sum(0)
def _prob_in_top_k(self, clean_values, noisy_values, noise_stddev, noisy_top_values):
batch = clean_values.size(0)
m = noisy_top_values.size(1)
top_values_flat = noisy_top_values.flatten()
threshold_positions_if_in = torch.arange(batch, device=clean_values.device) * m + self.k
threshold_if_in = torch.unsqueeze(torch.gather(top_values_flat, 0, threshold_positions_if_in), 1)
is_in = torch.gt(noisy_values, threshold_if_in)
threshold_positions_if_out = threshold_positions_if_in - 1
threshold_if_out = torch.unsqueeze(torch.gather(top_values_flat, 0, threshold_positions_if_out), 1)
normal = Normal(self.mean, self.std)
prob_if_in = normal.cdf((clean_values - threshold_if_in) / noise_stddev)
prob_if_out = normal.cdf((clean_values - threshold_if_out) / noise_stddev)
prob = torch.where(is_in, prob_if_in, prob_if_out)
return prob
def seasonality_and_trend_decompose(self, x):
x = x[:, :, :, 0]
_, trend = self.trend_model(x)
seasonality, _ = self.seasonality_model(x)
return x + seasonality + trend
def noisy_top_k_gating(self, x, train, noise_epsilon=1e-2):
x = self.start_linear(x).squeeze(-1)
clean_logits = x @ self.w_gate
if self.noisy_gating and train:
raw_noise_stddev = x @ self.w_noise
noise_stddev = ((self.softplus(raw_noise_stddev) + noise_epsilon))
noisy_logits = clean_logits + (torch.randn_like(clean_logits) * noise_stddev)
logits = noisy_logits
else:
logits = clean_logits
# calculate topk + 1 that will be needed for the noisy gates
top_logits, top_indices = logits.topk(min(self.k + 1, self.num_experts), dim=1)
top_k_logits = top_logits[:, :self.k]
top_k_indices = top_indices[:, :self.k]
top_k_gates = self.softmax(top_k_logits)
zeros = torch.zeros_like(logits, requires_grad=True)
gates = zeros.scatter(1, top_k_indices, top_k_gates)
if self.noisy_gating and self.k < self.num_experts and train:
load = (self._prob_in_top_k(clean_logits, noisy_logits, noise_stddev, top_logits)).sum(0)
else:
load = self._gates_to_load(gates)
return gates, load
def forward(self, x, loss_coef=1e-2):
new_x = self.seasonality_and_trend_decompose(x)
#multi-scale router
gates, load = self.noisy_top_k_gating(new_x, self.training)
# calculate balance loss
importance = gates.sum(0)
balance_loss = self.cv_squared(importance) + self.cv_squared(load)
balance_loss *= loss_coef
dispatcher = SparseDispatcher(self.num_experts, gates)
expert_inputs = dispatcher.dispatch(x)
expert_outputs = [self.experts[i](expert_inputs[i])[0] for i in range(self.num_experts)]
output = dispatcher.combine(expert_outputs)
if self.residual_connection:
output = output + x
return output, balance_loss