in scripts/tf_cnn_benchmarks/variable_mgr.py [0:0]
def get_post_init_ops(self):
"""Broadcast initialized values of variables to other devices.
Returns:
At task 0 device 0, broadcast_send.
At all other devices and tasks, broadcast_recv.
"""
global_vars = tf.global_variables()
group_size = self._num_workers * self._num_gpus
post_init_ops = []
# Gather variables into same-var-different-device groups.
vars_by_suffix = dict()
for v in global_vars:
split_name = v.name.split('/')
mo = re.match(r'v(\d+)$', split_name[0])
if mo:
device_id = int(mo.group(1))
suffix = '/'.join(split_name[1:])
if suffix in vars_by_suffix.keys():
vars_by_suffix[suffix].append(v)
else:
vars_by_suffix[suffix] = [v]
# Generate broadcast ops for each such group.
for suffix in sorted(vars_by_suffix):
vlist = vars_by_suffix[suffix]
assert self._num_gpus == len(vlist)
devices = [v.device for v in vlist]
# NOTE: this key should generate the same value for all tasks
group_key = allreduce.collective_group_key(devices)
group_size = self._num_workers * len(devices)
instance_key = self._get_instance_key(suffix)
for v in vlist:
split_name = v.name.split('/')
mo = re.match(r'v(\d+)$', split_name[0])
if mo:
device_id = int(mo.group(1))
if (self._task_id == 0 and device_id == 0):
with tf.device(v.device):
bcast_send = allreduce.broadcast_send(
v, v.shape, v.dtype, group_size, group_key, instance_key)
post_init_ops.append(v.assign(bcast_send))
else:
with tf.device(v.device):
bcast_recv = allreduce.broadcast_recv(
v.shape, v.dtype, group_size, group_key, instance_key)
post_init_ops.append(v.assign(bcast_recv))
return post_init_ops