in chatlearn/runtime/model_flow.py [0:0]
def trace(self, models, compute_flow):
"""
Trace the model compute_flow to get model graph.
Args
----
models: List(DistModel)
a list of DistModel
compute_flow: callable
compute_flow function
"""
local_models = [model.replicas[0].model for model in models]
self.name2remote_model = {model.name: model for model in models}
for model in local_models:
for func_name in self.cls.model_to_call_funcs[model]:
decorate_class_func(model.__class__, func_name, self.fake_compute)
dummy_data = DummyData()
assert compute_flow is not None
dummy_output = compute_flow(dummy_data)
# convert decorator back
for model in local_models:
for func_name in self.cls.model_to_call_funcs[model]:
setattr(model.__class__, func_name, unwrap_func(getattr(model.__class__, func_name), level=1))
if dummy_output:
if isinstance(dummy_output, DummyData):
dummy_output = [dummy_output]
for do in dummy_output:
self.return_model_nodes.append(do.from_node)
self.input_consumers = dummy_data.to_nodes
self.flow_topology = self.topological_sort()
self.model_nodes = flatten(self.flow_topology)
for i, current_node in enumerate(self.model_nodes):
for j in range(i + 1, len(self.model_nodes)):
next_node = self.model_nodes[j]
# if current_node and next_node share the same model, then thay are colocated
if current_node.model.colocate_with(next_node.model) or current_node.model is next_node.model:
current_node.next_colocate_node = next_node
break