in tensorflow_model_analysis/utils/util.py [0:0]
def merge_extracts(extracts: List[types.Extracts]) -> types.Extracts:
"""Merges list of extracts into single extract with multi-dimentional data."""
def merge_with_lists(target: types.Extracts, key: str, value: Any):
"""Merges key and value into the target extracts as a list of values."""
if isinstance(value, Mapping):
if key not in target:
target[key] = {}
target = target[key]
for k, v in value.items():
merge_with_lists(target, k, v)
else:
if key not in target:
target[key] = []
if isinstance(value, np.ndarray):
value = value.tolist()
target[key].append(value)
def merge_lists(target: types.Extracts) -> types.Extracts:
"""Converts target's leaves which are lists to batched np.array's, etc."""
if isinstance(target, Mapping):
result = {}
for key, value in target.items():
try:
result[key] = merge_lists(value)
except Exception as e:
raise RuntimeError(
'Failed to convert value for key "{}"'.format(key)) from e
return {k: merge_lists(v) for k, v in target.items()}
elif target and (isinstance(target[0], tf.compat.v1.SparseTensorValue) or
isinstance(target[0], types.SparseTensorValue)):
t = tf.sparse.concat(
0,
[tf.sparse.expand_dims(to_tensorflow_tensor(t), 0) for t in target])
return to_tensor_value(t)
elif target and isinstance(target[0], types.RaggedTensorValue):
t = tf.concat(
[tf.expand_dims(to_tensorflow_tensor(t), 0) for t in target], 0)
return to_tensor_value(t)
else:
arr = np.array(target)
# Flatten values that were originally single item lists into a single list
# e.g. [[1], [2], [3]] -> [1, 2, 3]
if len(arr.shape) == 2 and arr.shape[1] == 1:
return arr.squeeze(axis=1)
# Special case for empty slice arrays since numpy treats empty tuples as
# arrays with dimension 0.
# e.g. [[()], [()], [()]] -> [(), (), ()]
elif len(arr.shape) == 3 and arr.shape[1] == 1 and arr.shape[2] == 0:
return arr.squeeze(axis=1)
else:
return arr
result = {}
for x in extracts:
for k, v in x.items():
merge_with_lists(result, k, v)
return merge_lists(result)