in tensorflow_model_optimization/python/core/internal/tensor_encoding/core/simple_encoder.py [0:0]
def __init__(self, encoder, tensorspec):
"""Creates a `SimpleEncoder` for encoding `tensorspec`-like values.
This method instantiates `SimpleEncoder`, wrapping the functionality of
`encoder` and exposing necessary logic for encoding values compatible with
`tensorspec`. Note that the returned encoder will not accept inputs of other
properties.
Args:
encoder: An `Encoder` object to be used for encoding.
tensorspec: A `tf.TensorSpec`. The created `SimpleEncoder` will be
constrained to only encode input values compatible with `tensorspec`.
Returns:
A `SimpleEncoder`.
Raises:
TypeError:
If `encoder` is not an `Encoder` or `tensorspec` is not a
`tf.TensorSpec`.
"""
if not isinstance(encoder, core_encoder.Encoder):
raise TypeError('The encoder must be an instance of `Encoder`.')
if not isinstance(tensorspec, tf.TensorSpec):
raise TypeError('The tensorspec must be a tf.TensorSpec.')
if not tensorspec.shape.is_fully_defined():
raise TypeError('The shape of provided tensorspec must be fully defined.')
self._tensorspec = tensorspec
# These dictionaries are filled inside of the initial_state_fn and encode_fn
# methods, to be used in encode_fn and decode_fn methods, respectively.
# Decorated by tf.function, their necessary side effects are realized during
# call to get_concrete_function().
state_py_structure = collections.OrderedDict()
encoded_py_structure = collections.OrderedDict()
@tf.function
def initial_state_fn():
state = encoder.initial_state()
if not state_py_structure:
state_py_structure['state'] = tf.nest.map_structure(
lambda _: None, state)
# Simplify the structure that needs to be manipulated by the user.
return tuple(tf.nest.flatten(state))
@tf.function(input_signature=[
tensorspec,
tf.nest.map_structure(
tf.TensorSpec.from_tensor,
initial_state_fn.get_concrete_function().structured_outputs)
]) # pylint: disable=missing-docstring
def encode_fn(x, flat_state):
state = tf.nest.pack_sequence_as(state_py_structure['state'], flat_state)
encode_params, decode_params = encoder.get_params(state)
encoded_x, state_update_tensors, input_shapes = encoder.encode(
x, encode_params)
updated_flat_state = tuple(
tf.nest.flatten(encoder.update_state(state, state_update_tensors)))
# The following code converts the nested structres necessary for the
# underlying encoder, to a single flat dictionary, which is simpler to
# manipulate by the users of SimpleEncoder.
full_encoded_structure = collections.OrderedDict([
(_TENSORS, encoded_x),
(_PARAMS, decode_params),
(_SHAPES, input_shapes),
])
flat_encoded_structure = collections.OrderedDict(
py_utils.flatten_with_joined_string_paths(full_encoded_structure))
flat_encoded_py_structure, flat_encoded_tf_structure = (
py_utils.split_dict_py_tf(flat_encoded_structure))
if not encoded_py_structure:
encoded_py_structure['full'] = tf.nest.map_structure(
lambda _: None, full_encoded_structure)
encoded_py_structure['flat_py'] = flat_encoded_py_structure
return flat_encoded_tf_structure, updated_flat_state
@tf.function(input_signature=[
tf.nest.map_structure(
tf.TensorSpec.from_tensor,
encode_fn.get_concrete_function().structured_outputs[0])
]) # pylint: disable=missing-docstring
def decode_fn(encoded_structure):
encoded_structure = py_utils.merge_dicts(encoded_structure,
encoded_py_structure['flat_py'])
encoded_structure = tf.nest.pack_sequence_as(
encoded_py_structure['full'], tf.nest.flatten(encoded_structure))
return encoder.decode(encoded_structure[_TENSORS],
encoded_structure[_PARAMS],
encoded_structure[_SHAPES])
# Ensures the decode_fn is traced during initialization.
decode_fn.get_concrete_function()
self._initial_state_fn = initial_state_fn
self._encode_fn = encode_fn
self._decode_fn = decode_fn