def quantize_apply()

in tensorflow_model_optimization/python/core/quantization/keras/quantize.py [0:0]


def quantize_apply(
    model,
    scheme=default_8bit_quantize_scheme.Default8BitQuantizeScheme()):
  """Quantize a `tf.keras` model that has been annotated for quantization.

  Quantization constructs a model which emulates quantization during training.
  This allows the model to learn parameters robust to quantization loss, and
  also model the accuracy of a quantized model.

  For more information, see
  https://www.tensorflow.org/model_optimization/guide/quantization/training
  TODO(tfmot): Link blog once launched.

  This function takes a `tf.keras` model in which the desired layers for
  quantization have already been annotated. See `quantize_annotate_model`
  and `quantize_annotate_layer`.

  Example:

  ```python
  model = keras.Sequential([
      layers.Dense(10, activation='relu', input_shape=(100,)),
      quantize_annotate_layer(layers.Dense(2, activation='sigmoid'))
  ])

  # Only the second Dense layer is quantized.
  quantized_model = quantize_apply(model)
  ```

  Note that this function removes the optimizer from the original model.

  The returned model copies over weights from the original model. So while
  it preserves the original weights, training it will not modify the weights
  of the original model.

  Args:
    model: A `tf.keras` Sequential or Functional model which has been annotated
      with `quantize_annotate`. It can have pre-trained weights.
    scheme: A `QuantizeScheme` which specifies transformer and quantization
      registry. The default is `Default8BitQuantizeScheme()`.

  Returns:
    Returns a new `tf.keras` model in which the annotated layers have been
    prepared for quantization.
  """
  if model is None:
    raise ValueError('`model` cannot be None')

  if not isinstance(model, keras.Model):
    raise ValueError('`model` can only be a `tf.keras.Model` instance.'
                     'You passed an instance of type: {input}.'.format(
                         input=model.__class__.__name__))

  if not isinstance(model, keras.Sequential) and not model._is_graph_network:  # pylint: disable=protected-access
    raise ValueError('`model` can only either be a tf.keras Sequential or '
                     'Functional model.')

  # Have at least 1 layer annotated with QuantizeAnnotate
  if not any(isinstance(layer, quantize_annotate_mod.QuantizeAnnotate)
             for layer in model.layers):
    raise ValueError('`model` must contain at least one layer which have been '
                     'annotated with `quantize_annotate*`. There are no layers '
                     'to quantize.')

  if not model.built:
    raise ValueError('`model` must be a built model. '
                     'been built yet. Please call `model.build(input_shape)` '
                     'before quantizing your model.')

  def _clone_model_with_weights(model_to_clone):
    cloned_model = keras.models.clone_model(model_to_clone)
    cloned_model.set_weights(model_to_clone.get_weights())

    return cloned_model

  def _extract_original_model(model_to_unwrap):
    """Extracts original model by removing wrappers."""
    layer_quantize_map = {}
    requires_output_quantize = set()

    def _unwrap(layer):
      if not isinstance(layer, quantize_annotate_mod.QuantizeAnnotate):
        return layer

      annotate_wrapper = layer
      # pylint: disable=protected-access
      if layer._inbound_nodes and len(layer._inbound_nodes) == 1:
        node = layer._inbound_nodes[0]
        inbound_layers = tf.nest.flatten(node.inbound_layers)
        if len(inbound_layers) == 1 and not isinstance(
            inbound_layers[0], quantize_annotate_mod.QuantizeAnnotate):
          requires_output_quantize.add(inbound_layers[0].name)
      # pylint: enable=protected-access

      layer_quantize_map[annotate_wrapper.layer.name] = {
          'quantize_config': annotate_wrapper.quantize_config
      }
      return annotate_wrapper.layer

    unwrapped_model = keras.models.clone_model(
        model_to_unwrap, input_tensors=None, clone_function=_unwrap)

    return unwrapped_model, layer_quantize_map, requires_output_quantize

  def _quantize(layer):  # pylint: disable=missing-docstring
    if ((layer.name not in layer_quantize_map and
         layer.name not in requires_output_quantize) or
        (isinstance(layer, quantize_wrapper.QuantizeWrapper))):
      # It supports for custom QuantizeWrapper.
      return layer

    # layer is a QuantizeLayer, possibly rebuild
    # layer with modified config from parameters stored in the map.
    if isinstance(layer, quantize_layer.QuantizeLayer):
      # If there is more than one usage of the input, even if all are concat,
      # we need to quantize.
      if len(layer._outbound_nodes) > 1:  # pylint: disable=protected-access
        return layer
      layer_config = layer.get_config()
      for key, value in layer_quantize_map[layer.name].items():
        layer_config[key] = value
      return quantize_layer.QuantizeLayer.from_config(layer_config)

    if layer.name in requires_output_quantize:
      if not quantize_registry.supports(layer):
        return layer
      full_quantize_config = quantize_registry.get_quantize_config(layer)
      if not full_quantize_config:
        return layer
      quantize_config = quantize_config_mod.OutputOnlyConfig(
          full_quantize_config)
    else:
      quantize_config = layer_quantize_map[layer.name].get('quantize_config')
      if not quantize_config and quantize_registry.supports(layer):
        quantize_config = quantize_registry.get_quantize_config(layer)

    if not quantize_config:
      error_msg = (
          'Layer {}:{} is not supported. You can quantize this '
          'layer by passing a `tfmot.quantization.keras.QuantizeConfig` '
          'instance to the `quantize_annotate_layer` '
          'API.')
      raise RuntimeError(
          error_msg.format(layer.name, layer.__class__,
                           quantize_registry.__class__))

    # `QuantizeWrapper` does not copy any additional layer params from
    # `QuantizeAnnotate`. This should generally be fine, but occasionally
    # `QuantizeAnnotate` wrapper may contain `batch_input_shape` like params.
    # TODO(pulkitb): Ensure this does not affect model cloning.
    return quantize_wrapper.QuantizeWrapperV2(
        layer, quantize_config)

  # 1. Create a copy of the model with the same weights. This ensures
  # modifications don't affect the original model, or its weights.
  try:
    model_copy = _clone_model_with_weights(model)
  except ValueError:
    raise ValueError(
        'Unable to clone model. This generally happens if you used custom '
        'Keras layers or objects in your model. Please specify them via '
        '`quantize_scope` for your calls to `quantize_model` and '
        '`quantize_apply`.')

  # 2. Remove QuantizeAnnotate wrappers from the layers in the model. This
  # extracts the original model structure (easier to transform), and
  # stores relevant quantization information in a map.
  (unwrapped_model, layer_quantize_map,
   requires_output_quantize) = _extract_original_model(model_copy)
  # Model cloning excludes input layers. Add input layers into the map
  # since they need to be matched for patterns as well.
  # pylint: disable=protected-access
  for input_layer in unwrapped_model._input_layers:
    for outbound_node in input_layer._outbound_nodes:
      if outbound_node.outbound_layer.name in layer_quantize_map:
        layer_quantize_map[input_layer.name] = {}
  # pylint: enable=protected-access

  # 3. Apply the graph transformations required to match model passes on
  # target device/dialect.
  quantize_transform = scheme.get_layout_transformer()
  # layer_quantize_map gets modified by the transformations.
  transformed_model, layer_quantize_map = quantize_transform.apply(
      unwrapped_model, layer_quantize_map)

  # TODO(pulkitb): Think more about how to introduce Default specific code.
  quantize_registry = scheme.get_quantize_registry()

  # 4. Actually quantize all the relevant layers in the model. This is done by
  # wrapping the layers with QuantizeWrapper, and passing the associated
  # `QuantizeConfig`.

  return keras.models.clone_model(
      transformed_model, input_tensors=None, clone_function=_quantize)