in tensorflow_model_optimization/python/core/sparsity/keras/pruning_policy.py [0:0]
def _check_layer_support(self, layer):
"""Returns whether the layer is supported or not.
Mimics XNNPACK's behaviour of compatibility function.
Args:
layer: Current layer in the model.
Returns:
True if the layer is supported, False otherwise.
References:
- https://github.com/google/XNNPACK/blob/master/src/subgraph.c#L130
"""
if isinstance(layer,
(layers.Add, layers.Multiply, layers.ZeroPadding2D,
layers.ReLU, layers.LeakyReLU, layers.ELU, layers.Dropout)):
return True
elif isinstance(layer, layers.DepthwiseConv2D):
# 3x3 convolution with `SAME` padding (no dilation, stride-1).
# 3x3 convolution with `VALID` padding (no dilation, stride-1 or stride-2,
# preceding `ZeroPadding2D` layer with padding 1 on each side.
# 5x5 convolution with `SAME` padding (no dilation, stride-1)
# 5x5 convolution with `VALID` padding (no dilation, stride-1 or stride-2,
# preceding `ZeroPadding2D` layer with padding 2 on each side.
padding = layer.padding.lower()
producers = list(self._get_producers(layer))
zero_padding = (
producers[0] if len(producers) == 1 and
isinstance(producers[0], layers.ZeroPadding2D) else None)
supported_case_1 = (
layer.kernel_size == (3, 3) and layer.strides == (1, 1) and
padding == 'same')
supported_case_2 = (
layer.kernel_size == (3, 3) and
(layer.strides == (1, 1) or layer.strides == (2, 2)) and
padding == 'valid' and zero_padding and
zero_padding.padding == ((1, 1), (1, 1)))
supported_case_3 = (
layer.kernel_size == (5, 5) and layer.strides == (1, 1) and
padding == 'same')
supported_case_4 = (
layer.kernel_size == (5, 5) and
(layer.strides == (1, 1) or layer.strides == (2, 2)) and
padding == 'valid' and zero_padding and
zero_padding.padding == ((2, 2), (2, 2)))
supported = (
layer.depth_multiplier == 1 and layer.dilation_rate == (1, 1) and
(supported_case_1 or supported_case_2 or supported_case_3 or
supported_case_4))
return supported
elif isinstance(layer, layers.Conv2D):
# 1x1 convolution (no stride, no dilation, no padding, no groups).
return (layer.groups == 1 and layer.dilation_rate == (1, 1) and
layer.kernel_size == (1, 1) and layer.strides == (1, 1))
elif isinstance(layer, layers.GlobalAveragePooling2D):
return layer.keepdims
elif isinstance(layer, layers.BatchNormalization):
return list(layer.axis) == [3]
elif isinstance(layer, layers.UpSampling2D):
return layer.interpolation == 'bilinear'
elif isinstance(layer, layers.Activation):
return activations.serialize(layer.activation) in ('relu', 'relu6',
'leaky_relu', 'elu',
'sigmoid')
elif layer.__class__.__name__ == 'TFOpLambda':
return layer.function in (tf.identity, tf.__operators__.add, tf.math.add,
tf.math.subtract, tf.math.multiply)
elif isinstance(layer, pruning_wrapper.PruneLowMagnitude):
return self._check_layer_support(layer.layer)
return False