in imnet_evaluate/pnasnet.py [0:0]
def __init__(self, in_channels_left, out_channels_left, in_channels_right,
out_channels_right, is_reduction=False, zero_pad=False,
match_prev_layer_dimensions=False):
super(Cell, self).__init__()
# If `is_reduction` is set to `True` stride 2 is used for
# convolutional and pooling layers to reduce the spatial size of
# the output of a cell approximately by a factor of 2.
stride = 2 if is_reduction else 1
# If `match_prev_layer_dimensions` is set to `True`
# `FactorizedReduction` is used to reduce the spatial size
# of the left input of a cell approximately by a factor of 2.
self.match_prev_layer_dimensions = match_prev_layer_dimensions
if match_prev_layer_dimensions:
self.conv_prev_1x1 = FactorizedReduction(in_channels_left,
out_channels_left)
else:
self.conv_prev_1x1 = ReluConvBn(in_channels_left,
out_channels_left, kernel_size=1)
self.conv_1x1 = ReluConvBn(in_channels_right, out_channels_right,
kernel_size=1)
self.comb_iter_0_left = BranchSeparables(out_channels_left,
out_channels_left,
kernel_size=5, stride=stride,
zero_pad=zero_pad)
self.comb_iter_0_right = MaxPool(3, stride=stride, zero_pad=zero_pad)
self.comb_iter_1_left = BranchSeparables(out_channels_right,
out_channels_right,
kernel_size=7, stride=stride,
zero_pad=zero_pad)
self.comb_iter_1_right = MaxPool(3, stride=stride, zero_pad=zero_pad)
self.comb_iter_2_left = BranchSeparables(out_channels_right,
out_channels_right,
kernel_size=5, stride=stride,
zero_pad=zero_pad)
self.comb_iter_2_right = BranchSeparables(out_channels_right,
out_channels_right,
kernel_size=3, stride=stride,
zero_pad=zero_pad)
self.comb_iter_3_left = BranchSeparables(out_channels_right,
out_channels_right,
kernel_size=3)
self.comb_iter_3_right = MaxPool(3, stride=stride, zero_pad=zero_pad)
self.comb_iter_4_left = BranchSeparables(out_channels_left,
out_channels_left,
kernel_size=3, stride=stride,
zero_pad=zero_pad)
if is_reduction:
self.comb_iter_4_right = ReluConvBn(out_channels_right,
out_channels_right,
kernel_size=1, stride=stride)
else:
self.comb_iter_4_right = None