in scripts/tf_cnn_benchmarks/batch_allreduce.py [0:0]
def batch_all_reduce(self,
all_device_tensors,
num_splits,
compact_tensors,
defer_tensors,
xla_compile=False):
"""Performs a batch all-reduce.
The reduction done is a sum.
`all_device_tensors` is a list of list of tensors that will be batch
all-reduced. All tensors within a single inner list must be on the same
device. The nth element in each list, for any n, will be reduced together.
The return value is in the same form as `all_device_tensors`, except that
each tensor is reduced.
For example, if `all_device_tensors` is:
[[ A, B ], # A and B are on GPU 0
[ C, D ]] # C and D are on GPU 1
Then the return value will be:
[[ A+C, B+D ], # These two tensors are on GPU 0
[ A+C, B+D ]] # These two tensors are on GPU 1
Args:
all_device_tensors: A list of list of tensors. `all_device_tensors[i][j]`
is a tensor where `i` is the device index and `j` is the tensor index.
num_splits: If not None, tensors will be concatenated and split into this
many pieces during the all-reduce, then split back into their original
shapes afterwards. Has no impact on correctness and can improve
performance. Requires all tensors to be the same type.
compact_tensors: If True, tensors are casted to fp16 before being all-
reduced. Improves performance, but hurts numerical stability.
defer_tensors: If True, every time the return value
`reduced_all_device_tensors` is evaluated, the result will be the
reduced tensors values of `all_device_tensors` from the previous session
run instead of the current session run, or zero on the first session
run. This can improve performance. When training neural networks,
deferring gradients often does not harm training, so this can be used to
improve performance.
xla_compile: If True, use XLA to compile gradients packing and unpacking
ops.
Returns:
reduced_all_device_tensors: A list in the same form as
`all_device_tensors`, except each tensor has been reduced.
warmup_ops: A list of ops needed to be run once before the all-reduce can
occur.
"""
# Before all-reducing tensors, we do several preprocessing functions that
# can speed up the all-reduce. We undo these functions after all-reducing
# the tensors.
# all_device_packed_tensors is a 2-d list of tensors indexed by
# [device_id][tensor_id], holding packed tensors from all devices involved
# in all-reduce.
all_device_packed_tensors = []
# all_device_warmup_ops is a 2-d list of ops indexed by
# [device_id][tensor_id], holding warmup_ops that need to be run once before
# all-reduce can occur.
all_device_warmup_ops = []
# all_device_put_ops is a 2-d list of ops indexed by
# [device_id][tensor_id], holding put ops for deferred tensors. They will be
# called in each all-reduce step automatically due to control dependency.
all_device_put_ops = []
# packers is a list of _TensorPacker, one for each device involved in
# all-reduce.
packers = [
_TensorPacker(num_splits, compact_tensors) for _ in all_device_tensors
]
for packer, device_tensors in zip(packers, all_device_tensors):
def pack_single_device_tensors(packer=packer,
device_tensors=device_tensors):
"""Pack gradient tensors of a device."""
packed_tensors = packer.maybe_concat_tensors(device_tensors)
packed_tensors = packer.maybe_compact_tensors(packed_tensors)
# When xla_compile=False, defer tensors after concat for better
# performance.
if defer_tensors and not xla_compile:
packed_tensors, put_ops, warmup_ops = defer_single_device_tensors(
packed_tensors)
all_device_put_ops.append(put_ops)
all_device_warmup_ops.append(warmup_ops)
packed_tensors = packer.maybe_split_tensors(packed_tensors)
return packed_tensors
with tf.device(device_tensors[0].device):
if xla_compile:
packed_tensors = tf.xla.experimental.compile(
pack_single_device_tensors)
# When xla_compile=True, intermediate tensors in packing process are
# not materialized. Thus, we defer tensors after packing process is
# completed instead of in the middle of it.
if defer_tensors:
packed_tensors, put_ops, warmup_ops = defer_single_device_tensors(
packed_tensors)
all_device_put_ops.append(put_ops)
all_device_warmup_ops.append(warmup_ops)
else:
packed_tensors = pack_single_device_tensors()
all_device_packed_tensors.append(packed_tensors)
# Perform all-reduce on packed tensors.
all_device_tensors = self._do_batch_all_reduce(all_device_packed_tensors)
all_device_unpacked_tensors = []
for packer, device_tensors in zip(packers, all_device_tensors):
def unpack_single_device_tensors(packer=packer,
device_tensors=device_tensors):
"""Unpack gradient tensors of a device."""
unpacked_tensors = packer.undo_maybe_split_tensors(device_tensors)
unpacked_tensors = packer.undo_maybe_compact_tensors(unpacked_tensors)
unpacked_tensors = packer.undo_maybe_concat_tensors(unpacked_tensors)
return unpacked_tensors
with tf.device(device_tensors[0].device):
if xla_compile:
unpacked_device_tensor = tf.xla.experimental.compile(
unpack_single_device_tensors)
else:
unpacked_device_tensor = unpack_single_device_tensors()
all_device_unpacked_tensors.append(unpacked_device_tensor)
# Note: There is no undo operation for deferring tensors. But we do need to
# call _add_put_op_control_deps at the end if we deferred the tensors.
if defer_tensors:
all_device_unpacked_tensors = _add_put_op_control_deps(
all_device_unpacked_tensors, num_splits, all_device_put_ops)
return all_device_unpacked_tensors, all_device_warmup_ops