in gossip/ad_psgd.py [0:0]
def __init__(self, module, device_ids=None, master_addr=None,
master_port=None, backend=None, world_size=None, rank=None,
graph_class=None, mixing_class=None, num_peers=1,
comm_device=None, lr=0.1, momentum=0.9, weight_decay=1e-4,
nesterov=True, verbose=True, network_interface_type=None,
tcp_interface_name=None):
super(BilatGossipDataParallel, self).__init__()
# devices available locally
if device_ids is None:
device_ids = list(range(torch.cuda.device_count()))
self.output_device = device_ids[0]
self.device_ids = device_ids
# put model on output device
self.module = module.cuda(self.output_device)
# prepare local intra-node all-reduce objects
if len(self.device_ids) > 1:
self.broadcast_bucket_size = 10 * 1024 * 1024 # bytes
self.nccl_reduce_bucket_size = 256 * 1024 * 1024 # bytes
self._module_copies = replicate(self.module, self.device_ids,
detach=True)
self._module_copies[0] = self.module
for cmodule in self._module_copies[1:]:
for p, cp in zip(self.module.parameters(),
cmodule.parameters()):
cp.requires_grad = p.requires_grad
else:
self._module_copies = [self.module]
# communicate over cpu's if not specified
if comm_device is None:
comm_device = torch.device('cpu')
self.__cpu_comm = comm_device.type == 'cpu'
# distributed backend config
self.dist_config = {
'verbose': verbose,
'graph_class': graph_class,
'master_addr': master_addr,
'master_port': master_port,
'backend': backend,
'world_size': world_size,
'rank': rank,
'mixing_class': mixing_class,
'lr': lr,
'momentum': momentum,
'nesterov': nesterov,
'weight_decay': weight_decay,
'comm_device': comm_device,
'network_interface_type': network_interface_type,
'num_peers': num_peers
}
self.num_updates = 0
# logger used to print to stdout
self.logger = make_logger(rank, verbose)
# prepare parameters for gossip
self.gossip_enable = True
self.gossip_params = []
self.gossip_grads = []
for p in module.parameters():
cp = p.clone().detach_()
cp = cp.cpu().pin_memory() if self.__cpu_comm else cp.cuda()
cp.requires_grad = p.requires_grad
self.gossip_params.append(cp)
if p.requires_grad:
g = cp.clone().zero_().detach_()
g = g.cpu().pin_memory() if self.__cpu_comm else g.cuda()
self.gossip_grads.append(g)
self.gossip_queue = mp.Queue()
self.gossip_lock = mp.Lock()
self.gossip_enable_flag = mp.Event()
self.train_write_flag = mp.Event() # signal train-proc write event
self.gossip_read_flag = mp.Event() # signal gossip-proc read event
self.gossip_update_flag = mp.Event() # signal 2 gossip-proc need update
self._lr = mp.Value('f', lr, lock=self.gossip_lock)
self.gossip_thread = mp.Process(
target=BilatGossipDataParallel._gossip_target,
args=(self.dist_config,
self.gossip_enable_flag,
self.train_write_flag,
self.gossip_read_flag,
self.gossip_update_flag,
self._lr,
self.gossip_lock,
self.gossip_queue,
tcp_interface_name))
self.gossip_thread.daemon = True
self.gossip_thread.name = 'Gossip-Thread'
self.gossip_thread.start()
# pass handle to gossip_params and gossip_grads, and put in shared
# memory
self.gossip_queue.put((self.gossip_params, self.gossip_grads))
# register ps/grad-reduction hooks
self.__register_hooks()