def __init__()

in models/cifar/resnext_cnsn.py [0:0]


  def __init__(self, depth, cardinality, base_width, num_classes,
               active_num=None, pos=None, beta=None,
               crop=None, cnsn_type=None):
    super(CifarResNeXt, self).__init__()

    norm_func = nn.BatchNorm2d

    if beta is not None:
        print('beta: {}'.format(beta))

    if crop is not None:
        print('crop mode: {}'.format(crop))

    # Model type specifies number of layers for CIFAR-10 and CIFAR-100 model
    assert (depth - 2) % 9 == 0, 'depth should be one of 29, 38, 47, 56, 101'
    layer_blocks = (depth - 2) // 9

    self.cardinality = cardinality
    self.base_width = base_width
    self.num_classes = num_classes

    self.conv_1_3x3 = nn.Conv2d(3, 64, 3, 1, 1, bias=False)
    self.bn_1 = norm_func(64)

    self.inplanes = 64

    # 1st block
    self.stage_1 = self._make_layer_cnsn(ResNeXtBottleneckCustom, 64, layer_blocks, norm_func,
                                         pos=pos, beta=beta, crop=crop,
                                         cnsn_type=cnsn_type, stride=1)

    # 2nd block
    self.stage_2 = self._make_layer_cnsn(ResNeXtBottleneckCustom, 128, layer_blocks, norm_func,
                                         pos=pos, beta=beta, crop=crop,
                                         cnsn_type=cnsn_type, stride=2)

    # 3rd block
    self.stage_3 = self._make_layer_cnsn(ResNeXtBottleneckCustom, 256, layer_blocks, norm_func,
                                         pos=pos, beta=beta, crop=crop,
                                         cnsn_type=cnsn_type, stride=2)


    self.avgpool = nn.AvgPool2d(8)
    self.classifier = nn.Linear(256 * ResNeXtBottleneckCustom.expansion, num_classes)

    self.cn_modules = []
    for m in self.modules():
      if isinstance(m, nn.Conv2d):
        n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
        m.weight.data.normal_(0, math.sqrt(2. / n))
      elif isinstance(m, nn.BatchNorm2d):
        m.weight.data.fill_(1)
        m.bias.data.zero_()
      elif isinstance(m, nn.Linear):
        init.kaiming_normal_(m.weight)
        m.bias.data.zero_()
      elif isinstance(m, CrossNorm):
          self.cn_modules.append(m)

    if 'cn' in cnsn_type:
        self.cn_num = len(self.cn_modules)
        assert self.cn_num > 0
        print('cn_num: {}'.format(self.cn_num))
        self.active_num = active_num
        assert self.active_num > 0
        print('active_num: {}'.format(self.active_num))