in tensorflow_model_optimization/python/core/quantization/keras/experimental/default_n_bit/default_n_bit_quantize_layout_transform.py [0:0]
def apply(self, model, layer_quantize_map):
"""Implement default 8-bit transforms.
Currently this means the following.
1. Pull activations into layers, and apply fuse activations. (TODO)
2. Modify range in incoming layers for Concat. (TODO)
3. Fuse Conv2D/DepthwiseConv2D + BN into single layer.
Args:
model: Keras model to be quantized.
layer_quantize_map: Map with keys as layer names, and values as dicts
containing custom `QuantizeConfig`s which may have been passed with
layers.
Returns:
(Transformed Keras model to better match TensorFlow Lite backend, updated
layer quantize map.)
"""
transforms = [
default_n_bit_transforms.InputLayerQuantize(
num_bits_weight=self._num_bits_weight,
num_bits_activation=self._num_bits_activation),
default_n_bit_transforms.SeparableConv1DQuantize(
num_bits_weight=self._num_bits_weight,
num_bits_activation=self._num_bits_activation),
default_n_bit_transforms.SeparableConvQuantize(
num_bits_weight=self._num_bits_weight,
num_bits_activation=self._num_bits_activation),
default_n_bit_transforms.Conv2DReshapeBatchNormReLUQuantize(
num_bits_weight=self._num_bits_weight,
num_bits_activation=self._num_bits_activation),
default_n_bit_transforms.Conv2DReshapeBatchNormActivationQuantize(
num_bits_weight=self._num_bits_weight,
num_bits_activation=self._num_bits_activation),
default_n_bit_transforms.Conv2DBatchNormReLUQuantize(
num_bits_weight=self._num_bits_weight,
num_bits_activation=self._num_bits_activation),
default_n_bit_transforms.Conv2DBatchNormActivationQuantize(
num_bits_weight=self._num_bits_weight,
num_bits_activation=self._num_bits_activation),
default_n_bit_transforms.Conv2DReshapeBatchNormQuantize(
num_bits_weight=self._num_bits_weight,
num_bits_activation=self._num_bits_activation),
default_n_bit_transforms.Conv2DBatchNormQuantize(
num_bits_weight=self._num_bits_weight,
num_bits_activation=self._num_bits_activation),
default_n_bit_transforms.ConcatTransform6Inputs(
num_bits_weight=self._num_bits_weight,
num_bits_activation=self._num_bits_activation),
default_n_bit_transforms.ConcatTransform5Inputs(
num_bits_weight=self._num_bits_weight,
num_bits_activation=self._num_bits_activation),
default_n_bit_transforms.ConcatTransform4Inputs(
num_bits_weight=self._num_bits_weight,
num_bits_activation=self._num_bits_activation),
default_n_bit_transforms.ConcatTransform3Inputs(
num_bits_weight=self._num_bits_weight,
num_bits_activation=self._num_bits_activation),
default_n_bit_transforms.ConcatTransform(
num_bits_weight=self._num_bits_weight,
num_bits_activation=self._num_bits_activation),
default_n_bit_transforms.LayerReLUQuantize(
num_bits_weight=self._num_bits_weight,
num_bits_activation=self._num_bits_activation),
default_n_bit_transforms.LayerReluActivationQuantize(
num_bits_weight=self._num_bits_weight,
num_bits_activation=self._num_bits_activation),
]
return model_transformer.ModelTransformer(
model, transforms,
set(layer_quantize_map.keys()), layer_quantize_map).transform()