in ppuda/utils/utils.py [0:0]
def adjust_net(net, large_input=False):
"""
Adjusts the first layers of the network so that small images (32x32) can be processed.
:param net: neural network
:param large_input: True if the input images are large (224x224 or more).
:return: the adjusted network
"""
net.expected_image_sz = 224 if large_input else 32
if large_input:
return net
def adjust_first_conv(conv1, ks=(3, 3), stride=1):
assert conv1.in_channels == 3, conv1
ks_org = conv1.weight.data.shape[2:]
if ks_org[0] > ks[0] or ks_org[1] or ks[1]:
# use the center of the filters
offset = ((ks_org[0] - ks[0]) // 2, (ks_org[1] - ks[1]) // 2)
offset1 = ((ks_org[0] - ks[0]) % 2, (ks_org[1] - ks[1]) % 2)
conv1.weight.data = conv1.weight.data[:, :, offset[0]:-offset[0]-offset1[0], offset[1]:-offset[1]-offset1[1]]
assert conv1.weight.data.shape[2:] == ks, (conv1.weight.data.shape, ks)
conv1.kernel_size = ks
conv1.padding = (ks[0] // 2, ks[1] // 2)
conv1.stride = (stride, stride)
if isinstance(net, ResNet):
adjust_first_conv(net.conv1)
assert hasattr(net, 'maxpool'), type(net)
net.maxpool = nn.Identity()
elif isinstance(net, DenseNet):
adjust_first_conv(net.features[0])
assert isinstance(net.features[3], nn.MaxPool2d), (net.features[3], type(net))
net.features[3] = nn.Identity()
elif isinstance(net, (MobileNetV2, MobileNetV3)): # requires torchvision 0.9+
def reduce_stride(m):
if isinstance(m, nn.Conv2d):
m.stride = 1
for m in net.features[:5]:
m.apply(reduce_stride)
elif isinstance(net, VGG):
for layer, mod in enumerate(net.features[:10]):
if isinstance(mod, nn.MaxPool2d):
net.features[layer] = nn.Identity()
elif isinstance(net, AlexNet):
net.features[0].stride = 1
net.features[2] = nn.Identity()
elif isinstance(net, MNASNet):
net.layers[0].stride = 1
elif isinstance(net, ShuffleNetV2):
net.conv1.stride = 1
net.maxpool = nn.Identity()
elif isinstance(net, GoogLeNet):
net.conv1.stride = 1
net.maxpool1 = nn.Identity()
else:
print('WARNING: the network (%s) is not adapted for small inputs which may result in lower performance' % str(
type(net)))
return net