chatlearn/schedule/model_manager.py (355 lines of code) (raw):
# Copyright 2024 Alibaba Group Holding Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""model manager"""
import concurrent.futures
from collections import defaultdict
from concurrent.futures import ThreadPoolExecutor
import time
import ray
import ray.experimental.state.api
from chatlearn.data.storage import Storage
from chatlearn.launcher import dlc_utils
from chatlearn import FSDPModule
from chatlearn.models.torch_module import TorchModule
from chatlearn.models.vllm_module_v2 import VLLMModuleV2
from chatlearn.runtime.decorator import decorate_class_func
from chatlearn.runtime.decorator import timeit, preprocess_compute, monitor_error
from chatlearn.runtime.dist_actor import DistActor, DistTorchActor, DistVLLMActor, DistModel
from chatlearn.synchronizer.parameter_sync import ParameterSyncGroup, ParameterSyncGroupwithHEP
from chatlearn.synchronizer.parameter_sync_fsdp import FSDP2VllmParameterSyncGroup
from chatlearn.utils.constant import LOG_START
from chatlearn.utils.error_monitor import ErrorMonitor, ErrorSignalActor
from chatlearn.utils.logger import logger
from chatlearn.utils.global_vars import set_decorated, is_decorated
from chatlearn.utils.megatron_import_memory_helper import MegatronVersion, get_megatron_version
from .port_manager import PortManager
from ..utils import future
class ModelManager:
"""ModelManager"""
def __init__(self, models, resouce_manager, global_args):
self.local_models = models
self.resouce_manager = resouce_manager
self.dist_models = []
self.env_args = global_args.env_args
self.runtime_args = global_args.runtime_args
self.converted = False
# port for DLC jobs, the first two ports are reserved for ray start
self.free_ports = dlc_utils.get_free_ports()[2:]
self._port_manager = PortManager.remote(self.free_ports)
self.error_signal = ErrorSignalActor.remote()
self._storage = Storage.remote()
self.parameter_sync_groups = {}
self._parameter_sync_model_pair = []
self.model_packs = []
self.placement_groups = []
def _get_total_gpu_required(self):
total_gpu = 0
remote_states = set()
for group in self.runtime_args.colocation:
colocate_models = [self._name2distmodel[name] for name in group]
max_gpu = max(m.total_gpu for m in colocate_models)
total_gpu += max_gpu
for name in group:
remote_states.add(name)
for model in self.dist_models:
# place non-colocate models
if model.name not in remote_states:
max_gpu = model.total_gpu
total_gpu += max_gpu
return total_gpu
def remote(self) -> list:
"""
convert model to remote
"""
logger.info(f"{LOG_START} model_manager start to convert model to remote")
t1 = time.time()
if self.converted:
return self.dist_models
self._name2distmodel = {}
remote_states = set()
for model in self.local_models:
# create dist model object for each local model
dist_model = self._to_dist_model(model)
self.dist_models.append(dist_model)
self._name2distmodel[model.name] = dist_model
total_gpu_required = self._get_total_gpu_required()
if total_gpu_required > self.resouce_manager.total_gpu:
raise RuntimeError(f"The number of required gpus for current job is {total_gpu_required}, " + \
f"while the number of applied gpus is {self.resouce_manager.total_gpu}")
if self.resouce_manager.total_gpu > total_gpu_required:
logger.warning(f"The number of applied gpus is {self.resouce_manager.total_gpu}, " + \
f"while the number of required gpus is {total_gpu_required}, " + \
f"there is {self.resouce_manager.total_gpu - total_gpu_required} wasted gpus")
t2 = time.time()
logger.info(f"{LOG_START} model_manager convert model to remote, get_total_gpu_required(s):{(t2-t1)}")
env_list = []
for group in self.runtime_args.colocation:
colocate_models = [self._name2distmodel[name] for name in group]
self.place_models_to_remote_devices(colocate_models, env_list)
if len(colocate_models) > 1:
set_colocate = []
for model in colocate_models:
model.is_colocate = True
set_colocate.extend(model.set_colocate(True))
future.wait(set_colocate)
for name in group:
remote_states.add(name)
t3 = time.time()
logger.info(f"{LOG_START} model_manager convert model to remote, set_colocate(s):{(t3-t2)}")
for model in self.dist_models:
# place non-colocate models
if model.name not in remote_states:
self.place_models_to_remote_devices([model], env_list)
self.set_dist_env_concurrent(env_list)
self.converted = True
t4 = time.time()
logger.info(f"{LOG_START} model_manager convert model to remote, place_models_to_remote_devices(s):{(t4-t3)}")
return self.dist_models
def build_parameter_group(self):
# set ParameterSyncGroup
megatron_version = get_megatron_version()
for src_model, dst_model in self._parameter_sync_model_pair:
logger.info(
f"start build parameter sync group bewteen {src_model.name} and {dst_model.name}")
group_name = self._get_group_name(src_model, dst_model)
sync_frequency = self._get_sync_frequency(dst_model)
if isinstance(self._name2distmodel[src_model.name].replicas[0].model, FSDPModule):
sync_group = FSDP2VllmParameterSyncGroup(
self._name2distmodel[src_model.name],
self._name2distmodel[dst_model.name],
group_name,
sync_frequency,
self.error_signal
)
elif megatron_version == MegatronVersion.V4:
logger.info("QWEN_VERSION has been set to qwen_moe_v1, where HEP is enabled.")
sync_group = ParameterSyncGroupwithHEP(
self._name2distmodel[src_model.name],
self._name2distmodel[dst_model.name],
group_name,
sync_frequency,
self.error_signal
)
else:
sync_group = ParameterSyncGroup(
self._name2distmodel[src_model.name],
self._name2distmodel[dst_model.name],
group_name,
sync_frequency,
self.error_signal
)
self.parameter_sync_groups[group_name] = sync_group
def start_error_monitor(self):
group_names = list(self.parameter_sync_groups.keys())
self.error_monitor = ErrorMonitor.remote(self.error_signal, self.dist_models, group_names)
self.error_monitor.monitor.remote()
def _get_group_name(self, src_model, dst_model):
return src_model.name + "2" + dst_model.name
def _get_sync_frequency(self, model):
return model.parameter_sync_frequency
def set_parameter_sync(self, src_model, tgt_model):
group_name = self._get_group_name(src_model, tgt_model)
if group_name in self.parameter_sync_groups:
logger.warning(f"{group_name} already set, ignore")
else:
sync_frequency = self._get_sync_frequency(tgt_model)
assert sync_frequency >= 0, \
f"parameter sync frequency from {src_model.name} to {tgt_model.name} expected tp be greater than 0, while {sync_frequency}."
logger.info(f"sync parameters from {src_model.name} to {tgt_model.name} every {sync_frequency} episodes.")
self._parameter_sync_model_pair.append((src_model, tgt_model))
def warmup_collective_topology(self):
for _, sync_group in self.parameter_sync_groups.items():
sync_group.warmup_groups()
def sync_parameters(self, episode_offset=0, requires_grad=None, validate=False, dryrun=False):
"""
if requires_grad is False, all parameters will be syncronized,
this happends when broadcast parameters in the beginning of training,
set the parameters of inference same as training
"""
for _, sync_group in self.parameter_sync_groups.items():
if sync_group.frequency and \
episode_offset % sync_group.frequency == 0:
sync_group: ParameterSyncGroup = sync_group
# src_model, dst_model type: DistModel
src_model, dst_model = sync_group.src_model, sync_group.dst_model
future.wait(src_model.onload(
to_build_grad_buffers=False, to_onload_main_weights=False, to_onload_optimizer_states=False))
future.wait(dst_model.onload(
to_build_grad_buffers=False, to_onload_main_weights=False, to_onload_optimizer_states=False))
sync_group.sync(requires_grad, validate, dryrun=dryrun)
future.wait(src_model.offload())
future.wait(dst_model.offload())
def set_func_decorator(self, model):
if is_decorated(model.name):
return
call_funcs = model.call_funcs
model_cls = model.__class__
for func_name in call_funcs:
trainable = func_name in model.trainable_funcs
decorate_class_func(model_cls, func_name,preprocess_compute, trainable)
for func_name in ["save_checkpoint", "model_setup"] + call_funcs:
decorate_class_func(model_cls, func_name, timeit, func_name)
# public user function
# TODO: use decorator to annotate
for func_name in ["save_checkpoint", "model_setup", "onload", "offload", "build_dataset",
"_build_dataloader", "generate_vllm", "init"] + call_funcs:
decorate_class_func(model_cls, func_name, monitor_error, func_name)
set_decorated(model.name)
def _to_dist_model(self, model):
"""
Convert one model to DistActor and place it to devices
Args:
model: BaseModule
"""
self.set_func_decorator(model)
model.finalize()
def actor_type():
if isinstance(model, VLLMModuleV2):
return DistVLLMActor
if isinstance(model, TorchModule):
return DistTorchActor
return DistActor
dist_model = DistModel()
for replica_id in range(model.num_replica):
dist_actor = actor_type()(model, self.resouce_manager.gpu_per_node, self.error_signal, self._port_manager,
replica_id, self._storage)
dist_model.add_replica(dist_actor)
return dist_model
def _find_param_recv_models(self, models):
"""
find models that recv parameters
"""
if len(models) < 2:
return []
model_names = [model.name for model in models]
models_to_revert = []
for model in models:
for src, tgt in self._parameter_sync_model_pair:
if src.name in model_names and model.name == tgt.name:
models_to_revert.append(model)
return models_to_revert
def find_model_packing_strategy(self, models, total_gpu):
"""
Find model packing strategies that can pack all models into total_gpu
try to balance the models among devices, i.e., each device holds similar number of model parts
e.g., given models A:8, B:4, C:4, total_gpu: 8
then the pack strategy is [(A), (B,C)]
"""
sorted_models = sorted(models, key=lambda x: (x.trainable, x.total_gpu), reverse=True)
assert sorted_models[0].total_gpu <= total_gpu
final_packs = []
# key is the remaining gpu
unfinished_packs = defaultdict(list)
for model in sorted_models:
gpu = model.total_gpu
if gpu == total_gpu:
final_packs.append([model])
else:
if gpu in unfinished_packs:
# find a pack
packs = unfinished_packs[gpu].pop(0)
if len(unfinished_packs[gpu]) == 0:
unfinished_packs.pop(gpu)
packs.append(model)
final_packs.append(packs)
else:
near_gpus = [d for d in unfinished_packs if d > gpu]
if near_gpus:
near_gpu = sorted(near_gpus)[0]
packs = unfinished_packs[near_gpu].pop(0)
if len(unfinished_packs[gpu]) == 0:
unfinished_packs.pop(gpu)
packs.append(model)
# update the remaining gpu number
unfinished_packs[near_gpu - gpu].append(packs)
else:
# add model and wait for packing
unfinished_packs[total_gpu - gpu].append([model])
for gpu, packs_list in unfinished_packs.items():
if packs_list:
final_packs.extend(packs_list)
return final_packs
def place_gpu_models(self, gpu_models, env_list=None):
""" place DistModel to gpu
GPU models: Lis[DistModel]
"""
if not gpu_models:
return
max_gpu = max(m.total_gpu for m in gpu_models)
# create placement groups
placement_group = self.resouce_manager.create_placement_group(max_gpu)
for i, _ in enumerate(placement_group.bundle_specs):
self.placement_groups.append((placement_group, i))
models_str = ','.join([model.name for model in gpu_models])
logger.info(f"create placement_group {placement_group.bundle_specs} for model {models_str} done")
for model in gpu_models:
# TODO: for colocate gpu_per_process > 1, support later
assert model.gpu_per_process == 1
self.model_packs = self.find_model_packing_strategy(gpu_models, max_gpu)
for model in gpu_models:
pack = []
for pack in self.model_packs:
if model in pack:
break
colocate_models = []
for model2 in gpu_models:
if model2 is not model and model2 not in pack:
colocate_models.append(model2)
model.set_colocate_models(colocate_models)
def _get_model_replica_from_pack(gpu_index, model_pack):
# for gpu rank between N * model.num_gpu_per_replica to (N + 1) * model.num_gpu_per_replica
# this function will return the same replica
gpu_offset = 0
for model in model_pack:
if gpu_index < gpu_offset + model.total_gpu:
# compute the model rank
model_rank = gpu_index - gpu_offset
replica_id = model_rank // model.num_gpu_per_replica
return model.replicas[replica_id]
gpu_offset += model.total_gpu
# 1. we list the models to place on each device
# 2. for device i, the number of models is N, then the num_gpus for each ray actor is 1.0/N
# replica here is DistActor
gpu_to_replicas = []
for i in range(max_gpu):
colocate_models = []
for model_pack in self.model_packs:
replica = _get_model_replica_from_pack(i, model_pack)
if replica is not None:
colocate_models.append(replica)
gpu_to_replicas.append(colocate_models)
# For each gpu rank, create actor for each replica
for i, replicas in enumerate(gpu_to_replicas):
group = i // self.resouce_manager.gpu_per_node
for replica in replicas:
num_gpus = 1.0 / len(replicas)
if isinstance(replica.model, VLLMModuleV2) and replica.vllm_engine is None:
num_gpus = num_gpus / 2
replica.create_engine_actor(num_gpus, placement_group, group)
# we do not want to add engine actor to all_actors
replica.all_actors.pop()
replica.create_actor(num_gpus, placement_group, group)
models_to_revert = self._find_param_recv_models(gpu_models)
for model in gpu_models:
if model in models_to_revert: # pylint: disable=simplifiable-if-statement
# Reverse the placement of tgt models, so that shared models not in the same GPU
# NCCL limit: NCCL WARN Duplicate GPU detected : rank 1 and rank 0 both on CUDA device
# TODO: One GPU task still not work
reverse_gpu_placement = True
else:
reverse_gpu_placement = False
if env_list is None:
for replica in model.replicas:
replica.set_dist_env(reverse_gpu_placement)
else:
env_list.append((model, reverse_gpu_placement))
def place_cpu_models(self, cpu_models):
if not cpu_models:
return
num_cpus = []
for model in cpu_models:
for _ in range(model.module_args.num_replica):
num_cpus.append(model.module_args.cpu_per_process)
if not self.placement_groups:
placement_group = self.resouce_manager.create_placement_group(num_gpus=0, num_cpus=num_cpus, \
strategy=self.runtime_args.cpu_schedule_strategy)
models_str = ','.join([model.name for model in cpu_models])
logger.info(f"create placement_group {placement_group.bundle_specs} for model {models_str} done")
placement_groups = []
for i, _ in enumerate(placement_group.bundle_specs):
placement_groups.append((placement_group, i))
else:
placement_groups = self.placement_groups
i = 0
for cpu_model in cpu_models:
for replica in cpu_model.replicas:
pg, index = placement_groups[i]
replica.create_actor(0, pg, index)
i = i + 1
if i >= len(placement_groups):
i = 0
def place_models_to_remote_devices(self, models, env_list=None):
cpu_models = [model for model in models if model.total_gpu == 0]
gpu_models = [model for model in models if model.total_gpu > 0]
self.place_gpu_models(gpu_models, env_list)
self.place_cpu_models(cpu_models)
# DistActor.preprocess_actors will add remote call for each function in Actor
for model in models:
for replica in model.replicas:
replica.preprocess_actors()
def _set_dist_env(self, model, reverse):
for replica in model.replicas:
replica.set_dist_env(reverse)
def set_dist_env_concurrent(self, env_list):
num = len(env_list)
if num == 0:
return
with ThreadPoolExecutor(max_workers=num) as executor:
futures = []
for model, reverse in env_list:
# set env
futures.append(executor.submit(self._set_dist_env, model, reverse))
for _future in concurrent.futures.as_completed(futures):
try:
_future.result()
except Exception as e:
raise RuntimeError(f"Set dist env generated an exception: {e}") # pylint: disable=raise-missing-from
concurrent.futures.wait(futures)
def clean(self):
for group in self.parameter_sync_groups.values():
group.destroy_collective_group()
for dist_model in self._name2distmodel.values():
for dist_actor in dist_model.replicas:
for actor in dist_actor.all_actors:
try:
ray.kill(actor)
except Exception:
logger.info("Encountering exceptions in cleaning actors, but ok")
continue
ray.kill(self._storage)
ray.kill(self.error_signal)
self.resouce_manager.remove_placement_groups()