chatlearn/runtime/model_flow.py (171 lines of code) (raw):

# Copyright 2024 Alibaba Group Holding Limited. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Model FLow""" from collections import defaultdict, deque from chatlearn.utils import future from chatlearn.utils.global_vars import unwrap_func from chatlearn.utils.global_vars import reset_dependencies, set_dependencies, get_dependencies from chatlearn.utils.utils import flatten from .decorator import decorate_class_func class ControlDependencies: """ControlDependencies""" def __init__(self, dependencies): if not isinstance(dependencies, list): dependencies = [dependencies] self.dependencies = dependencies def __enter__(self): set_dependencies(self.dependencies) return self def __exit__(self, exc_type, exc_value, traceback): reset_dependencies() class DummyData: """DummyData to trace ModelGraph""" def __init__(self, from_node=None): self.from_node = from_node self.to_nodes = [] class ModelNode: """ModelNode""" def __init__(self, model, func_name): self.model = model self.name = model.name self.func_name = func_name self.input_nodes = [] self.output_nodes = [] self.out_queues = None self._input_queue = None # next colocate model node to execute self.next_colocate_node = None # model to wait before the execution of current model self.models_to_wait = [] # remote objects to wait before the execution of current model self.remote_objects_to_wait = [] self.dependent_output_nodes = [] self.trainable = False def add_input_node(self, node): if node in self.input_nodes: raise RuntimeError(f"{node} already added to {self} inputs") self.input_nodes.append(node) node.add_output_node(self) def add_output_node(self, model): self.output_nodes.append(model) def set_out_queues(self, queues): self.out_queues = queues def set_input_queue(self, queue): self._input_queue = queue def get_input_queues(self): input_queues = [] if self._input_queue is not None: input_queues.append(self._input_queue) for input_model_node in self.input_nodes: out_index = input_model_node.output_nodes.index(self) input_queues.append(input_model_node.out_queues[out_index]) if len(input_queues) == 1: return input_queues[0] return input_queues def _find_all_parents(self, model, prev_models_results): parents_models = [] parents_results = [] queue = deque([model]) visited = set() while queue: cur_model = queue.pop() if cur_model in visited: continue visited.add(cur_model) for prev_model, results in prev_models_results: if prev_model in cur_model.input_nodes and prev_model not in parents_models: parents_models.append(prev_model) parents_results.append(results) queue.append(prev_model) # reverse return parents_models[::-1], parents_results[::-1] def add_dependent_colocate_model_results(self, model, remote_objects, models_and_results_to_wait): # for models that are not colocated with current model, if their colocated model need to wait # the parent of their colocated model also need to wait dependent_models_not_colocate, dependent_results_not_colocate = self._find_all_parents(model, models_and_results_to_wait) models_and_results_to_wait2 = [(model, results) for model, results in models_and_results_to_wait \ if model not in dependent_models_not_colocate] for prev_model, result in zip(dependent_models_not_colocate, dependent_results_not_colocate): self.models_to_wait.append(prev_model) self.remote_objects_to_wait.extend(result) self.models_to_wait.append(model) self.remote_objects_to_wait.extend(remote_objects) return models_and_results_to_wait2 def wait_colocate_models_to_finish(self, timers, func_name): for model in self.models_to_wait: timers(f"{model.name}").start() future.wait(self.remote_objects_to_wait, f"{[model.name for model in self.models_to_wait]} {func_name}") for model in self.models_to_wait: timers(f"{model.name}").stop() self.remote_objects_to_wait = [] self.models_to_wait = [] def __str__(self): return f"{self.__class__.__name__}({self.model}) {self.func_name}" def __repr__(self): return f'<{self.__class__.__name__}({self.model}) {self.func_name} object at {hex(id(self))}>' class ModelFlow: """ModelFlow""" def __init__(self, cls): self.model_nodes = [] self.return_model_nodes = [] self.cls = cls # models that consumes input data self.input_consumers = [] def fake_compute(self, fn): def inner(*args): assert len(args) > 0 original_fn = unwrap_func(fn) func_name = original_fn.__name__ model_node = ModelNode(args[0], func_name) dist_model = self.name2remote_model[model_node.name] model_node.model = dist_model dist_model.model_node = model_node self.model_nodes.append(model_node) for data in args[1:]: if isinstance(data, DummyData): data.to_nodes.append(model_node) if data.from_node: model_node.add_input_node(data.from_node) dependencies = get_dependencies() if dependencies is not None: for dep in dependencies: dep.from_node.dependent_output_nodes.append(model_node) res = DummyData(model_node) return res return inner def trace(self, models, compute_flow): """ Trace the model compute_flow to get model graph. Args ---- models: List(DistModel) a list of DistModel compute_flow: callable compute_flow function """ local_models = [model.replicas[0].model for model in models] self.name2remote_model = {model.name: model for model in models} for model in local_models: for func_name in self.cls.model_to_call_funcs[model]: decorate_class_func(model.__class__, func_name, self.fake_compute) dummy_data = DummyData() assert compute_flow is not None dummy_output = compute_flow(dummy_data) # convert decorator back for model in local_models: for func_name in self.cls.model_to_call_funcs[model]: setattr(model.__class__, func_name, unwrap_func(getattr(model.__class__, func_name), level=1)) if dummy_output: if isinstance(dummy_output, DummyData): dummy_output = [dummy_output] for do in dummy_output: self.return_model_nodes.append(do.from_node) self.input_consumers = dummy_data.to_nodes self.flow_topology = self.topological_sort() self.model_nodes = flatten(self.flow_topology) for i, current_node in enumerate(self.model_nodes): for j in range(i + 1, len(self.model_nodes)): next_node = self.model_nodes[j] # if current_node and next_node share the same model, then thay are colocated if current_node.model.colocate_with(next_node.model) or current_node.model is next_node.model: current_node.next_colocate_node = next_node break def topological_sort(self): result = [] level_map = defaultdict(list) in_degree = defaultdict(int) # Calculate the in-degree of each vertex for u in self.model_nodes: for v in u.output_nodes: in_degree[v] += 1 for v in u.dependent_output_nodes: in_degree[v] += 1 # Enqueue all the vertices with an in-degree of 0 queue = deque([u for u in self.model_nodes if in_degree[u] == 0]) # Perform topological sorting while queue: current_level = [] for _ in range(len(queue)): current = queue.popleft() current_level.append(current) result.append(current) # Decrement the in-degree of adjacent vertices for v in current.output_nodes + current.dependent_output_nodes: in_degree[v] -= 1 if in_degree[v] == 0: queue.append(v) level_map[len(result)].extend(current_level) # Check if the graph contains a cycle if len(result) != len(self.model_nodes): raise RuntimeError("Please check if the graph contains a cycle") return [v[1] for v in sorted(level_map.items())]