chatlearn/__init__.py (45 lines of code) (raw):

# Copyright 2024 Alibaba Group Holding Limited. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """init""" import importlib from chatlearn import hooks from chatlearn.launcher.initialize import init from chatlearn.models.base_module import BaseModule from chatlearn.models.deepspeed_module import DeepSpeedModule from chatlearn.models.megatron_module import MegatronModule from chatlearn.models.torch_module import TorchModule from chatlearn.models.fsdp_module import FSDPModule from chatlearn.runtime.engine import DPOEngine from chatlearn.runtime.engine import Engine from chatlearn.runtime.engine import Environment from chatlearn.runtime.engine import EvalEngine from chatlearn.runtime.engine import OnlineDPOEngine from chatlearn.runtime.engine import GRPOEngine from chatlearn.runtime.engine import GRPOMathEngine from chatlearn.runtime.engine import RLHFEngine from chatlearn.runtime.engine import Trainer from chatlearn.runtime.evaluator import Evaluator from chatlearn.runtime.model_flow import ControlDependencies from chatlearn.utils.future import get from chatlearn.utils.global_vars import get_args from chatlearn.utils.logger import logger vllm_exist = importlib.util.find_spec("vllm") if vllm_exist: import vllm from chatlearn.utils.constant import CURRENT_VLLM_VERSION, VLLMVersion # pylint: disable=ungrouped-imports if CURRENT_VLLM_VERSION in [version.value for version in VLLMVersion]: from chatlearn.models.vllm_module import VLLMModule from chatlearn.models.vllm_module_v2 import VLLMModuleV2 # for compatibility, remove later class RLHFVLLMModule(VLLMModule): """RLHFVLLMModule is deprecated, please use VLLMModule""" def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) logger.warning("RLHFVLLMModule is deprecated, please use VLLMModule") # for compatibility, remove later class RLHFModule(BaseModule): """RLHFModule is deprecated, please use BaseModule""" def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) logger.warning("RLHFModule is deprecated, please use BaseModule") # for compatibility, remove later class RLHFTorchModule(TorchModule): """RLHFTorchModule is deprecated, please use TorchModule""" def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) logger.warning("RLHFTorchModule is deprecated, please use TorchModule") # for compatibility, remove later class RLHFMegatronModule(MegatronModule): """RLHFMegatronModule is deprecated, please use MegatronModule""" def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) logger.warning("RLHFMegatronModule is deprecated, please use MegatronModule")