in graphlearn_torch/python/channel/remote_channel.py [0:0]
def recv(self, **kwargs) -> SampleMessage:
if self.global_end_of_epoch:
if self._all_received():
raise StopIteration
else:
self._request_some()
msg, end_of_epoch, local_server_idx = self.queue.get()
self.num_received_list[local_server_idx] += 1
# server guarantees that when end_of_epoch is true, msg must be None
while end_of_epoch:
self.server_end_of_epoch[local_server_idx] = True
if sum(self.server_end_of_epoch) == len(self.server_rank_list):
self.global_end_of_epoch = True
if self._all_received():
raise StopIteration
msg, end_of_epoch, local_server_idx = self.queue.get()
self.num_received_list[local_server_idx] += 1
return msg