def _pack_min_diff_sequence_as()

in tensorflow_model_remediation/min_diff/keras/utils/structure_utils.py [0:0]


def _pack_min_diff_sequence_as(struct, flat_sequence):
  # pyformat: disable
  """Pack `flat_sequence` into the same structure as `struct`.

  Arguments:
    struct: structure that `flat_sequence` will be packed as. Must be a single
      element (including a tuple) or an unnested dict.
    flat_sequence: Flat sequence of elements to be packed.

  Has similar behavior to `tf.nest.pack_sequence_as` with the exception that
  tuples in `struct` will be considered as single elements. See
  `tf.nest.pack_sequence_as` documentation for additional details on behavior.

  Returns:
    `flat_sequence` converted to have the same structure as `struct`.

  Raises:
    ValueError: If `flat_sequence` has a different number of elements from
      `struct`. (Note: if `struct` is a dict, keys are used to count elements).
    ValueError: If `struct` is not a single element (including a tuple) or a
      dict. (Note: If `struct` is a nested dict, the nested values will be
      ignored).
  """
  # pyformat: enable
  if _is_min_diff_element(struct):
    if len(flat_sequence) != 1:
      raise ValueError(
          "The target structure is of type: {}\n\nHowever the input "
          "structure is a sequence ({}) of length {}: {}.".format(
              type(struct), type(flat_sequence), len(flat_sequence),
              flat_sequence))
    return flat_sequence[0]

  if isinstance(struct, dict):
    ordered_keys = sorted(struct.keys())
    if len(flat_sequence) != len(ordered_keys):
      raise ValueError(
          "Could not pack sequence. Dict had {} keys, but flat_sequence had {} "
          "element(s).  Structure: {}, flat_sequence: {}.".format(
              len(ordered_keys), len(flat_sequence), struct, flat_sequence))

    return {k: v for k, v in zip(ordered_keys, flat_sequence)}

  # If the code reaches here, then `struct` is not a valid MinDiff structure.
  # We call `validate_min_diff_structure` to raise the appropriate error.
  validate_min_diff_structure(struct)