def __init__()

in gossip/ad_psgd.py [0:0]


    def __init__(self, module, device_ids=None, master_addr=None,
                 master_port=None, backend=None, world_size=None, rank=None,
                 graph_class=None, mixing_class=None, num_peers=1,
                 comm_device=None, lr=0.1, momentum=0.9, weight_decay=1e-4,
                 nesterov=True, verbose=True, network_interface_type=None,
                 tcp_interface_name=None):
        super(BilatGossipDataParallel, self).__init__()

        # devices available locally
        if device_ids is None:
            device_ids = list(range(torch.cuda.device_count()))
        self.output_device = device_ids[0]
        self.device_ids = device_ids

        # put model on output device
        self.module = module.cuda(self.output_device)

        # prepare local intra-node all-reduce objects
        if len(self.device_ids) > 1:
            self.broadcast_bucket_size = 10 * 1024 * 1024  # bytes
            self.nccl_reduce_bucket_size = 256 * 1024 * 1024  # bytes

            self._module_copies = replicate(self.module, self.device_ids,
                                            detach=True)
            self._module_copies[0] = self.module
            for cmodule in self._module_copies[1:]:
                for p, cp in zip(self.module.parameters(),
                                 cmodule.parameters()):
                    cp.requires_grad = p.requires_grad
        else:
            self._module_copies = [self.module]

        # communicate over cpu's if not specified
        if comm_device is None:
            comm_device = torch.device('cpu')
        self.__cpu_comm = comm_device.type == 'cpu'

        # distributed backend config
        self.dist_config = {
            'verbose': verbose,
            'graph_class': graph_class,
            'master_addr': master_addr,
            'master_port': master_port,
            'backend': backend,
            'world_size': world_size,
            'rank': rank,
            'mixing_class': mixing_class,
            'lr': lr,
            'momentum': momentum,
            'nesterov': nesterov,
            'weight_decay': weight_decay,
            'comm_device': comm_device,
            'network_interface_type': network_interface_type,
            'num_peers': num_peers
        }
        self.num_updates = 0

        # logger used to print to stdout
        self.logger = make_logger(rank, verbose)

        # prepare parameters for gossip
        self.gossip_enable = True
        self.gossip_params = []
        self.gossip_grads = []
        for p in module.parameters():
            cp = p.clone().detach_()
            cp = cp.cpu().pin_memory() if self.__cpu_comm else cp.cuda()
            cp.requires_grad = p.requires_grad
            self.gossip_params.append(cp)
            if p.requires_grad:
                g = cp.clone().zero_().detach_()
                g = g.cpu().pin_memory() if self.__cpu_comm else g.cuda()
                self.gossip_grads.append(g)

        self.gossip_queue = mp.Queue()
        self.gossip_lock = mp.Lock()
        self.gossip_enable_flag = mp.Event()
        self.train_write_flag = mp.Event()  # signal train-proc write event
        self.gossip_read_flag = mp.Event()  # signal gossip-proc read event
        self.gossip_update_flag = mp.Event()  # signal 2 gossip-proc need update
        self._lr = mp.Value('f', lr, lock=self.gossip_lock)
        self.gossip_thread = mp.Process(
            target=BilatGossipDataParallel._gossip_target,
            args=(self.dist_config,
                  self.gossip_enable_flag,
                  self.train_write_flag,
                  self.gossip_read_flag,
                  self.gossip_update_flag,
                  self._lr,
                  self.gossip_lock,
                  self.gossip_queue,
                  tcp_interface_name))
        self.gossip_thread.daemon = True
        self.gossip_thread.name = 'Gossip-Thread'
        self.gossip_thread.start()

        # pass handle to gossip_params and gossip_grads, and put in shared
        # memory
        self.gossip_queue.put((self.gossip_params, self.gossip_grads))

        # register ps/grad-reduction hooks
        self.__register_hooks()