in graphlearn_torch/python/channel/remote_channel.py [0:0]
def _request_some(self):
def on_done(f: torch.futures.Future, local_server_idx):
try:
msg, end_of_epoch = f.wait()
self.queue.put((msg, end_of_epoch, local_server_idx))
except Exception as e:
logging.error("broken future of receiving remote messages: %s", e)
def create_callback(local_server_idx):
def callback(f):
on_done(f, local_server_idx)
return callback
from ..distributed import async_request_server, DistServer
for local_server_idx, server_rank in enumerate(self.server_rank_list):
if not self.server_end_of_epoch[local_server_idx]:
for _ in range(
self.num_received_list[local_server_idx] +
self.prefetch_size -
self.num_request_list[local_server_idx]
):
fut = async_request_server(
server_rank, DistServer.fetch_one_sampled_message,
self.producer_id_list[local_server_idx]
)
cb = create_callback(local_server_idx)
fut.add_done_callback(cb)
self.num_request_list[local_server_idx] += 1