chatlearn/utils/global_vars.py (55 lines of code) (raw):

# Copyright 2024 Alibaba Group Holding Limited. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """global vars.""" _GLOBAL_ARGS = None _EXIT_ACTOR = None _EXIT_ACTOR_NAME = "ChatLearnExitActor" _DECORATED_MODELS = None _DECORATED_OUTER_TO_INNER = {} _DEPENDENCIES = None _VLLM_ACTORS = None def _ensure_var_is_initialized(var, name): """Make sure the input variable is not None.""" assert var is not None, '{} is not initialized.'.format(name) def get_args(): """Return arguments.""" _ensure_var_is_initialized(_GLOBAL_ARGS, 'args') return _GLOBAL_ARGS def is_initialized(): if _GLOBAL_ARGS is None: return False return _GLOBAL_ARGS.initialized def set_initialized(): _GLOBAL_ARGS.initialized = True def set_global_variables(args): """Set global vars""" assert args is not None global _GLOBAL_ARGS _GLOBAL_ARGS = args global _DECORATED_MODELS _DECORATED_MODELS = set() def set_decorated(model_name): _DECORATED_MODELS.add(model_name) def is_decorated(model_name): _ensure_var_is_initialized(_DECORATED_MODELS, 'decorated_models') return bool(model_name in _DECORATED_MODELS) def unwrap_func(func, level=None): """ func: func to unwrap level: unwrap level, if level is None, unwrap to the original func """ if func not in _DECORATED_OUTER_TO_INNER: return func if level is not None: if level > 0: level -= 1 else: return func return unwrap_func(_DECORATED_OUTER_TO_INNER[func], level) def set_wrap_func(func, new_func): assert new_func not in _DECORATED_OUTER_TO_INNER _DECORATED_OUTER_TO_INNER[new_func] = func def set_dependencies(dependencies): global _DEPENDENCIES assert _DEPENDENCIES is None _DEPENDENCIES = dependencies def reset_dependencies(): global _DEPENDENCIES _DEPENDENCIES = None def get_dependencies(): return _DEPENDENCIES def set_vllm_actors(actors): global _VLLM_ACTORS _VLLM_ACTORS = actors def get_vllm_actors(): return _VLLM_ACTORS