chatlearn/runtime/executor.py (366 lines of code) (raw):

# Copyright 2024 Alibaba Group Holding Limited. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Executor""" import threading from collections import defaultdict from itertools import cycle from ray.util.queue import Queue import torch from chatlearn.models.vllm_module_v2 import VLLMModuleV2 from chatlearn.runtime.model_flow import ModelFlow from chatlearn.utils import future from chatlearn.utils.constant import CHATLEARN_REGROUP_TAG, INDEX_TAG from chatlearn.utils.constant import LOG_START from chatlearn.utils.global_vars import get_args from chatlearn.utils.logger import logger from .utils import encode_data, decode_data from .utils import FlowParser def split_list(lst, n): assert len(lst) % n == 0, f"{len(lst)} % {n} != 0" k = len(lst) // n return [lst[i*k:(i+1)*k] for i in range(n)] def split_along_batch(tensors, num_splits): res = [{} for _ in range(num_splits)] if tensors is None: return res for key in tensors.keys(): to_batch = tensors[key] if isinstance(to_batch, torch.Tensor): batched = to_batch.chunk(num_splits) elif isinstance(to_batch, list): batched = split_list(to_batch, num_splits) else: raise Exception(f"unknown types key: {key} and {type(to_batch)} to split: {key} {tensors.keys()} {to_batch}") for idx, ele in enumerate(batched): res[idx][key] = ele return res # pylint: disable=not-callable class Executor: """Executor""" def __init__(self, model_flow): """ Executor Args ---- models : List[BaseModule] a list of modules """ self._set_flow(model_flow) self.args = get_args().runtime_args self.model_flow = None self.local_models = self.models self._batch_per_episode = -1 self.is_eval = False self._timers = None self.model2iter = {} self.merged_buffer = defaultdict(dict) self._metric_list = [] def set_timers(self, _timers): self._timers = _timers @property def timers(self): return self._timers def _set_flow(self, flow): """ Set compution flow Args ---- flow : callable a function that defines model computation flow Returns ------- Executor return self """ self._flow = flow self.model_to_call_funcs = FlowParser().parse(flow) for model, func_names in self.model_to_call_funcs.items(): model.call_funcs += func_names self.models = list(self.model_to_call_funcs.keys()) return self @property def first_node(self): return self.model_flow.model_nodes[0] @property def first_model(self): return self.first_node.model def update_models(self, models): # update local model with remote models new_models = [] name_to_new_models = {model.name: model for model in models} for model in self.local_models: dist_model = name_to_new_models[model.name] dist_model.group_dist_actors_by_tp_rank() new_models.append(dist_model) self.models = new_models if self.args is None: self.args = get_args().runtime_args def setup(self): self._models_and_results_to_wait = [] self.model_flow = ModelFlow(self) self.model_flow.trace(self.models, self._flow) self.models = [model_node.model for model_node in self.model_flow.model_nodes] self.model_locks = {model_node: threading.Lock() for model_node in self.model_flow.model_nodes} def _next_model(self, model): if len(model.replicas) == 1: return model.replicas[0] if model not in self.model2iter: self.model2iter[model] = cycle(iter(model.replicas)) return next(self.model2iter[model]) def get_merged_data(self, queues, encode=True, micro_batch_index=None, model_node=None, trainable=False): mb0 = None if micro_batch_index is not None: mb0 = micro_batch_index data_list = [None] * len(queues) merged_buffer = self.merged_buffer[model_node] for index, queue in enumerate(queues): if index not in merged_buffer: merged_buffer[index] = {} if mb0 in merged_buffer[index]: data_list[index] = merged_buffer[index].pop(mb0) continue while True: flag = False while queue.qsize() == 0: if mb0 in merged_buffer[index]: data_list[index] = merged_buffer[index].pop(mb0) flag = True break if flag: break encoded_data = queue.get() mb, data = decode_data(encoded_data) if mb0 is None: mb0 = mb if isinstance(data, list) and not trainable: data = data[-1] if mb == mb0: data_list[index] = data break merged_buffer[index][mb] = data if encode: return encode_data(mb0, data_list) return data_list def get_merged_data_locked(self, queues, encode=True, micro_batch_index=None, model_node=None, trainable=False): with self.model_locks[model_node]: return self.get_merged_data(queues, encode, micro_batch_index, model_node, trainable) @staticmethod def align_out_queues(queues, encode=False): # TODO: deal with one2many scene out_queues = [] min_qsize = min([ele.qsize() for ele in queues]) # pylint: disable=consider-using-generator for queue in queues: num_producers = queue.qsize() if num_producers == min_qsize: out_queues.append(queue) continue assert num_producers % min_qsize == 0 out_queue = Queue() res_list = [] while queue.qsize() > 0: res = queue.get() res = decode_data(res)[1] if encode else res res_list.append(res) division = num_producers // min_qsize in_qsize = len(res_list) out_qsize = in_qsize // division for q_idx in range(out_qsize): start = q_idx * division end = start + division out_queue.put(encode_data(q_idx, {CHATLEARN_REGROUP_TAG:res_list[start:end]})) out_queues.append(out_queue) return out_queues def get_all_merged_data(self, queues, out_queue, encode=True): logger.info(f"{LOG_START} start to align output queues with sizes {[ele.qsize() for ele in queues]}.") queues = self.align_out_queues(queues, True) logger.info(f"{LOG_START} complete to align output queues, sizes of output_queues are {[ele.qsize() for ele in queues]}.") queue0 = queues[0] while queue0.qsize() > 0: res = self.get_merged_data(queues, encode) out_queue.put(res) def rebatch_all_merged_data(self, model_node, in_queues, is_eval=False):# pylint: disable=unused-argument if not model_node.input_nodes: return in_queues out_queues = [None] * len(in_queues) num_consumers = self.batch_per_episode(model_node.model) for index, (input_node, in_queue) in enumerate(zip(model_node.input_nodes, in_queues)): num_producers = self.batch_per_episode(input_node.model) if num_producers == num_consumers: out_queues[index] = in_queue else: out_queues[index] = Queue() res_list = [] while in_queue.qsize() > 0: res = in_queue.get() res = decode_data(res)[1] res_list.append(res) if num_producers > num_consumers: # Deal with the case where num_producers > num_consumers assert num_producers % num_consumers == 0, \ f"many2one: num_producers: {num_producers}, num_consumers: {num_consumers}, len inqueue: {len(in_queues)}" division = num_producers // num_consumers in_qsize = len(res_list) out_qsize = in_qsize // division for q_idx in range(out_qsize): start = q_idx * division end = start + division out_queues[index].put(encode_data(q_idx, {CHATLEARN_REGROUP_TAG:res_list[start:end]})) else: # Deal with the case where num_producers < num_consumers # TODO: add index for one2many case assert num_consumers % num_producers == 0, \ f"one2many: num_producers: {num_producers}, num_consumers: {num_consumers}, len inqueue: {len(in_queues)}" division = num_consumers // num_producers in_qsize = len(res_list) out_qsize = in_qsize * division for q_idx in range(out_qsize): start = q_idx // division end = start + 1 out_queues[index].put(encode_data(q_idx, {CHATLEARN_REGROUP_TAG: res_list[start:end], INDEX_TAG: (q_idx % division, division)})) return out_queues def get_next_data(self, in_queue, model_node, micro_batch_index): if isinstance(in_queue, list): if len(in_queue) > 0: # this should happen for inference models, will trigger bug for training models # since training models accept a list of remote object, which has the same # behavior for models accept multiple inputs # we need to deal with it later assert not model_node.trainable data = self.get_merged_data_locked(in_queue, micro_batch_index=micro_batch_index, model_node=model_node, trainable=model_node.trainable) mb, query = decode_data(data) else: mb, query = micro_batch_index, [] else: data = self.get_merged_data_locked([in_queue], micro_batch_index=micro_batch_index, model_node=model_node, trainable=model_node.trainable) assert len(data['data']) == 1 data['data'] = data['data'][0] mb, query = decode_data(data) query = [query] return mb, query def generate_step_one_model_internal(self, model_node, in_queue, step_num, replica, func_name="forward_step", to_empty_cache=None, is_eval=False, to_onload=None, to_offload=None, micro_batch_index=None): """ Args: model: DistModel in_queue: Queue step_num: int replica: current model replica of DistModel func_name: str to_empty_cache: None or boolean """ model = model_node.model kwargs = {} replica_num = len(model.replicas) output = [] if isinstance(replica.model, VLLMModuleV2): last_step_start = max(self.num_iteration(model) - replica_num, 0) is_last_batch = step_num >= last_step_start kwargs["is_last_batch"] = is_last_batch if is_eval is not None: kwargs["is_eval"] = is_eval if to_empty_cache is not None: kwargs["to_empty_cache"] = to_empty_cache if to_onload is not None: kwargs["to_onload"] = to_onload if to_offload is not None: kwargs["to_offload"] = to_offload mb, query = self.get_next_data(in_queue, model_node, micro_batch_index) assert isinstance(query, list) ret = replica.call_actor_remote_func(replica.vllm_engine, func_name, *query, **kwargs) output.append((ret, mb)) else: last_step_start = max(self.num_iteration(model) - replica_num, 0) is_last_batch = step_num >= last_step_start kwargs["is_last_batch"] = is_last_batch if to_empty_cache is not None: kwargs["to_empty_cache"] = to_empty_cache if to_onload is not None: kwargs["to_onload"] = to_onload if to_offload is not None: kwargs["to_offload"] = to_offload if is_eval is not None: kwargs["is_eval"] = is_eval for _, actors in replica.dp_rank_to_actors.items(): mb, query = self.get_next_data(in_queue, model_node, micro_batch_index) assert isinstance(query, list) for actor in actors: ret = replica.call_actor_remote_func(actor, func_name, *query, **kwargs) output.append((ret, mb)) return output def generate_step_one_model(self, model_node, replica, in_queue, out_queue, step_num, func_name="forward_step", to_empty_cache=None, is_eval=False, to_onload=None, to_offload=None, micro_batch_index=None): """ Args: model: DistModel in_queue: Queue out_queue: Queue step_num: int func_name: str to_empty_cache: None or boolean """ model = model_node.model # output is a list of tuple, each tuple is (remote_refs, mb) output = self.generate_step_one_model_internal(model_node, in_queue, step_num, replica, func_name, to_empty_cache, is_eval, to_onload, to_offload, micro_batch_index) num_dp_rank = len(replica.dp_rank_to_actors) if model.module_args.zero_size == 1: # If (tp > 1 or pp > 1) and ep = 1 for current model, its `output` will be a list whose # length is the number of Actors. In this case, all members in the list # are the same, and we choose output[-1] to put into out_queue. # If (tp > 1 or pp > 1) and ep > 1, we choose last output for each dp rank to put into # out_queue. if model.module_args.expert_model_parallel_size == 1 and num_dp_rank == 1: result = [output[-1]] else: num_output = len(output) assert num_output % num_dp_rank == 0, ( f"The number of outputs ({num_output}) must be divisible by " f"the number of dp_ranks ({num_dp_rank}) in a replica." ) interval = num_output // num_dp_rank result = [output[i] for i in range(interval - 1, num_output, interval)] else: result = output if isinstance(out_queue, list): for oq in out_queue: for res, mb in result: oq.put(encode_data(mb, res)) else: for res, mb in result: out_queue.put(encode_data(mb, res)) # To ensure all Actors are finished synchronously, all remote refs should be returned # note that ray wait does not support tuple type, return a list of list remote_refs = [item[0] for item in output] return out_queue, remote_refs def regroup_inqueue(self, model_node, queues, is_eval=False): if self.args.policy_to_regroup_queue == "global_barrier": # barrier to regroup all queues of producer node if not isinstance(queues, list): queues = [queues] logger.info(f"{LOG_START} regroup_inqueue in_queue {model_node}: {[ele.qsize() for ele in queues]}") out_queues = self.rebatch_all_merged_data(model_node, queues, is_eval=is_eval) logger.info(f"{LOG_START} regroup_inqueue out_queues {model_node}: {[ele.qsize() for ele in out_queues]}") return out_queues else: raise RuntimeError(f"Unsupported policy_to_regroup_queue {self.args.policy_to_regroup_queue}.") def compute_loop_one_model(self, model_node, num_batch=None): logger.info(f"{LOG_START} start compute_loop for {model_node}, is_eval={self.is_eval}") model = model_node.model is_eval = self.is_eval if num_batch is None: num_batch = self.num_iteration(model) func_name = model_node.func_name if model_node.remote_objects_to_wait: logger.info(f"{LOG_START} start to wait colocate models to finish for {model_node}") model_node.wait_colocate_models_to_finish(self.timers, func_name) logger.info(f"{LOG_START} complete to wait colocate models to finish for {model_node}") replica_num = len(model.replicas) last_step_start = max(num_batch - replica_num, 0) in_queue = model_node.get_input_queues() logger.info(f"{LOG_START} start to regroup in_queue for {model_node}") in_queue = self.regroup_inqueue(model_node, in_queue, is_eval=is_eval) logger.info(f"{LOG_START} complete to regroup in_queue for {model_node}") if isinstance(in_queue, list) and len(in_queue) == 1: in_queue = in_queue[0] results = [] logger.info(f"{LOG_START} start to generate_step_one_model for {model_node}") self.timers(f"{model.name}").start() for step in range(num_batch): to_empty_cache = step >= last_step_start and (model.is_colocate or model.module_args.force_free_memory) to_onload = step < replica_num and ((model.is_colocate and model.enable_offload) or model.module_args.force_free_memory) to_offload = step >= last_step_start and ((model.is_colocate and model.enable_offload) or model.module_args.force_free_memory) replica = self._next_model(model) _, data = self.generate_step_one_model(model_node, replica, in_queue, model_node.out_queues, step, func_name, to_empty_cache, is_eval=is_eval, to_onload=to_onload, to_offload=to_offload) results.append(data) self.timers(f"{model.name}").stop() if model_node.next_colocate_node: # before the execution of next colocate model, perform the wait, since we want to empty the cache. logger.info( f"{LOG_START} Model {model_node.next_colocate_node} will wait model {model} to finish since they are colocated") self._models_and_results_to_wait = model_node.next_colocate_node.add_dependent_colocate_model_results( model_node, results, self._models_and_results_to_wait) elif model.colocate_models or model.trainable: # 1. the model may colocate with training/inference, so we should wait until the end of compute_loop # 2. the model is trainable and it does not have next_colocate_model, we should make sure it is finished before parameter_sync # so we add them to a temp list logger.info(f"{LOG_START} Sync {model} in the end of {self.__class__.__name__}") self._models_and_results_to_wait.append((model_node, results)) def compute_loop(self, out_queue, num_batch=None): for model_group in self.model_flow.flow_topology: for model_node in model_group: self.compute_loop_one_model(model_node, num_batch) data = [None] * len(self.model_flow.return_model_nodes) for model_node in self.model_flow.model_nodes: self.timers(f"{model_node.model.name}").start() if model_node in self.model_flow.return_model_nodes: # let the results order follow model_node order data[self.model_flow.return_model_nodes.index(model_node)] = model_node.out_queues[-1] self.timers(f"{model_node.model.name}").stop() model_names = [] results = [] for model, result in self._models_and_results_to_wait: model_names.append(model.name) results.extend(result) if results: for model_name in model_names: self.timers(f"{model_name}").start() func_name = self.model_flow.model_nodes[0].func_name future.wait(results, f"{model_names} {func_name}") for model_name in model_names: self.timers(f"{model_name}").stop() self._models_and_results_to_wait = [] if data: self.get_all_merged_data(data, out_queue, encode=False) def setup_queues(self): data_queues = [] out_queue = Queue() for model_node in self.model_flow.input_consumers: data_queue = Queue() data_queues.append(data_queue) model_node.set_input_queue(data_queue) for model_node in self.model_flow.model_nodes: num_out_queue = len(model_node.output_nodes) if model_node in self.model_flow.return_model_nodes: num_out_queue += 1 model_node.set_out_queues([Queue() for _ in range(num_out_queue)]) return data_queues, out_queue # pylint: disable=not-callable