def _gossip_target()

in gossip/ad_psgd.py [0:0]


    def _gossip_target(dist_config, gossip_enable_flag, train_write_flag,
                       gossip_read_flag, gossip_update_flag, gossip_lr,
                       gossip_lock, gossip_queue, tcp_interface_name):
        """ Gossip thread, which performs push-sum on model params """
        with torch.no_grad():
            gossip_params, gossip_grads = gossip_queue.get()

            # prepare gossip process control objects
            gossip_optimizer = torch.optim.SGD(
                gossip_params,
                lr=dist_config['lr'],
                momentum=dist_config['momentum'],
                weight_decay=dist_config['weight_decay'],
                nesterov=dist_config['nesterov'])

            if dist_config['backend'] == 'gloo':
                assert dist_config['network_interface_type'] == 'ethernet'
            elif dist_config['network_interface_type'] == 'ethernet':
                if dist_config['backend'] == 'nccl':
                    os.environ['NCCL_SOCKET_IFNAME'] = tcp_interface_name
                    os.environ['NCCL_IB_DISABLE'] = '1'
                elif dist_config['backend'] == 'gloo':
                    os.environ['GLOO_SOCKET_IFNAME'] = tcp_interface_name
                else:
                    raise NotImplementedError

            # initialize torch distributed backend
            os.environ['MASTER_ADDR'] = dist_config['master_addr']
            os.environ['MASTER_PORT'] = dist_config['master_port']
            dist.init_process_group(backend=dist_config['backend'],
                                    world_size=dist_config['world_size'],
                                    rank=dist_config['rank'])

            logger = make_logger(dist.get_rank(), dist_config['verbose'])
            logger.debug('init rcvd: gossip_params {}, gossip_grads {}'.format(
                gossip_params[0].norm(), gossip_grads[0].norm()))

            # init gossip instance
            graph_class = dist_config['graph_class']
            mixing_class = dist_config['mixing_class']

            if graph_class:
                # dist.barrier is done here to ensure the NCCL communicator is
                # created here. This prevents an error which may be caused if
                # the NCCL # communicator is created at a time gap of more
                # than 5 minutes in different processes
                dist.barrier()
                graph = graph_class(
                    dist_config['rank'], dist_config['world_size'],
                    peers_per_itr=dist_config['num_peers'])
            if mixing_class and graph:
                mixing = mixing_class(graph, dist_config['comm_device'])

            gossiper = BilatPushPull(flatten_tensors(gossip_params),
                                     graph=graph,
                                     mixing=mixing,
                                     logger=logger)

            dist_config['graph'] = gossiper._graph_manager
            dist_config['mixing'] = gossiper._mixing_manager
            dist_config['gossiper'] = gossiper
            model_meter = Meter(ptag='Model', stateful=True, csv_format=False)
            gossip_meter = Meter(
                ptag='Gossip', stateful=True, csv_format=False)
            gossip_read_flag.set()

            # gossip loop
            while True:
                # we may be asked to hold off on gossip for some time
                gossip_enable_flag.wait()

                # we may be notified to update our learning rate
                if gossip_update_flag.is_set():
                    for pg in gossip_optimizer.param_groups:
                        pg['lr'] = gossip_lr.value
                    logger.debug('updated lr to {}'.format(gossip_lr.value))
                    gossip_update_flag.clear()

                # train process is telling us it computed the new grads
                if train_write_flag.is_set():
                    bt = time.time()
                    with gossip_lock:
                        i = 0
                        for p in gossip_params:
                            if p.requires_grad:
                                p.grad = gossip_grads[i]
                                i += 1
                        gossip_optimizer.step()
                        gossip_optimizer.zero_grad()
                    train_write_flag.clear()
                    gossip_read_flag.set()
                    model_meter.update(time.time() - bt)
                    logger.debug(model_meter)

                try:
                    # construct gossip tensor
                    bt = time.time()
                    with gossip_lock:
                        out_msg = flatten_tensors(gossip_params).to(
                            dist_config['comm_device'])
                    # gossip step
                    in_msg, completed = gossiper.mix(out_msg)
                    # update gossip params (local model)
                    if completed:
                        with gossip_lock:
                            for p, g in zip(
                                    gossip_params, unflatten_tensors(
                                        in_msg, gossip_params)):
                                p.data.add_(g.to(p.device)).mul_(0.5)
                    gossip_meter.update(time.time() - bt)
                    logger.debug(gossip_meter)
                except RuntimeError as e:
                    logger.warning('received runtime error {}'.format(e))
                    gossiper.clean_msg_buffers_()