def get_l2_reg()

in uimnet/algorithms/base.py [0:0]


  def get_l2_reg(self):
    l2_reg = 0.
    if self.hparams["weight_decay"] > 0.:
      for net_name, net in self.networks.items():
        for module in net.modules():
          if hasattr(module, 'weight') and module.weight is not None:
            if isinstance(module, torch.nn.modules.batchnorm._BatchNorm):
              continue
            l2_reg += module.weight.norm(p='fro').pow(2)
    return l2_reg