graphlearn_torch/python/channel/remote_channel.py (82 lines of code) (raw):

# Copyright 2022 Alibaba Group Holding Limited. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== import logging import queue import torch from .base import SampleMessage, ChannelBase from typing import Union, List class RemoteReceivingChannel(ChannelBase): r""" A pull-based receiving channel that can fetch sampled messages from remote sampling servers. Args: server_rank (int or List[int]): The ranks of target server to fetch sampled messages. producer_id (int or List[int]) ): The sequence ids of created sampling producer on the target server. prefetch_size (int): The number of messages to prefetch for every server. (Default ``2``). """ def __init__( self, server_rank: Union[int, List[int]], producer_id: Union[int, List[int]], prefetch_size: int = 2 ): self.server_rank_list = server_rank if isinstance(server_rank, List) else [server_rank] self.producer_id_list = producer_id if isinstance(producer_id, List) else [producer_id] self.prefetch_size = prefetch_size assert len(self.server_rank_list) == len(self.producer_id_list) self.num_request_list = [0] * len(self.server_rank_list) self.num_received_list = [0] * len(self.server_rank_list) self.server_end_of_epoch = [False] * len(self.server_rank_list) self.global_end_of_epoch = False self.queue = queue.Queue(maxsize=self.prefetch_size * len(self.server_rank_list)) def reset(self): r""" Reset all states to start a new epoch consuming. """ # Discard messages that have not been consumed. while not self.queue.empty(): _ = self.queue.get() self.server_end_of_epoch = [False] * len(self.server_rank_list) self.num_request_list = [0] * len(self.server_rank_list) self.num_received_list = [0] * len(self.server_rank_list) self.global_end_of_epoch = False def send(self, msg: SampleMessage, **kwargs): raise RuntimeError( f"'{self.__class__.__name__}': cannot send " f"message with a receiving channel." ) def recv(self, **kwargs) -> SampleMessage: if self.global_end_of_epoch: if self._all_received(): raise StopIteration else: self._request_some() msg, end_of_epoch, local_server_idx = self.queue.get() self.num_received_list[local_server_idx] += 1 # server guarantees that when end_of_epoch is true, msg must be None while end_of_epoch: self.server_end_of_epoch[local_server_idx] = True if sum(self.server_end_of_epoch) == len(self.server_rank_list): self.global_end_of_epoch = True if self._all_received(): raise StopIteration msg, end_of_epoch, local_server_idx = self.queue.get() self.num_received_list[local_server_idx] += 1 return msg def _all_received(self): return sum(self.num_received_list) == sum(self.num_request_list) def _request_some(self): def on_done(f: torch.futures.Future, local_server_idx): try: msg, end_of_epoch = f.wait() self.queue.put((msg, end_of_epoch, local_server_idx)) except Exception as e: logging.error("broken future of receiving remote messages: %s", e) def create_callback(local_server_idx): def callback(f): on_done(f, local_server_idx) return callback from ..distributed import async_request_server, DistServer for local_server_idx, server_rank in enumerate(self.server_rank_list): if not self.server_end_of_epoch[local_server_idx]: for _ in range( self.num_received_list[local_server_idx] + self.prefetch_size - self.num_request_list[local_server_idx] ): fut = async_request_server( server_rank, DistServer.fetch_one_sampled_message, self.producer_id_list[local_server_idx] ) cb = create_callback(local_server_idx) fut.add_done_callback(cb) self.num_request_list[local_server_idx] += 1