def from_encoder()

in tensorflow_model_optimization/python/core/internal/tensor_encoding/core/gather_encoder.py [0:0]


  def from_encoder(cls, encoder, tensorspec):
    """Creates a `GatherEncoder` for encoding `tensorspec`-like values.

    This method instantiates `GatherEncoder`, wrapping the functionality of
    `encoder` and exposing necessary logic for encoding values compatible with
    `tensorspec`. Note that the returned encoder will not accept inputs of other
    properties.

    Args:
      encoder: An `Encoder` object to be used for encoding.
      tensorspec: A `tf.TensorSpec`. The created `GatherEncoder` will be
        constrained to only encode input values compatible with `tensorspec`.

    Returns:
      A `GatherEncoder`.

    Raises:
      TypeError:
        If `encoder` is not an `Encoder` or `tensorspec` is not a
        `tf.TensorSpec`.
    """
    if not isinstance(encoder, core_encoder.Encoder):
      raise TypeError('The encoder must be an instance of `Encoder`.')
    if not isinstance(tensorspec, tf.TensorSpec):
      raise TypeError('The tensorspec must be a tf.TensorSpec.')
    if not tensorspec.shape.is_fully_defined():
      raise TypeError('The shape of provided tensorspec must be fully defined.')

    commuting_structure = encoder.commuting_structure
    state_update_aggregation_modes = tf.nest.flatten(
        encoder.state_update_aggregation_modes)

    # The following dictionaries are used to carry information known statically
    # during exectuion of the Python code (i.e., not the resulting TensorFlow
    # computation) between the `tf.function`s created below.
    #
    # The motivation behind this pattern is the following.
    #
    # The implementers of the `EncodingStageInterface` should not need to worry
    # about distinction between Python and TF values, when declaring certain
    # parameters. For instance, the number of quantization bits, can be both a
    # TenrosFlow value and a Python integer, and they should not need to be
    # handled differently by the implementer.
    #
    # However, for the user of the `GatherEncoder`, we only want to expose
    # values that are actually necessary to be handled outside of this tool.
    # That means, only the TF values. We quietly carry the Python values around
    # - in the internal_structure and internal_py_values dictionaries - and
    # place them at appropriate places at graph building time.
    #
    # As a consequence, it is impossible to statically determine the user-facing
    # signature of encode and decode methods, before we actually execute the
    # `get_params` method - the TF structure can depend on internal
    # configuration of the implementations of the `EncodingStageInterface`.
    #
    # A similar problem is we can't determine the signature of the decode
    # methods, before executing the encode method, because some implementations
    # of `EncodingStageInterface` need the original input_shape as an input to
    # their respective `decode` method. Hence, the user facing signature can
    # differ based on whether the shape is statically known or not. This
    # difference, again, can't be statically determined, without executing the
    # part of the relevant encoding tree above a given stage.
    #
    # The resulting complexity is of the good type, because either type of users
    # of the tensor_encoding tool do not even need to be aware of it. This
    # argument is well supported for instance in the book of John Ousterhout,
    # "A Philosophy of Software Design".
    internal_structure = collections.OrderedDict()
    internal_py_values = collections.OrderedDict()

    def _add_to_structure(key, value):
      if key not in internal_structure:
        internal_structure[key] = tf.nest.map_structure(lambda _: None, value)

    def _add_to_py_values(key, value):
      if key not in internal_py_values:
        internal_py_values[key] = value

    @tf.function
    def initial_state_fn():
      """See the `initial_state` method of this class."""
      state = encoder.initial_state()
      _add_to_structure('state', state)
      return tuple(tf.nest.flatten(state))

    state = initial_state_fn()
    flat_state_spec = tf.nest.map_structure(tf.TensorSpec.from_tensor, state)

    @tf.function
    def get_params_fn(flat_state):
      """See the `get_params` method of this class."""
      py_utils.assert_compatible(flat_state_spec, flat_state)
      state = tf.nest.pack_sequence_as(internal_structure['state'], flat_state)

      encode_params, decode_params = encoder.get_params(state)
      decode_before_sum_params, decode_after_sum_params = (
          core_encoder.split_params_by_commuting_structure(
              decode_params, commuting_structure))

      # Get the portion of input_shapes that will be relevant in the
      # decode_after_sum method and fold it into the params exposed to user.
      _, _, input_shapes = encoder.encode(
          tf.zeros(tensorspec.shape, tensorspec.dtype), encode_params)
      _, input_shapes_after_sum = (
          core_encoder.split_shapes_by_commuting_structure(
              input_shapes, commuting_structure))
      decode_after_sum_params = collections.OrderedDict([
          (_PARAMS, decode_after_sum_params),
          (_SHAPES, input_shapes_after_sum),
      ])

      encode_params_py, encode_params_tf = py_utils.split_dict_py_tf(
          encode_params)
      decode_before_sum_params_py, decode_before_sum_params_tf = (
          py_utils.split_dict_py_tf(decode_before_sum_params))
      decode_after_sum_params_py, decode_after_sum_params_tf = (
          py_utils.split_dict_py_tf(decode_after_sum_params))

      _add_to_structure('encode_params', encode_params_tf)
      _add_to_structure('decode_before_sum_params', decode_before_sum_params_tf)
      _add_to_structure('decode_after_sum_params', decode_after_sum_params_tf)
      _add_to_py_values('encode_params', encode_params_py)
      _add_to_py_values('decode_before_sum_params', decode_before_sum_params_py)
      _add_to_py_values('decode_after_sum_params', decode_after_sum_params_py)

      return (tuple(tf.nest.flatten(encode_params_tf)),
              tuple(tf.nest.flatten(decode_before_sum_params_tf)),
              tuple(tf.nest.flatten(decode_after_sum_params_tf)))

    encode_params, decode_before_sum_params, decode_after_sum_params = (
        get_params_fn(state))
    encode_params_spec = tf.nest.map_structure(tf.TensorSpec.from_tensor,
                                               encode_params)
    decode_before_sum_params_spec = tf.nest.map_structure(
        tf.TensorSpec.from_tensor, decode_before_sum_params)
    decode_after_sum_params_spec = tf.nest.map_structure(
        tf.TensorSpec.from_tensor, decode_after_sum_params)

    @tf.function
    def encode_fn(x, params):
      """See the `encode` method of this class."""
      if not tensorspec.is_compatible_with(x):
        raise ValueError(
            'The provided x is not compatible with the expected tensorspec.')
      py_utils.assert_compatible(encode_params_spec, params)

      params = py_utils.merge_dicts(
          tf.nest.pack_sequence_as(internal_structure['encode_params'], params),
          internal_py_values['encode_params'])
      encoded_x, state_update_tensors, input_shapes = encoder.encode(x, params)
      input_shapes_before_sum, _ = (
          core_encoder.split_shapes_by_commuting_structure(
              input_shapes, commuting_structure))

      encoded_structure = collections.OrderedDict([
          (_TENSORS, encoded_x),
          (_SHAPES, input_shapes_before_sum),
      ])
      encoded_structure_py, encoded_structure_tf = py_utils.split_dict_py_tf(
          encoded_structure)

      _add_to_structure('encoded_structure', encoded_structure_tf)
      _add_to_structure('state_update_tensors', state_update_tensors)
      _add_to_py_values('encoded_structure', encoded_structure_py)

      return (collections.OrderedDict(
          py_utils.flatten_with_joined_string_paths(encoded_structure_tf)),
              tuple(tf.nest.flatten(state_update_tensors)))

    encoded_structure, state_update_tensors = encode_fn(
        tf.zeros(tensorspec.shape, tensorspec.dtype), encode_params)
    encoded_structure_spec = tf.nest.map_structure(tf.TensorSpec.from_tensor,
                                                   encoded_structure)

    @tf.function
    def decode_before_sum_fn(encoded_structure, params):
      """See the `decode_before_sum` method of this class."""
      py_utils.assert_compatible(encoded_structure_spec, encoded_structure)
      py_utils.assert_compatible(decode_before_sum_params_spec, params)

      encoded_structure = py_utils.merge_dicts(
          tf.nest.pack_sequence_as(internal_structure['encoded_structure'],
                                   tf.nest.flatten(encoded_structure)),
          internal_py_values['encoded_structure'])
      params = py_utils.merge_dicts(
          tf.nest.pack_sequence_as(
              internal_structure['decode_before_sum_params'], params),
          internal_py_values['decode_before_sum_params'])

      encoded_tensors = encoded_structure[_TENSORS]
      input_shapes = encoded_structure[_SHAPES]
      part_decoded_structure = encoder.decode_before_sum(
          encoded_tensors, params, input_shapes)

      _add_to_structure('part_decoded_structure', part_decoded_structure)
      if isinstance(part_decoded_structure, dict):
        return collections.OrderedDict(
            py_utils.flatten_with_joined_string_paths(part_decoded_structure))
      else:
        return part_decoded_structure

    part_decoded_structure = decode_before_sum_fn(encoded_structure,
                                                  decode_before_sum_params)
    part_decoded_structure_spec = tf.nest.map_structure(
        tf.TensorSpec.from_tensor, part_decoded_structure)

    @tf.function
    def decode_after_sum_fn(part_decoded_structure, params, num_summands):
      """See the `decode_after_sum` method of this class."""
      py_utils.assert_compatible(part_decoded_structure_spec,
                                 part_decoded_structure)
      py_utils.assert_compatible(decode_after_sum_params_spec, params)

      part_decoded_structure = tf.nest.pack_sequence_as(
          internal_structure['part_decoded_structure'],
          tf.nest.flatten(part_decoded_structure))
      params = py_utils.merge_dicts(
          tf.nest.pack_sequence_as(
              internal_structure['decode_after_sum_params'], params),
          internal_py_values['decode_after_sum_params'])
      actual_params = params[_PARAMS]
      shapes = params[_SHAPES]
      decoded_x = encoder.decode_after_sum(part_decoded_structure,
                                           actual_params, num_summands, shapes)
      return decoded_x

    decoded_x = decode_after_sum_fn(part_decoded_structure,
                                    decode_after_sum_params, 1)
    assert tensorspec.is_compatible_with(decoded_x)

    @tf.function
    def update_state_fn(flat_state, state_update_tensors):
      """See the `update_state` method of this class."""
      py_utils.assert_compatible(flat_state_spec, flat_state)
      state = tf.nest.pack_sequence_as(internal_structure['state'], flat_state)
      state_update_tensors = tf.nest.pack_sequence_as(
          internal_structure['state_update_tensors'], state_update_tensors)
      updated_state = encoder.update_state(state, state_update_tensors)
      return tuple(tf.nest.flatten(updated_state))

    # Ensures the update_state_fn is traced during initialization.
    updated_state = update_state_fn(state, state_update_tensors)
    tf.nest.assert_same_structure(state, updated_state)

    return cls(tensorspec, commuting_structure, state_update_aggregation_modes,
               initial_state_fn, get_params_fn, encode_fn, decode_before_sum_fn,
               decode_after_sum_fn, update_state_fn)