def __init__()

in tensorflow_model_optimization/python/core/internal/tensor_encoding/core/simple_encoder.py [0:0]


  def __init__(self, encoder, tensorspec):
    """Creates a `SimpleEncoder` for encoding `tensorspec`-like values.

    This method instantiates `SimpleEncoder`, wrapping the functionality of
    `encoder` and exposing necessary logic for encoding values compatible with
    `tensorspec`. Note that the returned encoder will not accept inputs of other
    properties.

    Args:
      encoder: An `Encoder` object to be used for encoding.
      tensorspec: A `tf.TensorSpec`. The created `SimpleEncoder` will be
        constrained to only encode input values compatible with `tensorspec`.

    Returns:
      A `SimpleEncoder`.

    Raises:
      TypeError:
        If `encoder` is not an `Encoder` or `tensorspec` is not a
        `tf.TensorSpec`.
    """
    if not isinstance(encoder, core_encoder.Encoder):
      raise TypeError('The encoder must be an instance of `Encoder`.')
    if not isinstance(tensorspec, tf.TensorSpec):
      raise TypeError('The tensorspec must be a tf.TensorSpec.')
    if not tensorspec.shape.is_fully_defined():
      raise TypeError('The shape of provided tensorspec must be fully defined.')
    self._tensorspec = tensorspec

    # These dictionaries are filled inside of the initial_state_fn and encode_fn
    # methods, to be used in encode_fn and decode_fn methods, respectively.
    # Decorated by tf.function, their necessary side effects are realized during
    # call to get_concrete_function().
    state_py_structure = collections.OrderedDict()
    encoded_py_structure = collections.OrderedDict()

    @tf.function
    def initial_state_fn():
      state = encoder.initial_state()
      if not state_py_structure:
        state_py_structure['state'] = tf.nest.map_structure(
            lambda _: None, state)
      # Simplify the structure that needs to be manipulated by the user.
      return tuple(tf.nest.flatten(state))

    @tf.function(input_signature=[
        tensorspec,
        tf.nest.map_structure(
            tf.TensorSpec.from_tensor,
            initial_state_fn.get_concrete_function().structured_outputs)
    ])  # pylint: disable=missing-docstring
    def encode_fn(x, flat_state):
      state = tf.nest.pack_sequence_as(state_py_structure['state'], flat_state)
      encode_params, decode_params = encoder.get_params(state)
      encoded_x, state_update_tensors, input_shapes = encoder.encode(
          x, encode_params)
      updated_flat_state = tuple(
          tf.nest.flatten(encoder.update_state(state, state_update_tensors)))

      # The following code converts the nested structres necessary for the
      # underlying encoder, to a single flat dictionary, which is simpler to
      # manipulate by the users of SimpleEncoder.
      full_encoded_structure = collections.OrderedDict([
          (_TENSORS, encoded_x),
          (_PARAMS, decode_params),
          (_SHAPES, input_shapes),
      ])
      flat_encoded_structure = collections.OrderedDict(
          py_utils.flatten_with_joined_string_paths(full_encoded_structure))
      flat_encoded_py_structure, flat_encoded_tf_structure = (
          py_utils.split_dict_py_tf(flat_encoded_structure))

      if not encoded_py_structure:
        encoded_py_structure['full'] = tf.nest.map_structure(
            lambda _: None, full_encoded_structure)
        encoded_py_structure['flat_py'] = flat_encoded_py_structure
      return flat_encoded_tf_structure, updated_flat_state

    @tf.function(input_signature=[
        tf.nest.map_structure(
            tf.TensorSpec.from_tensor,
            encode_fn.get_concrete_function().structured_outputs[0])
    ])  # pylint: disable=missing-docstring
    def decode_fn(encoded_structure):
      encoded_structure = py_utils.merge_dicts(encoded_structure,
                                               encoded_py_structure['flat_py'])
      encoded_structure = tf.nest.pack_sequence_as(
          encoded_py_structure['full'], tf.nest.flatten(encoded_structure))
      return encoder.decode(encoded_structure[_TENSORS],
                            encoded_structure[_PARAMS],
                            encoded_structure[_SHAPES])

    # Ensures the decode_fn is traced during initialization.
    decode_fn.get_concrete_function()

    self._initial_state_fn = initial_state_fn
    self._encode_fn = encode_fn
    self._decode_fn = decode_fn