in models/cifar/resnext_cnsn.py [0:0]
def __init__(self,
inplanes,
planes,
cardinality,
base_width,
norm_func,
pos, beta, crop, cnsn_type,
stride=1,
downsample=None):
super(ResNeXtBottleneckCustom, self).__init__()
dim = int(math.floor(planes * (base_width / 64.0)))
self.conv_reduce = nn.Conv2d(
inplanes,
dim * cardinality,
kernel_size=1,
stride=1,
padding=0,
bias=False)
self.bn_reduce = norm_func(dim * cardinality)
self.conv_conv = nn.Conv2d(
dim * cardinality,
dim * cardinality,
kernel_size=3,
stride=stride,
padding=1,
groups=cardinality,
bias=False)
self.bn = norm_func(dim * cardinality)
self.conv_expand = nn.Conv2d(
dim * cardinality,
planes * 4,
kernel_size=1,
stride=1,
padding=0,
bias=False)
self.bn_expand = norm_func(planes * 4)
self.downsample = downsample
assert cnsn_type in ['sn', 'cn', 'cnsn']
if 'cn' in cnsn_type:
print('using CrossNorm with crop: {}'.format(crop))
crossnorm = CrossNorm(crop=crop, beta=beta)
else:
crossnorm = None
if 'sn' in cnsn_type:
print('using SelfNorm')
if pos in ['pre', 'identity']:
selfnorm = SelfNorm(inplanes)
else:
selfnorm = SelfNorm(planes * 4)
else:
selfnorm = None
self.cnsn = CNSN(crossnorm=crossnorm, selfnorm=selfnorm)
self.pos = pos
# if pos is not None:
print('{} in residual module: {}'.format(cnsn_type, pos))
assert pos in ['residual', 'identity', 'pre', 'post']