in uimnet/algorithms/base.py [0:0]
def get_l2_reg(self):
l2_reg = 0.
if self.hparams["weight_decay"] > 0.:
for net_name, net in self.networks.items():
for module in net.modules():
if hasattr(module, 'weight') and module.weight is not None:
if isinstance(module, torch.nn.modules.batchnorm._BatchNorm):
continue
l2_reg += module.weight.norm(p='fro').pow(2)
return l2_reg