chatlearn/launcher/initialize.py (41 lines of code) (raw):
# Copyright 2024 Alibaba Group Holding Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Initialize"""
import os
import sys
import ray
import torch
from cupy.cuda import nccl
from ray.util.collective.collective_group.nccl_util import TORCH_NCCL_DTYPE_MAP
from chatlearn.launcher import dlc_utils
from chatlearn.utils.arguments import parse_args
from chatlearn.utils.global_vars import set_global_variables
from chatlearn.utils.global_vars import set_initialized
from chatlearn.utils.logger import logger
from chatlearn.utils.version import VERSION
def patch_ray():
TORCH_NCCL_DTYPE_MAP[torch.bfloat16] = nccl.NCCL_BFLOAT16
patch_ray()
def init_ray(runtime_env_args):
runtime_env = {"env_vars": {}}
python_path = os.environ.get("PYTHONPATH", "")
if python_path:
runtime_env["env_vars"]["PYTHONPATH"] = python_path
def _set_runtime_env(runtime_env_args, attribute, runtime_env):
if getattr(runtime_env_args, attribute):
runtime_env[attribute] = getattr(runtime_env_args, attribute)
for key in ['pip', 'working_dir', 'py_modules', 'excludes']:
_set_runtime_env(runtime_env_args, key, runtime_env)
# namespace is needed to get NamedActor
ray.init(runtime_env=runtime_env, namespace="CHATLEARN", _node_ip_address=dlc_utils.get_addr(), log_to_driver=False)
def init(args=None):
"""
Initialize ChatLearn env, including
1. init_process_group for distributed
2. ...
"""
if args is None:
args = parse_args()
set_global_variables(args)
if dlc_utils.in_dlc_env():
dlc_utils.start_ray_cluster()
init_ray(args.env_args)
set_initialized()
if dlc_utils.in_dlc_env():
listener = dlc_utils.StartExitListener()
listener.start_exit_listener()
if dlc_utils.get_rank() > 0:
logger.info(f"RANK: {dlc_utils.get_rank()}: task finish, exit ...")
# other workers exit after head exit
sys.exit(0)
logger.info(f"init chatlearn done, version {VERSION}")