chatlearn/launcher/dlc_utils.py (195 lines of code) (raw):
# Copyright 2024 Alibaba Group Holding Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""DLC utils"""
import atexit
from collections import defaultdict
import json
import os
import sys
import time
import concurrent.futures
import threading
import ray
from ray.util.scheduling_strategies import NodeAffinitySchedulingStrategy
from chatlearn.utils import utils
from chatlearn.utils.global_vars import get_args
from chatlearn.utils.logger import logger
from chatlearn.utils.global_vars import _EXIT_ACTOR_NAME
from chatlearn.utils.log_monitor import LogMonitor, is_proc_alive, LogActor
from chatlearn.utils.utils import execute, get_ray_status
DLC_PORT_KEY = "CUSTOM_PORTS"
JOB_NAME_KEY = "JOB_NAME"
RANK_KEY = "RANK"
MASTER_ROLE = "master"
WORKER_ROLE = "worker"
PORT_SEP = ";"
LOCAL_MASTER_KEY = "LOCAL_MASTER_ADDR"
_warn_once = False
WORKER_SLEEP_SECOND = 2
_LOG_ACTOR_NAME = "CHATLEARN_LOG_ACTOR"
_EXIT_SIGNAL = False
def is_local():
return LOCAL_MASTER_KEY in os.environ
def in_dlc_env():
# Check whether in DLC env
if is_local():
# MOCK DLC in local clusters
return True
args = get_args()
if not args.env_args.platform.lower() == "dlc":
return False
global _warn_once
for key in [DLC_PORT_KEY, JOB_NAME_KEY, RANK_KEY]:
if key not in os.environ:
if not _warn_once:
logger.warning(f"cannot find {key} in DLC env, please check whether the job is submitted in DLC "
f"or whether customPortList/createSvcForAllWorkers is set")
logger.warning("fallback to local mode")
_warn_once = True
return False
return True
def get_dlc_env(key):
assert key in os.environ, f"cannot find {key} in DLC env"
return os.environ[key]
def get_job_name():
return get_dlc_env(JOB_NAME_KEY)
def get_master_addr():
if is_local():
return os.environ[LOCAL_MASTER_KEY]
job_name = get_job_name()
return f"{job_name}-{MASTER_ROLE}-0"
def get_rank():
return int(get_dlc_env(RANK_KEY))
def get_addr():
if is_local():
return utils.get_host_addr()
rank = get_rank()
job_name = get_job_name()
if rank == 0:
role = MASTER_ROLE
index = 0
else:
role = WORKER_ROLE
index = rank - 1
return f"{job_name}-{role}-{index}"
def get_free_ports():
# port for DLC jobs
assert DLC_PORT_KEY in os.environ, f"cannot find port {DLC_PORT_KEY} in DLC"
free_ports = [int(port) for port in os.environ[DLC_PORT_KEY].strip().split(PORT_SEP)]
# remove ports that reserved by ray
# 'client_server': 10001, 'dashboard': 8265, 'dashboard_agent_grpc': 49948, 'dashboard_agent_http': 52365,
# 'metrics_export': 63529, 'redis_shards': 'random', 'worker_ports': '9998 ports from 10002 to 19999'
def _valid_port(port):
if port in [10001, 8265, 49948, 52365, 63529]:
return False
if 10002 <= port <= 19999:
return False
return True
free_ports = [port for port in free_ports if _valid_port(port)]
return free_ports
def start_ray_cluster():
free_ports = get_free_ports()
port = free_ports[0]
node_manager_port = free_ports[1]
master_addr = get_master_addr()
rank = get_rank()
system_config = json.dumps({"object_timeout_milliseconds": 30000})
if rank == 0:
cmd = f"RAY_prestart_worker_first_driver=0 ray start --head --port={port} --node-ip-address={master_addr} " + \
f"--node-manager-port {node_manager_port} --node-name={master_addr} --system-config='{system_config}' " + \
"--dashboard-host=0.0.0.0 --dashboard-port=8265"
else:
cmd = f"ray start --address={master_addr}:{port} --node-manager-port {node_manager_port} " + \
f"--node-name={get_addr()} --dashboard-host=0.0.0.0 --dashboard-port=8265"
logger.info(f"execute {cmd}")
state, _ = execute(cmd)
if not state:
sys.exit(1)
def filter_known_msg(msg):
if "StatusCode.DEADLINE_EXCEEDED" in msg:
return True
return False
@ray.remote
class ExitActor:
"""ExitActor"""
def __init__(self):
self._node_and_err_msg = defaultdict(list)
def notify(self):
return 1
def add_error_node_and_msg(self, ip, msg):
self._node_and_err_msg[ip].append(msg)
def get_error_node_and_msg(self):
return self._node_and_err_msg
def get_error_msg(self, ip):
return self._node_and_err_msg[ip]
def execute_with_timeout(func, args, timeout):
with concurrent.futures.ThreadPoolExecutor() as executor:
future = executor.submit(func, *args)
try:
result = future.result(timeout)
return result
except concurrent.futures.TimeoutError:
future.cancel()
print("Function execution timed out.")
except Exception:
# actor has not been created yet
return
class StartExitListener:
"""StartExitListener"""
def __init__(self):
log_dir = os.path.dirname(os.path.dirname(ray.nodes()[0]['ObjectStoreSocketName']))
self.log_dir = os.path.join(log_dir, 'logs')
print(self.log_dir, flush=True)
log_actor = None
# Only run the actor on the master node.
if get_rank() == 0:
log_actor = LogActor.options(
name=_LOG_ACTOR_NAME,
scheduling_strategy=NodeAffinitySchedulingStrategy(
node_id=ray.get_runtime_context().get_node_id(),
soft = False,
), lifetime="detached"
).remote()
else:
while log_actor is None:
try:
log_actor = ray.get_actor(_LOG_ACTOR_NAME)
except Exception:
print(f'get actor {_LOG_ACTOR_NAME} failed, retry ....')
time.sleep(2)
self.log_monitor = LogMonitor(
self.log_dir,
is_proc_alive,
log_actor
)
self._start_exit_actor = None
self.quit_event = threading.Event()
self.log_monitor_thread = threading.Thread(target=self.log_monitor.run, args=(self.quit_event,))
self.log_monitor_thread.daemon = True
self.log_monitor_thread.start()
def stop(self):
self.quit_event.set()
self.log_monitor_thread.join(2)
ray.shutdown()
logger.info("Execute ray.shutdown before the program exits. Done ...")
def start_exit_listener(self):
atexit.register(self.stop)
address = get_addr()
if get_rank() == 0:
self._start_exit_actor = ExitActor.options(name=_EXIT_ACTOR_NAME, lifetime="detached").remote()
else:
# wait for the head node to be created
head_created = False
while True:
cluster_state, msg = get_ray_status()
if cluster_state:
if msg is None:
head_created = True
else:
if not filter_known_msg(msg):
logger.warning(f"ray status got unknown msg {msg}, ignore ...")
else:
if head_created:
logger.info(f"ray status got msg {msg}")
logger.info("head has exited, exit worker ...")
break
logger.info("wait for head to be created.")
if self._start_exit_actor is None:
self._start_exit_actor = execute_with_timeout(ray.get_actor, [_EXIT_ACTOR_NAME], 3)
if self._start_exit_actor is not None:
try:
error_msg_list = ray.get(self._start_exit_actor.get_error_msg.remote(address))
except ray.exceptions.RayActorError:
logger.info("start_exit_actor has been killed")
break
if error_msg_list:
msg = '\n'.join(error_msg_list)
raise Exception(msg)
time.sleep(WORKER_SLEEP_SECOND)
print("Exit worker", flush=True)