in sparseconvnet/spectral_norm.py [0:0]
def compute_weight(self, module):
weight = getattr(module, self.name + '_orig')
u = getattr(module, self.name + '_u')
weight_mat = weight #/ self.norm
if self.dim != 0:
# permute dim to front
weight_mat = weight_mat.permute(self.dim,
*[d for d in range(weight_mat.dim()) if d != self.dim])
height = weight_mat.size(0)
weight_mat = weight_mat.reshape(height, -1)
with torch.no_grad():
for _ in range(self.n_power_iterations):
# Spectral norm of weight equals to `u^T W v`, where `u` and `v`
# are the first left and right singular vectors.
# This power iteration produces approximations of `u` and `v`.
v = normalize(torch.matmul(weight_mat.t(), u), dim=0, eps=self.eps)
u = normalize(torch.matmul(weight_mat, v), dim=0, eps=self.eps)
sigma = torch.dot(u, torch.matmul(weight_mat, v))
weight = weight * (self.norm / sigma)
return weight, u