def _check_layer_support()

in tensorflow_model_optimization/python/core/sparsity/keras/pruning_policy.py [0:0]


  def _check_layer_support(self, layer):
    """Returns whether the layer is supported or not.

    Mimics XNNPACK's behaviour of compatibility function.

    Args:
      layer: Current layer in the model.

    Returns:
      True if the layer is supported, False otherwise.

    References:
      - https://github.com/google/XNNPACK/blob/master/src/subgraph.c#L130
    """
    if isinstance(layer,
                  (layers.Add, layers.Multiply, layers.ZeroPadding2D,
                   layers.ReLU, layers.LeakyReLU, layers.ELU, layers.Dropout)):
      return True
    elif isinstance(layer, layers.DepthwiseConv2D):
      # 3x3 convolution with `SAME` padding (no dilation, stride-1).
      # 3x3 convolution with `VALID` padding (no dilation, stride-1 or stride-2,
      #   preceding `ZeroPadding2D` layer with padding 1 on each side.
      # 5x5 convolution with `SAME` padding (no dilation, stride-1)
      # 5x5 convolution with `VALID` padding (no dilation, stride-1 or stride-2,
      #   preceding `ZeroPadding2D` layer with padding 2 on each side.
      padding = layer.padding.lower()
      producers = list(self._get_producers(layer))
      zero_padding = (
          producers[0] if len(producers) == 1 and
          isinstance(producers[0], layers.ZeroPadding2D) else None)

      supported_case_1 = (
          layer.kernel_size == (3, 3) and layer.strides == (1, 1) and
          padding == 'same')

      supported_case_2 = (
          layer.kernel_size == (3, 3) and
          (layer.strides == (1, 1) or layer.strides == (2, 2)) and
          padding == 'valid' and zero_padding and
          zero_padding.padding == ((1, 1), (1, 1)))

      supported_case_3 = (
          layer.kernel_size == (5, 5) and layer.strides == (1, 1) and
          padding == 'same')

      supported_case_4 = (
          layer.kernel_size == (5, 5) and
          (layer.strides == (1, 1) or layer.strides == (2, 2)) and
          padding == 'valid' and zero_padding and
          zero_padding.padding == ((2, 2), (2, 2)))

      supported = (
          layer.depth_multiplier == 1 and layer.dilation_rate == (1, 1) and
          (supported_case_1 or supported_case_2 or supported_case_3 or
           supported_case_4))

      return supported
    elif isinstance(layer, layers.Conv2D):
      # 1x1 convolution (no stride, no dilation, no padding, no groups).
      return (layer.groups == 1 and layer.dilation_rate == (1, 1) and
              layer.kernel_size == (1, 1) and layer.strides == (1, 1))
    elif isinstance(layer, layers.GlobalAveragePooling2D):
      return layer.keepdims
    elif isinstance(layer, layers.BatchNormalization):
      return list(layer.axis) == [3]
    elif isinstance(layer, layers.UpSampling2D):
      return layer.interpolation == 'bilinear'
    elif isinstance(layer, layers.Activation):
      return activations.serialize(layer.activation) in ('relu', 'relu6',
                                                         'leaky_relu', 'elu',
                                                         'sigmoid')
    elif layer.__class__.__name__ == 'TFOpLambda':
      return layer.function in (tf.identity, tf.__operators__.add, tf.math.add,
                                tf.math.subtract, tf.math.multiply)
    elif isinstance(layer, pruning_wrapper.PruneLowMagnitude):
      return self._check_layer_support(layer.layer)
    return False