aiops/Pathformer_ICLR2024/utils/decomposition.py (40 lines of code) (raw):

import torch import torch.nn as nn import torch.nn.functional as F from einops import repeat, rearrange from contextlib import contextmanager def svd_denoise(x, cut): x_ = x.clone().detach() U, S, V = torch.linalg.svd(x_, full_matrices=False) S[:, cut:] = 0 return U @ torch.diag(S[0, :]) @ V @contextmanager def null_context(): yield def exists(val): return val is not None def default(val, d): return val if exists(val) else d class NMF(nn.Module): def __init__(self, dim, n, ratio=8, K=6, eps=2e-8): super().__init__() r = dim // ratio D = torch.zeros(dim, r).uniform_(0, 1) C = torch.zeros(r, n).uniform_(0, 1) self.K = K self.D = nn.Parameter(D) self.C = nn.Parameter(C) self.eps = eps def forward(self, x): b, D, C, eps = x.shape[0], self.D, self.C, self.eps # x is made non-negative with relu as proposed in paper x = F.relu(x) D = repeat(D, 'd r -> b d r', b = b) C = repeat(C, 'r n -> b r n', b = b) # transpose t = lambda tensor: rearrange(tensor, 'b i j -> b j i') for k in reversed(range(self.K)): # only calculate gradients on the last step, per propose 'One-step Gradient' context = null_context if k == 0 else torch.no_grad with context(): C_new = C * ((t(D) @ x) / ((t(D) @ D @ C) + eps)) D_new = D * ((x @ t(C)) / ((D @ C @ t(C)) + eps)) C, D = C_new, D_new return D @ C