in scripts/tf_cnn_benchmarks/variable_mgr_util.py [0:0]
def aggregate_indexed_slices_gradients(grads):
"""Aggregates gradients containing `IndexedSlices`s."""
if len(grads) < 1:
return None
elif len(grads) == 1:
return grads[0]
else:
grads = [g for g in grads if g is not None]
# If any gradient is a `Tensor`, sum them up and return a dense tensor
# object.
if any(isinstance(g, ops.Tensor) for g in grads):
return math_ops.add_n(grads)
# The following `_as_indexed_slices_list` casts ids of IndexedSlices into
# int64. It is to make sure the inputs of `concat` all have same the data
# type.
grads = math_ops._as_indexed_slices_list(grads) # pylint: disable=protected-access
grads = [flatten_nested_indexed_slices(x) for x in grads]
# Form IndexedSlices out of the concatenated values and indices.
concat_grad = ops.IndexedSlices(
array_ops.concat([x.values for x in grads], axis=0),
array_ops.concat([x.indices for x in grads], axis=0),
grads[0].dense_shape)
return concat_grad