def trace()

in chatlearn/runtime/model_flow.py [0:0]


    def trace(self, models, compute_flow):
        """
        Trace the model compute_flow to get model graph.

        Args
        ----
        models: List(DistModel)
            a list of DistModel
        compute_flow: callable
            compute_flow function
        """
        local_models = [model.replicas[0].model for model in models]
        self.name2remote_model = {model.name: model for model in models}
        for model in local_models:
            for func_name in self.cls.model_to_call_funcs[model]:
                decorate_class_func(model.__class__, func_name, self.fake_compute)

        dummy_data = DummyData()
        assert compute_flow is not None
        dummy_output = compute_flow(dummy_data)
        # convert decorator back
        for model in local_models:
            for func_name in self.cls.model_to_call_funcs[model]:
                setattr(model.__class__, func_name, unwrap_func(getattr(model.__class__, func_name), level=1))

        if dummy_output:
            if isinstance(dummy_output, DummyData):
                dummy_output = [dummy_output]
            for do in dummy_output:
                self.return_model_nodes.append(do.from_node)

        self.input_consumers = dummy_data.to_nodes
        self.flow_topology = self.topological_sort()
        self.model_nodes = flatten(self.flow_topology)
        for i, current_node in enumerate(self.model_nodes):
            for j in range(i + 1, len(self.model_nodes)):
                next_node = self.model_nodes[j]
                # if current_node and next_node share the same model, then thay are colocated
                if current_node.model.colocate_with(next_node.model) or current_node.model is next_node.model:
                    current_node.next_colocate_node = next_node
                    break