tensorflow_model_optimization/python/core/quantization/keras/default_8bit/default_8bit_transforms.py [32:104]:
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
LayerNode = transforms.LayerNode
LayerPattern = transforms.LayerPattern

keras = tf.keras


def _get_conv_bn_layers(bn_layer_node):
  bn_layer = bn_layer_node.layer
  conv_layer = bn_layer_node.input_layers[0].layer
  return conv_layer, bn_layer


def _get_weights(bn_layer_node):
  """Returns weight values for fused layer, including copying original values in unfused version."""

  return collections.OrderedDict(
      list(bn_layer_node.input_layers[0].weights.items())
      + list(bn_layer_node.weights.items()))


def _get_params(conv_layer, bn_layer, relu_layer=None):
  """Retrieve conv_bn params within wrapped layers."""
  if 'use_bias' in conv_layer['config']:
    if conv_layer['config']['use_bias']:
      raise ValueError(
          'use_bias should not be set to True in a Conv layer when followed '
          'by BatchNormalization. The bias in the Conv would be redundant '
          'with the one in the BatchNormalization.')

    del conv_layer['config']['use_bias']

  if 'name' in bn_layer['config']:
    del bn_layer['config']['name']

  # TODO(pulkitb): remove key conflicts
  params = dict(
      list(conv_layer['config'].items()) + list(bn_layer['config'].items()))

  if relu_layer is not None:
    params['post_activation'] = keras.layers.deserialize(relu_layer)

  return params


def _get_layer_node(fused_layer, weights):
  layer_config = keras.layers.serialize(fused_layer)
  layer_config['name'] = layer_config['config']['name']
  # This config tracks which layers get quantized, and whether they have a
  # custom QuantizeConfig.
  layer_metadata = {'quantize_config': None}

  return LayerNode(layer_config, weights, metadata=layer_metadata)


def _get_quantize_config(layer_node):
  return layer_node.metadata.get('quantize_config')


def _has_custom_quantize_config(*layer_nodes):
  for layer_node in layer_nodes:
    if _get_quantize_config(layer_node) is not None:
      return True
  return False


def _normalize_tuple(value):
  if isinstance(value, int):
    return (value,)
  else:
    return tuple(value)


class Conv2DBatchNormQuantize(transforms.Transform):
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -



tensorflow_model_optimization/python/core/quantization/keras/experimental/default_n_bit/default_n_bit_transforms.py [31:103]:
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
LayerNode = transforms.LayerNode
LayerPattern = transforms.LayerPattern

keras = tf.keras


def _get_conv_bn_layers(bn_layer_node):
  bn_layer = bn_layer_node.layer
  conv_layer = bn_layer_node.input_layers[0].layer
  return conv_layer, bn_layer


def _get_weights(bn_layer_node):
  """Returns weight values for fused layer, including copying original values in unfused version."""

  return collections.OrderedDict(
      list(bn_layer_node.input_layers[0].weights.items())
      + list(bn_layer_node.weights.items()))


def _get_params(conv_layer, bn_layer, relu_layer=None):
  """Retrieve conv_bn params within wrapped layers."""
  if 'use_bias' in conv_layer['config']:
    if conv_layer['config']['use_bias']:
      raise ValueError(
          'use_bias should not be set to True in a Conv layer when followed '
          'by BatchNormalization. The bias in the Conv would be redundant '
          'with the one in the BatchNormalization.')

    del conv_layer['config']['use_bias']

  if 'name' in bn_layer['config']:
    del bn_layer['config']['name']

  # TODO(pulkitb): remove key conflicts
  params = dict(
      list(conv_layer['config'].items()) + list(bn_layer['config'].items()))

  if relu_layer is not None:
    params['post_activation'] = keras.layers.deserialize(relu_layer)

  return params


def _get_layer_node(fused_layer, weights):
  layer_config = keras.layers.serialize(fused_layer)
  layer_config['name'] = layer_config['config']['name']
  # This config tracks which layers get quantized, and whether they have a
  # custom QuantizeConfig.
  layer_metadata = {'quantize_config': None}

  return LayerNode(layer_config, weights, metadata=layer_metadata)


def _get_quantize_config(layer_node):
  return layer_node.metadata.get('quantize_config')


def _has_custom_quantize_config(*layer_nodes):
  for layer_node in layer_nodes:
    if _get_quantize_config(layer_node) is not None:
      return True
  return False


def _normalize_tuple(value):
  if isinstance(value, int):
    return (value,)
  else:
    return tuple(value)


class Conv2DBatchNormQuantize(transforms.Transform):
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -



