in gossip/gossiper.py [0:0]
def __init__(self, msg, graph, device=None, mixing=None, logger=None,
rank=None, world_size=None):
"""
Initialize generic averaging class designed for multi-peer comms
:param msg: (tensor) message used to initialize recv buffer
:param device: (device) device on which to initialize recv buffer
:param graph: (GraphManager) Subclass of GraphManager
:param mixing: (MixingManager) Subclass of MixingManager
:param logger: (python logger) module used to log results
"""
self.logger = logger
if rank is None or world_size is None:
assert dist.is_initialized()
# for now p2p communication only supported withed tcp and mpi
assert dist._backend != dist_backend.GLOO
assert dist._backend != dist_backend.NCCL
rank = dist.get_rank()
world_size = dist.get_world_size()
# graph topology properties
self.rank = rank
self.world_size = world_size
assert isinstance(graph, GraphManager)
self._graph_manager = graph
self.peers_per_itr_device = torch.tensor(
[self._graph_manager.peers_per_itr], device=device,
dtype=msg.dtype)
self.passive = self._graph_manager.is_passive()
self.refresh_peers_(rotate=False) # sets in- and out-peers attributes
# mixing matrix
if mixing is None:
mixing = UniformMixing(self._graph_manager, device)
assert isinstance(mixing, MixingManager)
self._mixing_manager = mixing
self.refresh_mixing_weights_() # sets mixing-weights attribute
# regular ==> we don't need to keep track of ps-weight explicitly
self.regular = self._mixing_manager.is_regular()
# msg buffers used during send/recv
self.device = device if device is not None else msg.device
self.out_msg_buffer = []
self.in_msg_buffer = msg.clone().detach_().to(self.device)
self._ps_weight = torch.ones(1).detach_().to(self.device).type(
msg.dtype)
# not using regular comms ==> need to communicate ps-weight
if not self.regular:
self.in_msg_buffer = torch.cat([self.in_msg_buffer,
self.ps_weight])
if self.device.type == 'cpu':
try:
self.in_msg_buffer = self.in_msg_buffer.pin_memory()
except Exception as e:
if self.logger is not None:
self.logger.error(e)
self.placeholder = self.in_msg_buffer.clone()
self._pending_req = None