in gala/gpu_gossip_buffer.py [0:0]
def aggregate_message(self, rank, model):
"""
Average messages with local model:
Average all in-neighbours' (defined in 'self.topology') parameters with
agent 'rank's 'model' and copy the result into 'model'.
Agents should only aggregate messages into their own model:
i.e., ensure 'model belongs to agent 'rank'
"""
with torch.no_grad():
# Check if in-peers finished writing messages to broadcast buffers
_, in_peers = self.topology[rank].get_peers()
write_complete = True
for peer in in_peers:
peer_buffer = self.msg_buffer[peer]
write_event = peer_buffer[2][rank]
if not write_event.is_set():
write_complete = False
break
# Check if any messages are excessively stale
stale_assert = self.sync_list[rank] >= self.sync_freq
# If peers done writing or message too stale, wait and clear events
if write_complete or stale_assert:
for peer in in_peers:
peer_buffer = self.msg_buffer[peer]
write_event = peer_buffer[2][rank]
write_event.wait()
write_event.clear()
self.sync_list[rank] = 0
# Not done writing, but staleness is still tolerable
else:
self.sync_list[rank] += 1
print('%s: staleness %s' % (rank, self.sync_list[rank]))
return
# Lazy-mixing of local params
num_peers = self.topology[rank].peers_per_itr
for p in model.parameters():
p.data.div_(num_peers + 1)
# Aggregate received messages
for peer in in_peers:
peer_buffer = self.msg_buffer[peer]
lock = peer_buffer[3]
with lock:
# Read message and update 'params'
peer_msg = peer_buffer[0]
for p, bp in zip(model.parameters(),
peer_msg.parameters()):
p.data.add_(bp.to(p.device, non_blocking=True))
torch.cuda.current_stream().synchronize()
# Mark message as 'read'
peer_buffer[1][rank].set()