models/cifar/allconv_cnsn.py [135:158]:
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
        m.bias.data.zero_()
      elif isinstance(m, CrossNorm):
          self.cn_modules.append(m)

    if 'cn' in cnsn_type:
        self.cn_num = len(self.cn_modules)
        assert self.cn_num > 0
        print('cn_num: {}'.format(self.cn_num))
        self.active_num = active_num
        assert self.active_num > 0
        print('active_num: {}'.format(self.active_num))

  def _enable_cross_norm(self):
      active_cn_idxs = np.random.choice(self.cn_num, self.active_num, replace=False).tolist()
      assert len(set(active_cn_idxs)) == self.active_num
      # print('active_cn_idxs: {}'.format(active_cn_idxs))
      for idx in active_cn_idxs:
          self.cn_modules[idx].active = True

  def forward(self, x, aug=False):
    if aug:
      # print('forward cross norm...')
      # exit()
      self._enable_cross_norm()
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -



models/cifar/wideresnet_cnsn.py [187:208]:
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
        m.bias.data.zero_()
      elif isinstance(m, CrossNorm):
          self.cn_modules.append(m)

    if 'cn' in cnsn_type:
        self.cn_num = len(self.cn_modules)
        assert self.cn_num > 0
        print('cn_num: {}'.format(self.cn_num))
        self.active_num = active_num
        assert self.active_num > 0
        print('active_num: {}'.format(self.active_num))

  def _enable_cross_norm(self):
      active_cn_idxs = np.random.choice(self.cn_num, self.active_num, replace=False).tolist()
      assert len(set(active_cn_idxs)) == self.active_num
      for idx in active_cn_idxs:
          self.cn_modules[idx].active = True

  def forward(self, x, aug=False):

      if aug:
          self._enable_cross_norm()
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -



