in tensorflow_model_optimization/python/core/internal/tensor_encoding/core/gather_encoder.py [0:0]
def from_encoder(cls, encoder, tensorspec):
"""Creates a `GatherEncoder` for encoding `tensorspec`-like values.
This method instantiates `GatherEncoder`, wrapping the functionality of
`encoder` and exposing necessary logic for encoding values compatible with
`tensorspec`. Note that the returned encoder will not accept inputs of other
properties.
Args:
encoder: An `Encoder` object to be used for encoding.
tensorspec: A `tf.TensorSpec`. The created `GatherEncoder` will be
constrained to only encode input values compatible with `tensorspec`.
Returns:
A `GatherEncoder`.
Raises:
TypeError:
If `encoder` is not an `Encoder` or `tensorspec` is not a
`tf.TensorSpec`.
"""
if not isinstance(encoder, core_encoder.Encoder):
raise TypeError('The encoder must be an instance of `Encoder`.')
if not isinstance(tensorspec, tf.TensorSpec):
raise TypeError('The tensorspec must be a tf.TensorSpec.')
if not tensorspec.shape.is_fully_defined():
raise TypeError('The shape of provided tensorspec must be fully defined.')
commuting_structure = encoder.commuting_structure
state_update_aggregation_modes = tf.nest.flatten(
encoder.state_update_aggregation_modes)
# The following dictionaries are used to carry information known statically
# during exectuion of the Python code (i.e., not the resulting TensorFlow
# computation) between the `tf.function`s created below.
#
# The motivation behind this pattern is the following.
#
# The implementers of the `EncodingStageInterface` should not need to worry
# about distinction between Python and TF values, when declaring certain
# parameters. For instance, the number of quantization bits, can be both a
# TenrosFlow value and a Python integer, and they should not need to be
# handled differently by the implementer.
#
# However, for the user of the `GatherEncoder`, we only want to expose
# values that are actually necessary to be handled outside of this tool.
# That means, only the TF values. We quietly carry the Python values around
# - in the internal_structure and internal_py_values dictionaries - and
# place them at appropriate places at graph building time.
#
# As a consequence, it is impossible to statically determine the user-facing
# signature of encode and decode methods, before we actually execute the
# `get_params` method - the TF structure can depend on internal
# configuration of the implementations of the `EncodingStageInterface`.
#
# A similar problem is we can't determine the signature of the decode
# methods, before executing the encode method, because some implementations
# of `EncodingStageInterface` need the original input_shape as an input to
# their respective `decode` method. Hence, the user facing signature can
# differ based on whether the shape is statically known or not. This
# difference, again, can't be statically determined, without executing the
# part of the relevant encoding tree above a given stage.
#
# The resulting complexity is of the good type, because either type of users
# of the tensor_encoding tool do not even need to be aware of it. This
# argument is well supported for instance in the book of John Ousterhout,
# "A Philosophy of Software Design".
internal_structure = collections.OrderedDict()
internal_py_values = collections.OrderedDict()
def _add_to_structure(key, value):
if key not in internal_structure:
internal_structure[key] = tf.nest.map_structure(lambda _: None, value)
def _add_to_py_values(key, value):
if key not in internal_py_values:
internal_py_values[key] = value
@tf.function
def initial_state_fn():
"""See the `initial_state` method of this class."""
state = encoder.initial_state()
_add_to_structure('state', state)
return tuple(tf.nest.flatten(state))
state = initial_state_fn()
flat_state_spec = tf.nest.map_structure(tf.TensorSpec.from_tensor, state)
@tf.function
def get_params_fn(flat_state):
"""See the `get_params` method of this class."""
py_utils.assert_compatible(flat_state_spec, flat_state)
state = tf.nest.pack_sequence_as(internal_structure['state'], flat_state)
encode_params, decode_params = encoder.get_params(state)
decode_before_sum_params, decode_after_sum_params = (
core_encoder.split_params_by_commuting_structure(
decode_params, commuting_structure))
# Get the portion of input_shapes that will be relevant in the
# decode_after_sum method and fold it into the params exposed to user.
_, _, input_shapes = encoder.encode(
tf.zeros(tensorspec.shape, tensorspec.dtype), encode_params)
_, input_shapes_after_sum = (
core_encoder.split_shapes_by_commuting_structure(
input_shapes, commuting_structure))
decode_after_sum_params = collections.OrderedDict([
(_PARAMS, decode_after_sum_params),
(_SHAPES, input_shapes_after_sum),
])
encode_params_py, encode_params_tf = py_utils.split_dict_py_tf(
encode_params)
decode_before_sum_params_py, decode_before_sum_params_tf = (
py_utils.split_dict_py_tf(decode_before_sum_params))
decode_after_sum_params_py, decode_after_sum_params_tf = (
py_utils.split_dict_py_tf(decode_after_sum_params))
_add_to_structure('encode_params', encode_params_tf)
_add_to_structure('decode_before_sum_params', decode_before_sum_params_tf)
_add_to_structure('decode_after_sum_params', decode_after_sum_params_tf)
_add_to_py_values('encode_params', encode_params_py)
_add_to_py_values('decode_before_sum_params', decode_before_sum_params_py)
_add_to_py_values('decode_after_sum_params', decode_after_sum_params_py)
return (tuple(tf.nest.flatten(encode_params_tf)),
tuple(tf.nest.flatten(decode_before_sum_params_tf)),
tuple(tf.nest.flatten(decode_after_sum_params_tf)))
encode_params, decode_before_sum_params, decode_after_sum_params = (
get_params_fn(state))
encode_params_spec = tf.nest.map_structure(tf.TensorSpec.from_tensor,
encode_params)
decode_before_sum_params_spec = tf.nest.map_structure(
tf.TensorSpec.from_tensor, decode_before_sum_params)
decode_after_sum_params_spec = tf.nest.map_structure(
tf.TensorSpec.from_tensor, decode_after_sum_params)
@tf.function
def encode_fn(x, params):
"""See the `encode` method of this class."""
if not tensorspec.is_compatible_with(x):
raise ValueError(
'The provided x is not compatible with the expected tensorspec.')
py_utils.assert_compatible(encode_params_spec, params)
params = py_utils.merge_dicts(
tf.nest.pack_sequence_as(internal_structure['encode_params'], params),
internal_py_values['encode_params'])
encoded_x, state_update_tensors, input_shapes = encoder.encode(x, params)
input_shapes_before_sum, _ = (
core_encoder.split_shapes_by_commuting_structure(
input_shapes, commuting_structure))
encoded_structure = collections.OrderedDict([
(_TENSORS, encoded_x),
(_SHAPES, input_shapes_before_sum),
])
encoded_structure_py, encoded_structure_tf = py_utils.split_dict_py_tf(
encoded_structure)
_add_to_structure('encoded_structure', encoded_structure_tf)
_add_to_structure('state_update_tensors', state_update_tensors)
_add_to_py_values('encoded_structure', encoded_structure_py)
return (collections.OrderedDict(
py_utils.flatten_with_joined_string_paths(encoded_structure_tf)),
tuple(tf.nest.flatten(state_update_tensors)))
encoded_structure, state_update_tensors = encode_fn(
tf.zeros(tensorspec.shape, tensorspec.dtype), encode_params)
encoded_structure_spec = tf.nest.map_structure(tf.TensorSpec.from_tensor,
encoded_structure)
@tf.function
def decode_before_sum_fn(encoded_structure, params):
"""See the `decode_before_sum` method of this class."""
py_utils.assert_compatible(encoded_structure_spec, encoded_structure)
py_utils.assert_compatible(decode_before_sum_params_spec, params)
encoded_structure = py_utils.merge_dicts(
tf.nest.pack_sequence_as(internal_structure['encoded_structure'],
tf.nest.flatten(encoded_structure)),
internal_py_values['encoded_structure'])
params = py_utils.merge_dicts(
tf.nest.pack_sequence_as(
internal_structure['decode_before_sum_params'], params),
internal_py_values['decode_before_sum_params'])
encoded_tensors = encoded_structure[_TENSORS]
input_shapes = encoded_structure[_SHAPES]
part_decoded_structure = encoder.decode_before_sum(
encoded_tensors, params, input_shapes)
_add_to_structure('part_decoded_structure', part_decoded_structure)
if isinstance(part_decoded_structure, dict):
return collections.OrderedDict(
py_utils.flatten_with_joined_string_paths(part_decoded_structure))
else:
return part_decoded_structure
part_decoded_structure = decode_before_sum_fn(encoded_structure,
decode_before_sum_params)
part_decoded_structure_spec = tf.nest.map_structure(
tf.TensorSpec.from_tensor, part_decoded_structure)
@tf.function
def decode_after_sum_fn(part_decoded_structure, params, num_summands):
"""See the `decode_after_sum` method of this class."""
py_utils.assert_compatible(part_decoded_structure_spec,
part_decoded_structure)
py_utils.assert_compatible(decode_after_sum_params_spec, params)
part_decoded_structure = tf.nest.pack_sequence_as(
internal_structure['part_decoded_structure'],
tf.nest.flatten(part_decoded_structure))
params = py_utils.merge_dicts(
tf.nest.pack_sequence_as(
internal_structure['decode_after_sum_params'], params),
internal_py_values['decode_after_sum_params'])
actual_params = params[_PARAMS]
shapes = params[_SHAPES]
decoded_x = encoder.decode_after_sum(part_decoded_structure,
actual_params, num_summands, shapes)
return decoded_x
decoded_x = decode_after_sum_fn(part_decoded_structure,
decode_after_sum_params, 1)
assert tensorspec.is_compatible_with(decoded_x)
@tf.function
def update_state_fn(flat_state, state_update_tensors):
"""See the `update_state` method of this class."""
py_utils.assert_compatible(flat_state_spec, flat_state)
state = tf.nest.pack_sequence_as(internal_structure['state'], flat_state)
state_update_tensors = tf.nest.pack_sequence_as(
internal_structure['state_update_tensors'], state_update_tensors)
updated_state = encoder.update_state(state, state_update_tensors)
return tuple(tf.nest.flatten(updated_state))
# Ensures the update_state_fn is traced during initialization.
updated_state = update_state_fn(state, state_update_tensors)
tf.nest.assert_same_structure(state, updated_state)
return cls(tensorspec, commuting_structure, state_update_aggregation_modes,
initial_state_fn, get_params_fn, encode_fn, decode_before_sum_fn,
decode_after_sum_fn, update_state_fn)