def _request_some()

in graphlearn_torch/python/channel/remote_channel.py [0:0]


  def _request_some(self):

    def on_done(f: torch.futures.Future, local_server_idx):
      try:
        msg, end_of_epoch = f.wait()
        self.queue.put((msg, end_of_epoch, local_server_idx))

      except Exception as e:
        logging.error("broken future of receiving remote messages: %s", e)

    def create_callback(local_server_idx):

      def callback(f):
        on_done(f, local_server_idx)

      return callback

    from ..distributed import async_request_server, DistServer

    for local_server_idx, server_rank in enumerate(self.server_rank_list):
      if not self.server_end_of_epoch[local_server_idx]:
        for _ in range(
          self.num_received_list[local_server_idx] +
          self.prefetch_size -
          self.num_request_list[local_server_idx]
        ):
          fut = async_request_server(
            server_rank, DistServer.fetch_one_sampled_message,
            self.producer_id_list[local_server_idx]
          )
          cb = create_callback(local_server_idx)
          fut.add_done_callback(cb)
          self.num_request_list[local_server_idx] += 1