def __init__()

in models/cifar/densenet_cnsn.py [0:0]


  def __init__(self, growth_rate, depth, reduction, n_classes, bottleneck,
               active_num=None, pos=None, beta=None, crop=None, cnsn_type=None):
    super(DenseNet, self).__init__()

    norm_func = nn.BatchNorm2d

    if beta is not None:
      print('beta: {}'.format(beta))

    if crop is not None:
      print('crop mode: {}'.format(crop))

    if bottleneck:
      n_dense_blocks = int((depth - 4) / 6)
    else:
      n_dense_blocks = int((depth - 4) / 3)

    n_channels = 2 * growth_rate
    self.conv1 = nn.Conv2d(3, n_channels, kernel_size=3, padding=1, bias=False)

    # 1st block
    self.dense1 = self._make_dense_cnsn(n_channels, growth_rate, n_dense_blocks,
                                        bottleneck, norm_func, pos=pos,
                                        beta=beta, crop=crop, cnsn_type=cnsn_type)

    n_channels += n_dense_blocks * growth_rate
    n_out_channels = int(math.floor(n_channels * reduction))
    self.trans1 = Transition(n_channels, n_out_channels, norm_func)

    n_channels = n_out_channels
    # 2nd block
    self.dense2 = self._make_dense_cnsn(n_channels, growth_rate, n_dense_blocks,
                                        bottleneck, norm_func, pos=pos,
                                        beta=beta, crop=crop, cnsn_type=cnsn_type)

    n_channels += n_dense_blocks * growth_rate
    n_out_channels = int(math.floor(n_channels * reduction))
    self.trans2 = Transition(n_channels, n_out_channels, norm_func)

    n_channels = n_out_channels
    # 3rd block
    self.dense3 = self._make_dense_cnsn(n_channels, growth_rate, n_dense_blocks,
                                        bottleneck, norm_func, pos=pos,
                                        beta=beta, crop=crop, cnsn_type=cnsn_type)

    n_channels += n_dense_blocks * growth_rate

    self.bn1 = nn.BatchNorm2d(n_channels)
    self.fc = nn.Linear(n_channels, n_classes)

    self.cn_modules = []
    for m in self.modules():
      if isinstance(m, nn.Conv2d):
        n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
        m.weight.data.normal_(0, math.sqrt(2. / n))
      elif isinstance(m, nn.BatchNorm2d):
        m.weight.data.fill_(1)
        m.bias.data.zero_()
      elif isinstance(m, nn.Linear):
        m.bias.data.zero_()
      elif isinstance(m, CrossNorm):
          self.cn_modules.append(m)

    if 'cn' in cnsn_type:
        self.cn_num = len(self.cn_modules)
        assert self.cn_num > 0
        print('cn_num: {}'.format(self.cn_num))
        self.active_num = active_num
        assert self.active_num > 0
        print('active_num: {}'.format(self.active_num))