in gossip/distributed.py [0:0]
def __init__(self, module, device_ids=None, rank=None, world_size=None,
graph=None, mixing=None, comm_device=None, push_sum=True,
overlap=False, synch_freq=0, verbose=False, use_streams=True,
nprocs_per_node=1, local_node_group=None):
super(GossipDataParallel, self).__init__()
# devices available locally
if device_ids is None:
device_ids = list(range(torch.cuda.device_count()))
self.output_device = device_ids[0]
self.device_ids = device_ids
self.nprocs_per_node = nprocs_per_node
if world_size is None or rank is None:
assert dist.is_initialized()
rank = dist.get_rank()
world_size = dist.get_world_size()
self.process_rank = rank
if self.nprocs_per_node > 1:
self.local_rank = self.process_rank % self.nprocs_per_node
world_size //= nprocs_per_node
rank //= nprocs_per_node
if local_node_group is None:
for node in range(world_size):
node_processes_ranks = list(
range(node * self.nprocs_per_node,
(node + 1) * self.nprocs_per_node))
# Process group to communicate between processes on this
# machine
new_local_group = create_process_group(
node_processes_ranks)
if self.process_rank in node_processes_ranks:
self.local_node_group = new_local_group
else:
self.local_node_group = local_node_group
else:
self.local_rank = 0
# put model on output device
self.module = module
first_param_dtype = next(self.module.parameters()).dtype
# prepare local intra-node all-reduce objects
if len(self.device_ids) > 1:
self.broadcast_bucket_size = 10 * 1024 * 1024 # bytes
self.nccl_reduce_bucket_size = 256 * 1024 * 1024 # bytes
self._module_copies = replicate(self.module, self.device_ids,
detach=True)
self._module_copies[0] = self.module
for cmodule in self._module_copies[1:]:
for p, cp in zip(self.module.parameters(),
cmodule.parameters()):
cp.requires_grad = p.requires_grad
else:
self._module_copies = [self.module]
# choose communication device based on backend
if comm_device is None:
cpu_comm = True if dist.get_backend() == 'gloo' else False
comm_device = torch.device('cpu') if cpu_comm else torch.device('cuda')
self.__cpu_comm = comm_device.type == 'cpu'
if graph is None:
graph = NPDDEGraph(
rank, world_size, self.nprocs_per_node, self.local_rank)
if mixing is None:
mixing = UniformMixing(graph, comm_device)
# distributed backend config
self.dist_config = {
'verbose': verbose,
'comm_device': comm_device,
'graph': graph,
'mixing': mixing,
'push_sum': push_sum,
'rank': rank,
'process_rank': self.process_rank,
'world_size': world_size,
'cpu_comm': self.__cpu_comm
}
self.overlap = overlap
self.synch_freq = synch_freq
self.num_updates = 0
self.asynch = synch_freq > 0
# logger used to print to stdout
self.logger = make_logger(rank, verbose)
# push-sum weight=1.0 ==> distributed averaging
self.ps_weight = torch.ones(1, device=comm_device).type(
first_param_dtype)
self.nprocs_per_node_device = torch.tensor(
[self.nprocs_per_node], device=comm_device,
dtype=first_param_dtype)
self.is_ps_numerator = False
# prepare parameters for gossip
self.gossip_enable = True
self.gossiping = False
self.params_mixed = True
self.gossip_ps_factor = torch.zeros(1, device=comm_device).type(
first_param_dtype)
self.gossip_ps_weight = self.ps_weight.clone()
self.gossip_params = []
self.gossip_device_buffer = []
for p in module.parameters():
cp = p.clone().detach_()
cp = cp.cpu().pin_memory() if self.__cpu_comm else cp.cuda()
self.gossip_params.append(cp)
self.gossip_device_buffer.append(cp)
# prepare gossip process control objects
self.gossip_lock = threading.Lock()
self.gossip_flag = threading.Event()
self.train_flag = threading.Event()
if self.dist_config['comm_device'].type != 'cpu' and use_streams:
self.gossip_stream = torch.cuda.Stream()
else:
self.gossip_stream = torch.cuda.current_stream()
if self.process_rank % self.nprocs_per_node == 0:
self.gossip_thread = threading.Thread(
target=GossipDataParallel._gossip_target,
args=(self.dist_config,
self.gossip_flag,
self.train_flag,
self.gossip_lock,
self.gossip_params,
self.gossip_device_buffer,
self.gossip_ps_weight,
self.gossip_ps_factor,
self.gossip_stream))
self.gossip_thread.daemon = True
self.gossip_thread.name = 'Gossip-Thread'
self.gossip_thread.start()
else:
self.gossip_flag.set()
# wait for thread to complete initialization
self.gossip_flag.wait()
self.gossip_flag.clear()
# lazy mixing avoids additional bias/de-bias steps
self.lazy_mixing = (
not self.asynch and self.dist_config['mixing'].is_regular() and
not self.overlap)
self.lazy_ps_factor = self.gossip_ps_factor.clone()
self.logger.debug('lazy mixing: {}'.format(self.lazy_mixing))
# register ps/grad-reduction hooks
self.__register_hooks()