def apply()

in tensorflow_model_optimization/python/core/quantization/keras/experimental/default_n_bit/default_n_bit_quantize_layout_transform.py [0:0]


  def apply(self, model, layer_quantize_map):
    """Implement default 8-bit transforms.

    Currently this means the following.
      1. Pull activations into layers, and apply fuse activations. (TODO)
      2. Modify range in incoming layers for Concat. (TODO)
      3. Fuse Conv2D/DepthwiseConv2D + BN into single layer.

    Args:
      model: Keras model to be quantized.
      layer_quantize_map: Map with keys as layer names, and values as dicts
        containing custom `QuantizeConfig`s which may have been passed with
        layers.

    Returns:
      (Transformed Keras model to better match TensorFlow Lite backend, updated
      layer quantize map.)
    """

    transforms = [
        default_n_bit_transforms.InputLayerQuantize(
            num_bits_weight=self._num_bits_weight,
            num_bits_activation=self._num_bits_activation),
        default_n_bit_transforms.SeparableConv1DQuantize(
            num_bits_weight=self._num_bits_weight,
            num_bits_activation=self._num_bits_activation),
        default_n_bit_transforms.SeparableConvQuantize(
            num_bits_weight=self._num_bits_weight,
            num_bits_activation=self._num_bits_activation),
        default_n_bit_transforms.Conv2DReshapeBatchNormReLUQuantize(
            num_bits_weight=self._num_bits_weight,
            num_bits_activation=self._num_bits_activation),
        default_n_bit_transforms.Conv2DReshapeBatchNormActivationQuantize(
            num_bits_weight=self._num_bits_weight,
            num_bits_activation=self._num_bits_activation),
        default_n_bit_transforms.Conv2DBatchNormReLUQuantize(
            num_bits_weight=self._num_bits_weight,
            num_bits_activation=self._num_bits_activation),
        default_n_bit_transforms.Conv2DBatchNormActivationQuantize(
            num_bits_weight=self._num_bits_weight,
            num_bits_activation=self._num_bits_activation),
        default_n_bit_transforms.Conv2DReshapeBatchNormQuantize(
            num_bits_weight=self._num_bits_weight,
            num_bits_activation=self._num_bits_activation),
        default_n_bit_transforms.Conv2DBatchNormQuantize(
            num_bits_weight=self._num_bits_weight,
            num_bits_activation=self._num_bits_activation),
        default_n_bit_transforms.ConcatTransform6Inputs(
            num_bits_weight=self._num_bits_weight,
            num_bits_activation=self._num_bits_activation),
        default_n_bit_transforms.ConcatTransform5Inputs(
            num_bits_weight=self._num_bits_weight,
            num_bits_activation=self._num_bits_activation),
        default_n_bit_transforms.ConcatTransform4Inputs(
            num_bits_weight=self._num_bits_weight,
            num_bits_activation=self._num_bits_activation),
        default_n_bit_transforms.ConcatTransform3Inputs(
            num_bits_weight=self._num_bits_weight,
            num_bits_activation=self._num_bits_activation),
        default_n_bit_transforms.ConcatTransform(
            num_bits_weight=self._num_bits_weight,
            num_bits_activation=self._num_bits_activation),
        default_n_bit_transforms.LayerReLUQuantize(
            num_bits_weight=self._num_bits_weight,
            num_bits_activation=self._num_bits_activation),
        default_n_bit_transforms.LayerReluActivationQuantize(
            num_bits_weight=self._num_bits_weight,
            num_bits_activation=self._num_bits_activation),
    ]
    return model_transformer.ModelTransformer(
        model, transforms,
        set(layer_quantize_map.keys()), layer_quantize_map).transform()