in models/cifar/densenet_cnsn.py [0:0]
def __init__(self, growth_rate, depth, reduction, n_classes, bottleneck,
active_num=None, pos=None, beta=None, crop=None, cnsn_type=None):
super(DenseNet, self).__init__()
norm_func = nn.BatchNorm2d
if beta is not None:
print('beta: {}'.format(beta))
if crop is not None:
print('crop mode: {}'.format(crop))
if bottleneck:
n_dense_blocks = int((depth - 4) / 6)
else:
n_dense_blocks = int((depth - 4) / 3)
n_channels = 2 * growth_rate
self.conv1 = nn.Conv2d(3, n_channels, kernel_size=3, padding=1, bias=False)
# 1st block
self.dense1 = self._make_dense_cnsn(n_channels, growth_rate, n_dense_blocks,
bottleneck, norm_func, pos=pos,
beta=beta, crop=crop, cnsn_type=cnsn_type)
n_channels += n_dense_blocks * growth_rate
n_out_channels = int(math.floor(n_channels * reduction))
self.trans1 = Transition(n_channels, n_out_channels, norm_func)
n_channels = n_out_channels
# 2nd block
self.dense2 = self._make_dense_cnsn(n_channels, growth_rate, n_dense_blocks,
bottleneck, norm_func, pos=pos,
beta=beta, crop=crop, cnsn_type=cnsn_type)
n_channels += n_dense_blocks * growth_rate
n_out_channels = int(math.floor(n_channels * reduction))
self.trans2 = Transition(n_channels, n_out_channels, norm_func)
n_channels = n_out_channels
# 3rd block
self.dense3 = self._make_dense_cnsn(n_channels, growth_rate, n_dense_blocks,
bottleneck, norm_func, pos=pos,
beta=beta, crop=crop, cnsn_type=cnsn_type)
n_channels += n_dense_blocks * growth_rate
self.bn1 = nn.BatchNorm2d(n_channels)
self.fc = nn.Linear(n_channels, n_classes)
self.cn_modules = []
for m in self.modules():
if isinstance(m, nn.Conv2d):
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
m.weight.data.normal_(0, math.sqrt(2. / n))
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()
elif isinstance(m, nn.Linear):
m.bias.data.zero_()
elif isinstance(m, CrossNorm):
self.cn_modules.append(m)
if 'cn' in cnsn_type:
self.cn_num = len(self.cn_modules)
assert self.cn_num > 0
print('cn_num: {}'.format(self.cn_num))
self.active_num = active_num
assert self.active_num > 0
print('active_num: {}'.format(self.active_num))