mpi_utils.py (88 lines of code) (raw):
from mpi4py import MPI
import numpy as np
import tensorflow as tf
import blocksparse as bs
from blocksparse import nccl
def mpi_init(initializer):
'Variable initializer for MPI. Used such that allreduce '
'syncs variables at the beginning of training. '
'This is better than multiplying the values by 0, which requires'
'extra memory. Alternatively, a broadcast can be used.'
if mpi_rank() == 0:
return initializer
return tf.zeros_initializer()
def random_or_zeros_init(stddev):
return mpi_init(tf.random_normal_initializer(stddev=stddev))
def constant_or_zeros_init(constant):
return mpi_init(tf.constant_initializer(constant))
def zeros_init():
return tf.zeros_initializer()
def num_comms():
# perhaps make this editable later
return 2
def mpi_size():
return MPI.COMM_WORLD.Get_size()
def mpi_rank():
return MPI.COMM_WORLD.Get_rank()
def num_nodes():
# works only w 8 gpu nodes
if mpi_size() > 8:
return mpi_size() // 8
return 1
def gpus_per_node():
size = mpi_size()
if size > 1:
return max(size // num_nodes(), 1)
return 1
def local_mpi_rank():
return mpi_rank() % gpus_per_node()
def prereduce_size():
if mpi_size() > 8:
if mpi_size() % num_nodes() != 0:
raise ValueError('MPI size not evenly divisible across nodes')
return gpus_per_node()
return 0
def allreduce(val):
if mpi_size() == 1:
return val
return nccl.allreduce(val, num_comms=num_comms(), prereduce=prereduce_size())
def sync_variables(sess):
sess.run(bs.nccl.sync_globals_zero_init_op(
num_comms=num_comms(), prereduce=prereduce_size()))
def group_allreduce(grads, params, search_strings=None, cast_all=None):
if mpi_size() == 1:
return grads
return nccl.group_allreduce(
grads, params,
search_strings=search_strings,
cast_all=cast_all,
num_comms=num_comms(),
prereduce=prereduce_size())
def mpi_dtype(dtype):
return {
"float32": MPI.FLOAT,
"float64": MPI.DOUBLE,
"int8": MPI.CHAR,
"uint8": MPI.UNSIGNED_CHAR,
"int16": MPI.SHORT,
"uint16": MPI.UNSIGNED_SHORT,
"int32": MPI.INT,
"uint32": MPI.UNSIGNED,
"int64": MPI.LONG,
"uint64": MPI.UNSIGNED_LONG,
}[dtype]
def mpi_barrier():
MPI.COMM_WORLD.Barrier()
def mpi_allgather(arr):
comm = MPI.COMM_WORLD
n = comm.Get_size()
bs, *other = arr.shape
out = np.zeros((bs * n, *other), dtype=arr.dtype)
dtype = mpi_dtype(arr.dtype.name)
comm.Allgather([arr, dtype], [out, dtype])
return out
def get_session(mpi=True, disable_swapping=True, log=print):
config = tf.ConfigProto()
# if mpi:
# log('local rank', local_mpi_rank(), 'rank', mpi_rank())
# config.gpu_options.visible_device_list = str(local_mpi_rank())
config.allow_soft_placement = False
if disable_swapping:
# Disables the swapping heuristic used by TF to reduce memory;
# it is faster to recompute gradients rather than swap out params
config.graph_options.rewrite_options.memory_optimization = 1
# Dont need the timeout session if mpi4py is used when invoking mpi
# sess = TimeoutSession(timeout=timeout, config=config, log=log)
sess = tf.Session(config=config)
return sess