def merge_extracts()

in tensorflow_model_analysis/utils/util.py [0:0]


def merge_extracts(extracts: List[types.Extracts]) -> types.Extracts:
  """Merges list of extracts into single extract with multi-dimentional data."""

  def merge_with_lists(target: types.Extracts, key: str, value: Any):
    """Merges key and value into the target extracts as a list of values."""
    if isinstance(value, Mapping):
      if key not in target:
        target[key] = {}
      target = target[key]
      for k, v in value.items():
        merge_with_lists(target, k, v)
    else:
      if key not in target:
        target[key] = []
      if isinstance(value, np.ndarray):
        value = value.tolist()
      target[key].append(value)

  def merge_lists(target: types.Extracts) -> types.Extracts:
    """Converts target's leaves which are lists to batched np.array's, etc."""
    if isinstance(target, Mapping):
      result = {}
      for key, value in target.items():
        try:
          result[key] = merge_lists(value)
        except Exception as e:
          raise RuntimeError(
              'Failed to convert value for key "{}"'.format(key)) from e
      return {k: merge_lists(v) for k, v in target.items()}
    elif target and (isinstance(target[0], tf.compat.v1.SparseTensorValue) or
                     isinstance(target[0], types.SparseTensorValue)):
      t = tf.sparse.concat(
          0,
          [tf.sparse.expand_dims(to_tensorflow_tensor(t), 0) for t in target])
      return to_tensor_value(t)
    elif target and isinstance(target[0], types.RaggedTensorValue):
      t = tf.concat(
          [tf.expand_dims(to_tensorflow_tensor(t), 0) for t in target], 0)
      return to_tensor_value(t)
    else:
      arr = np.array(target)
      # Flatten values that were originally single item lists into a single list
      # e.g. [[1], [2], [3]] -> [1, 2, 3]
      if len(arr.shape) == 2 and arr.shape[1] == 1:
        return arr.squeeze(axis=1)
      # Special case for empty slice arrays since numpy treats empty tuples as
      # arrays with dimension 0.
      # e.g. [[()], [()], [()]] -> [(), (), ()]
      elif len(arr.shape) == 3 and arr.shape[1] == 1 and arr.shape[2] == 0:
        return arr.squeeze(axis=1)
      else:
        return arr

  result = {}
  for x in extracts:
    for k, v in x.items():
      merge_with_lists(result, k, v)
  return merge_lists(result)