in segmentation/model/cnsn_resnet.py [0:0]
def __init__(self, inplanes, planes, pos, beta, crop, cnsn_type, stride=1, downsample=None, groups=1,
base_width=64, dilation=1, norm_layer=None):
super(BasicBlockCustom, self).__init__()
if norm_layer is None:
norm_layer = nn.BatchNorm2d
if groups != 1 or base_width != 64:
raise ValueError('BasicBlock only supports groups=1 and base_width=64')
if dilation > 1:
raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
# Both self.conv1 and self.downsample layers downsample the input when stride != 1
self.conv1 = conv3x3(inplanes, planes, stride)
self.bn1 = norm_layer(planes)
self.relu = nn.ReLU(inplace=True)
self.conv2 = conv3x3(planes, planes)
self.bn2 = norm_layer(planes)
self.downsample = downsample
self.stride = stride
assert cnsn_type in ['sn', 'cn', 'cnsn']
if 'cn' in cnsn_type:
crossnorm = CrossNorm(beta=beta, crop=crop)
else:
crossnorm = None
if 'sn' in cnsn_type:
print('using SelfNorm module')
if pos == 'pre' and not self.is_in_equal_out:
selfnorm = SelfNorm(in_planes)
else:
selfnorm = SelfNorm(out_planes)
else:
selfnorm = None
self.cnsn = CNSN(selfnorm=selfnorm, crossnorm=crossnorm)
self.pos = pos
if pos is not None:
print('{} in residual module: {}'.format(cnsn_type, pos))
assert pos in ['residual', 'identity', 'pre', 'post']