chatlearn/utils/error_monitor.py (45 lines of code) (raw):
# Copyright 2024 Alibaba Group Holding Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Error monitor"""
import time
import ray
import ray.util.collective as col
from chatlearn.utils import future
@ray.remote
class ErrorMonitor:
"""Error Monitor"""
def __init__(self, error_signal, remote_models, group_names):
self.error_signal = error_signal
self.remote_models = remote_models
self.collective_groups = group_names
def monitor(self):
while True:
try:
catch_err = future.get(self.error_signal.is_set.remote())
except Exception:
catch_err = False
if catch_err:
break
time.sleep(2)
for group_name in self.collective_groups:
col.destroy_collective_group(group_name)
for model in self.remote_models:
model.terminate()
error_msg = future.get(self.error_signal.error_msg.remote())
error_address = future.get(self.error_signal.error_address.remote())
raise Exception(f"Catch an exception in {error_address}, error msg: {error_msg}")
@ray.remote(num_cpus=0)
class ErrorSignalActor:
"""ErrorSignalActor"""
def __init__(self):
self.error_state = False
self.err_msg = None
self._address_list = []
def set(self, err_msg=None):
self.error_state = True
if err_msg is not None:
self.err_msg = err_msg
def set_address(self, address):
if address not in self._address_list:
self._address_list.append(address)
def is_set(self):
return self.error_state
def error_msg(self):
return self.err_msg
def error_address(self):
return self._address_list