chatlearn/runtime/dist_actor.py (335 lines of code) (raw):
# Copyright 2024 Alibaba Group Holding Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Dist Actor"""
from collections import defaultdict
import importlib
import inspect
from functools import partial
import ray
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
from chatlearn.models.base_module import BaseModule
from chatlearn.utils import future
from chatlearn.utils.utils import parse_function_args
vllm_exist = importlib.util.find_spec("vllm")
if vllm_exist:
from chatlearn.models.vllm_module import VLLMModule
from chatlearn.models.vllm_module_v2 import VLLMModuleV2
from chatlearn.utils.constant import CURRENT_VLLM_VERSION, VLLMVersion
RAY_REMOTE = "remote"
class DistActor:
"""Manage a collection of actors"""
def __init__(self, model: BaseModule,
gpu_per_node,
error_signal,
port_manager,
replica_id=0,
storage=None):
self.total_gpu = model.total_gpu
self.total_cpu = model.total_cpu
self.gpu_per_process = model.gpu_per_process
self.num_gpu_per_replica = model.num_gpu_per_replica
self.trainable = model.trainable
self.gpu_per_node = gpu_per_node
self.model = model
self.all_actors = []
self.replica_id = replica_id
self._port_manager = port_manager
self.name = self.model.name
self.error_signal = error_signal
self.storage = storage
# ranks for model update
self.all_ranks = None
self._init_done = False
self._placement_group = None
self.rank_to_actors = {}
@property
def module_args(self):
return self.model.module_args
@property
def runtime_args(self):
return self.model.runtime_args
@property
def master(self):
return self.all_actors[0]
@property
def tailer(self):
return self.all_actors[-1]
@property
def actor_num(self):
return len(self.all_actors)
def _get_func_args(self, func_name):
func = getattr(self.model, func_name)
return parse_function_args(func)
def preprocess_actors(self):
self.add_remote_func()
def add_remote_func(self):
for func_name, _ in inspect.getmembers(self.master):
# ray.actor.ActorMethod
if func_name.startswith('_'):
continue
dist_call = partial(self.call_remote_funcs, func_name)
setattr(self, func_name, dist_call)
def call_actor_remote_func(self, actor, func_name, *args, **kwargs):
func = getattr(actor, func_name)
remote_func = getattr(func, RAY_REMOTE)
res = remote_func(*args, **kwargs)
return res
def call_remote_funcs(self, func_name, *args, **kwargs):
"""
Call remote functions for a collection of actors.
"""
results = []
for actor in self.all_actors:
res = self.call_actor_remote_func(actor, func_name, *args, **kwargs)
results.append(res)
return results
def _create_actor(self, cls, num_gpus, placement_group, group_index, **kwargs):
scheduling_strategy = PlacementGroupSchedulingStrategy(
placement_group=placement_group,
placement_group_bundle_index=group_index,
)
# use max_concurrency=1 to make sure only one task execute at one time
actor = ray.remote(num_gpus=num_gpus, num_cpus=0)(cls) \
.options(scheduling_strategy=scheduling_strategy) \
.remote(self.model.name, self.model.global_args, self.replica_id, **kwargs)
actor.set_error_signal.remote(self.error_signal)
actor.set_storage.remote(self.storage)
self.all_actors.append(actor)
return actor
def create_actor(self, num_gpus, placement_group, group_index):
return self._create_actor(self.model.__class__, num_gpus, placement_group, group_index)
def _setup_collective_group(self, rank_offset, world_size, group_name, backend="nccl"):
refs = []
all_ranks = []
for i, actor in enumerate(self.all_actors):
rank = i + rank_offset
ref = actor.setup_collective_group.remote(
rank=rank,
world_size=world_size,
backend=backend,
group_name=group_name)
refs.append(ref)
all_ranks.append(rank)
self.rank_to_actors[rank] = actor
self.all_ranks = all_ranks
return refs
def _setup_ranks(self, rank_offset):
all_ranks = []
for i, actor in enumerate(self.all_actors):
rank = i + rank_offset
all_ranks.append(rank)
self.rank_to_actors[rank] = actor
self.all_ranks = all_ranks
def terminate(self):
# terminate when catching exceptions
for actor in self.all_actors:
ray.kill(actor)
@property
def placement_group(self):
return self._placement_group
@placement_group.setter
def placement_group(self, pg):
self._placement_group = pg
def group_dist_actors_by_tp_rank(self):
self.dp_rank_to_actors = defaultdict(list)
self.data_parallel_size = future.get(self.all_actors[0].get_data_parallel_size.remote())
if self.data_parallel_size is None:
self.data_parallel_size = 1
dp_ranks = future.wait([actor.get_data_parallel_rank.remote() for actor in self.all_actors], return_output=True)
for actor, dp_rank in zip(self.all_actors, dp_ranks):
self.dp_rank_to_actors[dp_rank].append(actor)
def set_dist_env(self, revert_placement=False):
pass
def __str__(self):
return f"{self.__class__.__name__}({self.name})[{self.replica_id}]"
def __repr__(self):
return f'<{self.__class__.__name__}({self.name})[{self.replica_id}] object at {hex(id(self))}>'
class DistTorchActor(DistActor):
"""DistTorchActor"""
def reorder_actors(self, actors, revert_placement=False):
gpu_per_node = min(self.gpu_per_node, self.model.num_gpu_per_replica)
ordered_actors = []
count = 0
actor_gpus = []
for actor in actors:
gpus = future.get(actor.get_visible_gpus.remote())
count += len(gpus)
actor_gpus.append((actor, gpus))
if count == gpu_per_node:
actor_gpus.sort(key=lambda x: x[1][0])
if revert_placement:
actor_gpus.reverse()
ordered_actors += [a[0] for a in actor_gpus]
actor_gpus = []
count = 0
return ordered_actors
def set_dist_env(self, revert_placement=False):
self.all_actors = self.reorder_actors(self.all_actors, revert_placement)
master_addr = future.get(self.master.get_address.remote())
master_port = future.get(self._port_manager.get_free_port.remote(master_addr))
world_size = self.actor_num
env_config = {"MASTER_ADDR": master_addr, "MASTER_PORT": master_port, "WORLD_SIZE": world_size}
ret = []
for rank, actor in enumerate(self.all_actors):
env_config["RANK"] = rank
if self.model.gpu_per_process == 1:
local_rank = 0
else:
local_rank = rank % self.model.gpu_per_process
env_config["LOCAL_RANK"] = local_rank
ret.append(actor.set_env.remote(env_config))
status = sum(future.get(ret))
assert status == world_size
class DistVLLMActor(DistTorchActor):
"""DistVLLMActor"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.vllm_engine = None
def create_actor(self, num_gpus, placement_group, group_index):
if CURRENT_VLLM_VERSION == VLLMVersion.v_0_6_3:
kwargs = {
"worker_module_name": "vllm.worker.worker",
"worker_class_name": "Worker",
"worker_class_fn": None,
"trust_remote_code": True,
}
else:
kwargs = {
"vllm_actor_type" : "worker"
}
self._create_actor(self.model.__class__, num_gpus, placement_group, group_index, **kwargs)
def create_engine_actor(self, num_gpus, placement_group, group_index):
self.vllm_engine = self._create_actor(self.model.__class__, num_gpus, placement_group, group_index)
self.model.engine = self.vllm_engine
def call_vllm_engine_remote_funcs(self, func_name, *args, **kwargs):
"""
Call remote functions for vllm_engine.
"""
results = []
res = self.call_actor_remote_func(self.vllm_engine, func_name, *args, **kwargs)
results.append(res)
return results
def call_vllm_engine_and_workers_remote_funcs(self, func_name, *args, **kwargs):
"""
Call remote functions for vllm_engine + workers.
"""
results = []
for actor in self.all_actors:
res = self.call_actor_remote_func(actor, func_name, *args, **kwargs)
results.append(res)
res = self.call_actor_remote_func(self.vllm_engine, func_name, *args, **kwargs)
results.append(res)
return results
def add_remote_func(self):
for func_name, _ in inspect.getmembers(self.master):
# ray.actor.ActorMethod
if func_name.startswith('_') or func_name in ["peak_memory"]:
continue
if func_name in ["timer_summary"]:
dist_call = partial(self.call_vllm_engine_remote_funcs, func_name)
elif func_name in ["onload", "offload"]:
if func_name == "onload":
new_func_name = "onload_for_workers"
else:
new_func_name = "offload_for_workers"
dist_call = partial(self.call_vllm_engine_remote_funcs, new_func_name)
elif func_name in ["model_setup"]:
dist_call = partial(self.call_vllm_engine_and_workers_remote_funcs, func_name)
elif func_name in ["get_and_clear_metrics"]:
dist_call = partial(self.call_vllm_engine_remote_funcs, func_name)
else: # needed to check for other call_funs.
dist_call = partial(self.call_remote_funcs, func_name)
setattr(self, func_name, dist_call)
@property
def master(self):
return self.vllm_engine
def peak_memory(self):
return self.model.peak_memory()
class DistModel:
"""DistModel"""
def __init__(self):
self.replicas = []
self.name = None
self.rank_to_actors = {}
self.register_func()
self._is_colocate = False
self._colocate_models = []
def add_replica(self, replica):
self.replicas.append(replica)
self.name = replica.name
@property
def trainable(self):
return self.replicas[0].trainable
@property
def module_args(self):
return self.replicas[0].module_args
@property
def actor_num(self):
return sum(len(dist_actor.all_actors) for dist_actor in self.replicas)
@property
def num_replica(self):
return len(self.replicas)
@property
def total_gpu(self):
return self.replicas[0].total_gpu
@property
def total_cpu(self):
return self.replicas[0].total_cpu
@property
def num_gpu_per_replica(self):
return self.replicas[0].num_gpu_per_replica
@property
def gpu_per_process(self):
return self.replicas[0].gpu_per_process
@property
def is_colocate(self):
return self._is_colocate
@is_colocate.setter
def is_colocate(self, flag):
self._is_colocate = flag
def get_actor(self, rank):
# given rank, return the actor
for dist_actor in self.replicas:
if rank in dist_actor.rank_to_actors:
return dist_actor.rank_to_actors[rank]
def init(self):
refs = []
for dist_actor in self.replicas:
refs.append(dist_actor.init())
future.get(refs)
def register_func(self):
for func_name in ["model_setup",
"before_episode",
"after_episode",
"get_and_clear_metrics",
"validate",
"destroy_collective_group",
"terminate",
"peak_memory",
"empty_cache",
"set_start_iteration",
"offload",
"onload",
"eval",
"train",
"set_src_parameter_model",
"set_colocate"]:
dist_call = partial(self.call_replica_func, func_name)
setattr(self, func_name, dist_call)
def call_replica_func(self, func, *args, **kwargs):
refs = []
for dist_actor in self.replicas:
ref = getattr(dist_actor, func)(*args, **kwargs)
if ref is not None:
refs.append(ref)
return refs
def call_replica_serial_func(self, func, *args, **kwargs):
results = []
for dist_actor in self.replicas:
ref = getattr(dist_actor, func)(*args, **kwargs)
if ref is not None:
res = future.get(ref)
results.append(res)
return results
def set_colocate_models(self, models):
self._colocate_models = models
def colocate_with(self, model):
return model in self._colocate_models
@property
def colocate_models(self):
return self._colocate_models
@property
def all_ranks(self):
return [dist_actor.all_ranks for dist_actor in self.replicas]
@property
def use_vllm_backend(self):
return vllm_exist and isinstance(self.replicas[0].model, (VLLMModule, VLLMModuleV2))
def group_dist_actors_by_tp_rank(self):
for replica in self.replicas:
replica.group_dist_actors_by_tp_rank()
@property
def enable_offload(self):
return self.module_args.free_grad_buffers or self.module_args.offload_weights or \
self.module_args.offload_optimizer_states
def __str__(self):
return f"{self.__class__.__name__}({self.name})"
def __repr__(self):
return f'<{self.__class__.__name__}({self.name}) object at {hex(id(self))}>'