def __init__()

in gossip/distributed.py [0:0]


    def __init__(self, module, device_ids=None, rank=None, world_size=None,
                 graph=None, mixing=None, comm_device=None, push_sum=True,
                 overlap=False, synch_freq=0, verbose=False, use_streams=True,
                 nprocs_per_node=1, local_node_group=None):
        super(GossipDataParallel, self).__init__()

        # devices available locally
        if device_ids is None:
            device_ids = list(range(torch.cuda.device_count()))
        self.output_device = device_ids[0]
        self.device_ids = device_ids

        self.nprocs_per_node = nprocs_per_node

        if world_size is None or rank is None:
            assert dist.is_initialized()
            rank = dist.get_rank()
            world_size = dist.get_world_size()
        self.process_rank = rank

        if self.nprocs_per_node > 1:
            self.local_rank = self.process_rank % self.nprocs_per_node
            world_size //= nprocs_per_node
            rank //= nprocs_per_node
            if local_node_group is None:
                for node in range(world_size):
                    node_processes_ranks = list(
                        range(node * self.nprocs_per_node,
                              (node + 1) * self.nprocs_per_node))
                    # Process group to communicate between processes on this
                    # machine
                    new_local_group = create_process_group(
                        node_processes_ranks)
                    if self.process_rank in node_processes_ranks:
                        self.local_node_group = new_local_group
            else:
                self.local_node_group = local_node_group
        else:
            self.local_rank = 0

        # put model on output device
        self.module = module
        first_param_dtype = next(self.module.parameters()).dtype

        # prepare local intra-node all-reduce objects
        if len(self.device_ids) > 1:
            self.broadcast_bucket_size = 10 * 1024 * 1024  # bytes
            self.nccl_reduce_bucket_size = 256 * 1024 * 1024  # bytes

            self._module_copies = replicate(self.module, self.device_ids,
                                            detach=True)
            self._module_copies[0] = self.module
            for cmodule in self._module_copies[1:]:
                for p, cp in zip(self.module.parameters(),
                                 cmodule.parameters()):
                    cp.requires_grad = p.requires_grad
        else:
            self._module_copies = [self.module]

        # choose communication device based on backend
        if comm_device is None:
            cpu_comm = True if dist.get_backend() == 'gloo' else False
            comm_device = torch.device('cpu') if cpu_comm else torch.device('cuda')
        self.__cpu_comm = comm_device.type == 'cpu'

        if graph is None:
            graph = NPDDEGraph(
                rank, world_size, self.nprocs_per_node, self.local_rank)

        if mixing is None:
            mixing = UniformMixing(graph, comm_device)

        # distributed backend config
        self.dist_config = {
            'verbose': verbose,
            'comm_device': comm_device,
            'graph': graph,
            'mixing': mixing,
            'push_sum': push_sum,
            'rank': rank,
            'process_rank': self.process_rank,
            'world_size': world_size,
            'cpu_comm': self.__cpu_comm
        }
        self.overlap = overlap
        self.synch_freq = synch_freq
        self.num_updates = 0
        self.asynch = synch_freq > 0

        # logger used to print to stdout
        self.logger = make_logger(rank, verbose)

        # push-sum weight=1.0 ==> distributed averaging
        self.ps_weight = torch.ones(1, device=comm_device).type(
            first_param_dtype)
        self.nprocs_per_node_device = torch.tensor(
            [self.nprocs_per_node], device=comm_device,
            dtype=first_param_dtype)
        self.is_ps_numerator = False

        # prepare parameters for gossip
        self.gossip_enable = True
        self.gossiping = False
        self.params_mixed = True
        self.gossip_ps_factor = torch.zeros(1, device=comm_device).type(
            first_param_dtype)
        self.gossip_ps_weight = self.ps_weight.clone()
        self.gossip_params = []
        self.gossip_device_buffer = []
        for p in module.parameters():
            cp = p.clone().detach_()
            cp = cp.cpu().pin_memory() if self.__cpu_comm else cp.cuda()
            self.gossip_params.append(cp)
            self.gossip_device_buffer.append(cp)

        # prepare gossip process control objects
        self.gossip_lock = threading.Lock()
        self.gossip_flag = threading.Event()
        self.train_flag = threading.Event()

        if self.dist_config['comm_device'].type != 'cpu' and use_streams:
            self.gossip_stream = torch.cuda.Stream()
        else:
            self.gossip_stream = torch.cuda.current_stream()

        if self.process_rank % self.nprocs_per_node == 0:
            self.gossip_thread = threading.Thread(
                target=GossipDataParallel._gossip_target,
                args=(self.dist_config,
                      self.gossip_flag,
                      self.train_flag,
                      self.gossip_lock,
                      self.gossip_params,
                      self.gossip_device_buffer,
                      self.gossip_ps_weight,
                      self.gossip_ps_factor,
                      self.gossip_stream))
            self.gossip_thread.daemon = True
            self.gossip_thread.name = 'Gossip-Thread'
            self.gossip_thread.start()
        else:
            self.gossip_flag.set()
        # wait for thread to complete initialization
        self.gossip_flag.wait()
        self.gossip_flag.clear()
        # lazy mixing avoids additional bias/de-bias steps
        self.lazy_mixing = (
            not self.asynch and self.dist_config['mixing'].is_regular() and
            not self.overlap)
        self.lazy_ps_factor = self.gossip_ps_factor.clone()
        self.logger.debug('lazy mixing: {}'.format(self.lazy_mixing))

        # register ps/grad-reduction hooks
        self.__register_hooks()