chatlearn/runtime/executor.py (366 lines of code) (raw):
# Copyright 2024 Alibaba Group Holding Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Executor"""
import threading
from collections import defaultdict
from itertools import cycle
from ray.util.queue import Queue
import torch
from chatlearn.models.vllm_module_v2 import VLLMModuleV2
from chatlearn.runtime.model_flow import ModelFlow
from chatlearn.utils import future
from chatlearn.utils.constant import CHATLEARN_REGROUP_TAG, INDEX_TAG
from chatlearn.utils.constant import LOG_START
from chatlearn.utils.global_vars import get_args
from chatlearn.utils.logger import logger
from .utils import encode_data, decode_data
from .utils import FlowParser
def split_list(lst, n):
assert len(lst) % n == 0, f"{len(lst)} % {n} != 0"
k = len(lst) // n
return [lst[i*k:(i+1)*k] for i in range(n)]
def split_along_batch(tensors, num_splits):
res = [{} for _ in range(num_splits)]
if tensors is None:
return res
for key in tensors.keys():
to_batch = tensors[key]
if isinstance(to_batch, torch.Tensor):
batched = to_batch.chunk(num_splits)
elif isinstance(to_batch, list):
batched = split_list(to_batch, num_splits)
else:
raise Exception(f"unknown types key: {key} and {type(to_batch)} to split: {key} {tensors.keys()} {to_batch}")
for idx, ele in enumerate(batched):
res[idx][key] = ele
return res
# pylint: disable=not-callable
class Executor:
"""Executor"""
def __init__(self, model_flow):
"""
Executor
Args
----
models : List[BaseModule]
a list of modules
"""
self._set_flow(model_flow)
self.args = get_args().runtime_args
self.model_flow = None
self.local_models = self.models
self._batch_per_episode = -1
self.is_eval = False
self._timers = None
self.model2iter = {}
self.merged_buffer = defaultdict(dict)
self._metric_list = []
def set_timers(self, _timers):
self._timers = _timers
@property
def timers(self):
return self._timers
def _set_flow(self, flow):
"""
Set compution flow
Args
----
flow : callable
a function that defines model computation flow
Returns
-------
Executor
return self
"""
self._flow = flow
self.model_to_call_funcs = FlowParser().parse(flow)
for model, func_names in self.model_to_call_funcs.items():
model.call_funcs += func_names
self.models = list(self.model_to_call_funcs.keys())
return self
@property
def first_node(self):
return self.model_flow.model_nodes[0]
@property
def first_model(self):
return self.first_node.model
def update_models(self, models):
# update local model with remote models
new_models = []
name_to_new_models = {model.name: model for model in models}
for model in self.local_models:
dist_model = name_to_new_models[model.name]
dist_model.group_dist_actors_by_tp_rank()
new_models.append(dist_model)
self.models = new_models
if self.args is None:
self.args = get_args().runtime_args
def setup(self):
self._models_and_results_to_wait = []
self.model_flow = ModelFlow(self)
self.model_flow.trace(self.models, self._flow)
self.models = [model_node.model for model_node in self.model_flow.model_nodes]
self.model_locks = {model_node: threading.Lock() for model_node in self.model_flow.model_nodes}
def _next_model(self, model):
if len(model.replicas) == 1:
return model.replicas[0]
if model not in self.model2iter:
self.model2iter[model] = cycle(iter(model.replicas))
return next(self.model2iter[model])
def get_merged_data(self, queues, encode=True, micro_batch_index=None, model_node=None, trainable=False):
mb0 = None
if micro_batch_index is not None:
mb0 = micro_batch_index
data_list = [None] * len(queues)
merged_buffer = self.merged_buffer[model_node]
for index, queue in enumerate(queues):
if index not in merged_buffer:
merged_buffer[index] = {}
if mb0 in merged_buffer[index]:
data_list[index] = merged_buffer[index].pop(mb0)
continue
while True:
flag = False
while queue.qsize() == 0:
if mb0 in merged_buffer[index]:
data_list[index] = merged_buffer[index].pop(mb0)
flag = True
break
if flag:
break
encoded_data = queue.get()
mb, data = decode_data(encoded_data)
if mb0 is None:
mb0 = mb
if isinstance(data, list) and not trainable:
data = data[-1]
if mb == mb0:
data_list[index] = data
break
merged_buffer[index][mb] = data
if encode:
return encode_data(mb0, data_list)
return data_list
def get_merged_data_locked(self, queues, encode=True, micro_batch_index=None, model_node=None, trainable=False):
with self.model_locks[model_node]:
return self.get_merged_data(queues, encode, micro_batch_index, model_node, trainable)
@staticmethod
def align_out_queues(queues, encode=False):
# TODO: deal with one2many scene
out_queues = []
min_qsize = min([ele.qsize() for ele in queues]) # pylint: disable=consider-using-generator
for queue in queues:
num_producers = queue.qsize()
if num_producers == min_qsize:
out_queues.append(queue)
continue
assert num_producers % min_qsize == 0
out_queue = Queue()
res_list = []
while queue.qsize() > 0:
res = queue.get()
res = decode_data(res)[1] if encode else res
res_list.append(res)
division = num_producers // min_qsize
in_qsize = len(res_list)
out_qsize = in_qsize // division
for q_idx in range(out_qsize):
start = q_idx * division
end = start + division
out_queue.put(encode_data(q_idx, {CHATLEARN_REGROUP_TAG:res_list[start:end]}))
out_queues.append(out_queue)
return out_queues
def get_all_merged_data(self, queues, out_queue, encode=True):
logger.info(f"{LOG_START} start to align output queues with sizes {[ele.qsize() for ele in queues]}.")
queues = self.align_out_queues(queues, True)
logger.info(f"{LOG_START} complete to align output queues, sizes of output_queues are {[ele.qsize() for ele in queues]}.")
queue0 = queues[0]
while queue0.qsize() > 0:
res = self.get_merged_data(queues, encode)
out_queue.put(res)
def rebatch_all_merged_data(self, model_node, in_queues, is_eval=False):# pylint: disable=unused-argument
if not model_node.input_nodes:
return in_queues
out_queues = [None] * len(in_queues)
num_consumers = self.batch_per_episode(model_node.model)
for index, (input_node, in_queue) in enumerate(zip(model_node.input_nodes, in_queues)):
num_producers = self.batch_per_episode(input_node.model)
if num_producers == num_consumers:
out_queues[index] = in_queue
else:
out_queues[index] = Queue()
res_list = []
while in_queue.qsize() > 0:
res = in_queue.get()
res = decode_data(res)[1]
res_list.append(res)
if num_producers > num_consumers:
# Deal with the case where num_producers > num_consumers
assert num_producers % num_consumers == 0, \
f"many2one: num_producers: {num_producers}, num_consumers: {num_consumers}, len inqueue: {len(in_queues)}"
division = num_producers // num_consumers
in_qsize = len(res_list)
out_qsize = in_qsize // division
for q_idx in range(out_qsize):
start = q_idx * division
end = start + division
out_queues[index].put(encode_data(q_idx, {CHATLEARN_REGROUP_TAG:res_list[start:end]}))
else:
# Deal with the case where num_producers < num_consumers
# TODO: add index for one2many case
assert num_consumers % num_producers == 0, \
f"one2many: num_producers: {num_producers}, num_consumers: {num_consumers}, len inqueue: {len(in_queues)}"
division = num_consumers // num_producers
in_qsize = len(res_list)
out_qsize = in_qsize * division
for q_idx in range(out_qsize):
start = q_idx // division
end = start + 1
out_queues[index].put(encode_data(q_idx, {CHATLEARN_REGROUP_TAG: res_list[start:end],
INDEX_TAG: (q_idx % division, division)}))
return out_queues
def get_next_data(self, in_queue, model_node, micro_batch_index):
if isinstance(in_queue, list):
if len(in_queue) > 0:
# this should happen for inference models, will trigger bug for training models
# since training models accept a list of remote object, which has the same
# behavior for models accept multiple inputs
# we need to deal with it later
assert not model_node.trainable
data = self.get_merged_data_locked(in_queue, micro_batch_index=micro_batch_index,
model_node=model_node, trainable=model_node.trainable)
mb, query = decode_data(data)
else:
mb, query = micro_batch_index, []
else:
data = self.get_merged_data_locked([in_queue], micro_batch_index=micro_batch_index,
model_node=model_node, trainable=model_node.trainable)
assert len(data['data']) == 1
data['data'] = data['data'][0]
mb, query = decode_data(data)
query = [query]
return mb, query
def generate_step_one_model_internal(self, model_node, in_queue, step_num, replica, func_name="forward_step", to_empty_cache=None,
is_eval=False, to_onload=None, to_offload=None, micro_batch_index=None):
"""
Args:
model: DistModel
in_queue: Queue
step_num: int
replica: current model replica of DistModel
func_name: str
to_empty_cache: None or boolean
"""
model = model_node.model
kwargs = {}
replica_num = len(model.replicas)
output = []
if isinstance(replica.model, VLLMModuleV2):
last_step_start = max(self.num_iteration(model) - replica_num, 0)
is_last_batch = step_num >= last_step_start
kwargs["is_last_batch"] = is_last_batch
if is_eval is not None:
kwargs["is_eval"] = is_eval
if to_empty_cache is not None:
kwargs["to_empty_cache"] = to_empty_cache
if to_onload is not None:
kwargs["to_onload"] = to_onload
if to_offload is not None:
kwargs["to_offload"] = to_offload
mb, query = self.get_next_data(in_queue, model_node, micro_batch_index)
assert isinstance(query, list)
ret = replica.call_actor_remote_func(replica.vllm_engine, func_name, *query, **kwargs)
output.append((ret, mb))
else:
last_step_start = max(self.num_iteration(model) - replica_num, 0)
is_last_batch = step_num >= last_step_start
kwargs["is_last_batch"] = is_last_batch
if to_empty_cache is not None:
kwargs["to_empty_cache"] = to_empty_cache
if to_onload is not None:
kwargs["to_onload"] = to_onload
if to_offload is not None:
kwargs["to_offload"] = to_offload
if is_eval is not None:
kwargs["is_eval"] = is_eval
for _, actors in replica.dp_rank_to_actors.items():
mb, query = self.get_next_data(in_queue, model_node, micro_batch_index)
assert isinstance(query, list)
for actor in actors:
ret = replica.call_actor_remote_func(actor, func_name, *query, **kwargs)
output.append((ret, mb))
return output
def generate_step_one_model(self, model_node, replica, in_queue, out_queue, step_num, func_name="forward_step",
to_empty_cache=None, is_eval=False, to_onload=None, to_offload=None, micro_batch_index=None):
"""
Args:
model: DistModel
in_queue: Queue
out_queue: Queue
step_num: int
func_name: str
to_empty_cache: None or boolean
"""
model = model_node.model
# output is a list of tuple, each tuple is (remote_refs, mb)
output = self.generate_step_one_model_internal(model_node, in_queue, step_num, replica, func_name, to_empty_cache,
is_eval, to_onload, to_offload, micro_batch_index)
num_dp_rank = len(replica.dp_rank_to_actors)
if model.module_args.zero_size == 1:
# If (tp > 1 or pp > 1) and ep = 1 for current model, its `output` will be a list whose
# length is the number of Actors. In this case, all members in the list
# are the same, and we choose output[-1] to put into out_queue.
# If (tp > 1 or pp > 1) and ep > 1, we choose last output for each dp rank to put into
# out_queue.
if model.module_args.expert_model_parallel_size == 1 and num_dp_rank == 1:
result = [output[-1]]
else:
num_output = len(output)
assert num_output % num_dp_rank == 0, (
f"The number of outputs ({num_output}) must be divisible by "
f"the number of dp_ranks ({num_dp_rank}) in a replica."
)
interval = num_output // num_dp_rank
result = [output[i] for i in range(interval - 1, num_output, interval)]
else:
result = output
if isinstance(out_queue, list):
for oq in out_queue:
for res, mb in result:
oq.put(encode_data(mb, res))
else:
for res, mb in result:
out_queue.put(encode_data(mb, res))
# To ensure all Actors are finished synchronously, all remote refs should be returned
# note that ray wait does not support tuple type, return a list of list
remote_refs = [item[0] for item in output]
return out_queue, remote_refs
def regroup_inqueue(self, model_node, queues, is_eval=False):
if self.args.policy_to_regroup_queue == "global_barrier":
# barrier to regroup all queues of producer node
if not isinstance(queues, list):
queues = [queues]
logger.info(f"{LOG_START} regroup_inqueue in_queue {model_node}: {[ele.qsize() for ele in queues]}")
out_queues = self.rebatch_all_merged_data(model_node, queues, is_eval=is_eval)
logger.info(f"{LOG_START} regroup_inqueue out_queues {model_node}: {[ele.qsize() for ele in out_queues]}")
return out_queues
else:
raise RuntimeError(f"Unsupported policy_to_regroup_queue {self.args.policy_to_regroup_queue}.")
def compute_loop_one_model(self, model_node, num_batch=None):
logger.info(f"{LOG_START} start compute_loop for {model_node}, is_eval={self.is_eval}")
model = model_node.model
is_eval = self.is_eval
if num_batch is None:
num_batch = self.num_iteration(model)
func_name = model_node.func_name
if model_node.remote_objects_to_wait:
logger.info(f"{LOG_START} start to wait colocate models to finish for {model_node}")
model_node.wait_colocate_models_to_finish(self.timers, func_name)
logger.info(f"{LOG_START} complete to wait colocate models to finish for {model_node}")
replica_num = len(model.replicas)
last_step_start = max(num_batch - replica_num, 0)
in_queue = model_node.get_input_queues()
logger.info(f"{LOG_START} start to regroup in_queue for {model_node}")
in_queue = self.regroup_inqueue(model_node, in_queue, is_eval=is_eval)
logger.info(f"{LOG_START} complete to regroup in_queue for {model_node}")
if isinstance(in_queue, list) and len(in_queue) == 1:
in_queue = in_queue[0]
results = []
logger.info(f"{LOG_START} start to generate_step_one_model for {model_node}")
self.timers(f"{model.name}").start()
for step in range(num_batch):
to_empty_cache = step >= last_step_start and (model.is_colocate or model.module_args.force_free_memory)
to_onload = step < replica_num and ((model.is_colocate and model.enable_offload) or model.module_args.force_free_memory)
to_offload = step >= last_step_start and ((model.is_colocate and model.enable_offload) or model.module_args.force_free_memory)
replica = self._next_model(model)
_, data = self.generate_step_one_model(model_node, replica, in_queue, model_node.out_queues, step, func_name, to_empty_cache,
is_eval=is_eval, to_onload=to_onload, to_offload=to_offload)
results.append(data)
self.timers(f"{model.name}").stop()
if model_node.next_colocate_node:
# before the execution of next colocate model, perform the wait, since we want to empty the cache.
logger.info(
f"{LOG_START} Model {model_node.next_colocate_node} will wait model {model} to finish since they are colocated")
self._models_and_results_to_wait = model_node.next_colocate_node.add_dependent_colocate_model_results(
model_node, results, self._models_and_results_to_wait)
elif model.colocate_models or model.trainable:
# 1. the model may colocate with training/inference, so we should wait until the end of compute_loop
# 2. the model is trainable and it does not have next_colocate_model, we should make sure it is finished before parameter_sync
# so we add them to a temp list
logger.info(f"{LOG_START} Sync {model} in the end of {self.__class__.__name__}")
self._models_and_results_to_wait.append((model_node, results))
def compute_loop(self, out_queue, num_batch=None):
for model_group in self.model_flow.flow_topology:
for model_node in model_group:
self.compute_loop_one_model(model_node, num_batch)
data = [None] * len(self.model_flow.return_model_nodes)
for model_node in self.model_flow.model_nodes:
self.timers(f"{model_node.model.name}").start()
if model_node in self.model_flow.return_model_nodes:
# let the results order follow model_node order
data[self.model_flow.return_model_nodes.index(model_node)] = model_node.out_queues[-1]
self.timers(f"{model_node.model.name}").stop()
model_names = []
results = []
for model, result in self._models_and_results_to_wait:
model_names.append(model.name)
results.extend(result)
if results:
for model_name in model_names:
self.timers(f"{model_name}").start()
func_name = self.model_flow.model_nodes[0].func_name
future.wait(results, f"{model_names} {func_name}")
for model_name in model_names:
self.timers(f"{model_name}").stop()
self._models_and_results_to_wait = []
if data:
self.get_all_merged_data(data, out_queue, encode=False)
def setup_queues(self):
data_queues = []
out_queue = Queue()
for model_node in self.model_flow.input_consumers:
data_queue = Queue()
data_queues.append(data_queue)
model_node.set_input_queue(data_queue)
for model_node in self.model_flow.model_nodes:
num_out_queue = len(model_node.output_nodes)
if model_node in self.model_flow.return_model_nodes:
num_out_queue += 1
model_node.set_out_queues([Queue() for _ in range(num_out_queue)])
return data_queues, out_queue
# pylint: disable=not-callable