in tensorflow_model_remediation/min_diff/keras/utils/structure_utils.py [0:0]
def _pack_min_diff_sequence_as(struct, flat_sequence):
# pyformat: disable
"""Pack `flat_sequence` into the same structure as `struct`.
Arguments:
struct: structure that `flat_sequence` will be packed as. Must be a single
element (including a tuple) or an unnested dict.
flat_sequence: Flat sequence of elements to be packed.
Has similar behavior to `tf.nest.pack_sequence_as` with the exception that
tuples in `struct` will be considered as single elements. See
`tf.nest.pack_sequence_as` documentation for additional details on behavior.
Returns:
`flat_sequence` converted to have the same structure as `struct`.
Raises:
ValueError: If `flat_sequence` has a different number of elements from
`struct`. (Note: if `struct` is a dict, keys are used to count elements).
ValueError: If `struct` is not a single element (including a tuple) or a
dict. (Note: If `struct` is a nested dict, the nested values will be
ignored).
"""
# pyformat: enable
if _is_min_diff_element(struct):
if len(flat_sequence) != 1:
raise ValueError(
"The target structure is of type: {}\n\nHowever the input "
"structure is a sequence ({}) of length {}: {}.".format(
type(struct), type(flat_sequence), len(flat_sequence),
flat_sequence))
return flat_sequence[0]
if isinstance(struct, dict):
ordered_keys = sorted(struct.keys())
if len(flat_sequence) != len(ordered_keys):
raise ValueError(
"Could not pack sequence. Dict had {} keys, but flat_sequence had {} "
"element(s). Structure: {}, flat_sequence: {}.".format(
len(ordered_keys), len(flat_sequence), struct, flat_sequence))
return {k: v for k, v in zip(ordered_keys, flat_sequence)}
# If the code reaches here, then `struct` is not a valid MinDiff structure.
# We call `validate_min_diff_structure` to raise the appropriate error.
validate_min_diff_structure(struct)