def assemble()

in nni/retiarii/execution/logical_optimizer/logical_plan.py [0:0]


    def assemble(self, multi_model_placement: Dict[Model, Device]) \
            -> Tuple[Model, Dict[Node, Device]]:
        """
        Given a set of models to be formed in a physical model and their device placement,
        this function replaces all the logical node in this LogicalPlan with executable physical nodes
        for the physical model.

        Parameters
        ----------
        multi_model_placement : dict
            a dict of models and device placement.
            These models will be assembled into the same physical model to run.

        Returns
        -------
        phy_model : Model
            the physical model formed by models in `multi_model_placement`
            all logical node are replaced by physical nodes
        node_placements : dict
            the device placement of the nodes in `phy_model`
        """
        phy_model = Model(_internal=True)
        phy_graph = self.lp_model.root_graph._fork_to(phy_model)
        phy_graph._rename_graph(phy_graph.name, "_model")

        # merge sub-graphs
        for model in multi_model_placement:
            if phy_model.evaluator is None and model.evaluator is not None:
                phy_model.evaluator = model.evaluator
            for graph_name in model.graphs:
                if graph_name != model._root_graph_name:
                    new_graph = model.graphs[graph_name]._fork_to(
                        phy_model, name_prefix=f'M_{model.model_id}_')

                    # prefix of M_ of hidden_nodes name in non-root graphs is added here
                    for new_node in new_graph.hidden_nodes:
                        if isinstance(new_node.operation, Cell):
                            old_cell_name = new_node.operation.cell_name
                            new_node.operation = copy.deepcopy(new_node.operation)
                            new_node.operation.cell_name = f'M_{model.model_id}_{old_cell_name}'

        assert(phy_model.evaluator is not None)

        # When replace logical nodes, merge the training configs when
        # input/output nodes are replaced.
        evaluator_slot = {}  # Model ID -> Slot ID
        input_slot_mapping = {}
        output_slot_mapping = {}
        # Replace all logical nodes to executable physical nodes
        hidden_nodes = phy_graph.hidden_nodes.copy()
        node_placements = {}

        added_models = []

        for node in hidden_nodes:
            if isinstance(node, OriginNode):
                model_id = node.original_graph.model.model_id
                if node.original_graph.model not in multi_model_placement:
                    for edge in node.incoming_edges:
                        edge.remove()
                    for edge in node.outgoing_edges:
                        edge.remove()
                    node.remove()
                    continue

            if isinstance(node, AbstractLogicalNode):
                new_node, placement = node.assemble(multi_model_placement)
                if isinstance(new_node.operation, _IOPseudoOperation):
                    model_id = new_node.graph.model.model_id
                    if model_id not in evaluator_slot:
                        added_models.append(model_id)
                        evaluator_slot[model_id] = len(added_models) - 1
                        slot = evaluator_slot[model_id]
                    else:
                        slot = evaluator_slot[model_id]
                    # If a model's inputs/outputs are not used in the multi-model
                    # the codegen and trainer should not generate and use them
                    # "use_input" and "use_output" are used to mark whether
                    # an input/output of a model is used in a multi-model
                    if new_node.operation.type == '_inputs':
                        input_slot_mapping[new_node] = slot
                    if new_node.operation.type == '_outputs':
                        output_slot_mapping[new_node] = slot

                self.node_replace(node, new_node)

                # name prefix of M_ of cells in hidden_nodes of root graphs is added here
                # FIXME: merge this rename with non-root graph, only do once.
                if isinstance(new_node.operation, Cell):
                    old_cell_name = new_node.operation.cell_name
                    new_node.operation = copy.deepcopy(new_node.operation)
                    new_node.operation.cell_name = f'M_{model_id}_{old_cell_name}'

                # input should be at CPU, move it to GPU first if necessary
                if isinstance(new_node.operation, _IOPseudoOperation) and new_node.operation.type == '_inputs':
                    # hack: only support single_server
                    node_placements[new_node] = CPUDevice(node_id=placement.node_id)
                else:
                    node_placements[new_node] = placement

                node.remove()

        # If two nodes are placed on different devices, use ToDevice op to copy the node
        # TODO: when copying one node to multiple devices, broadcast is more efficient than P2P communication
        existing_edges = phy_graph.edges.copy()
        # Avoid a node is copied multiple times on the same device
        copied_op: Dict[Tuple(Node, Device), Node] = {}
        for edge in existing_edges:
            head_placement = node_placements[edge.head]
            tail_placement = node_placements[edge.tail]
            if head_placement != tail_placement:
                if head_placement.node_id != tail_placement.node_id:
                    raise ValueError('Cross-server placement is not supported.')
                # Same server different devices
                if (edge.head, tail_placement) in copied_op:
                    to_node = copied_op[(edge.head, tail_placement)]
                else:
                    dst_name = edge.head.name + "_to_" + edge.tail.name
                    to_operation = Operation.new(
                        'ToDevice', {
                            "device": tail_placement, "src": (
                                edge.head.name, edge.head_slot), "dst": dst_name})
                    to_node = Node(phy_graph, uid(), dst_name, to_operation)._register()
                    Edge((edge.head, edge.head_slot), (to_node, None), _internal=True)._register()
                    copied_op[(edge.head, tail_placement)] = to_node
                    node_placements[to_node] = head_placement
                edge.head = to_node
                edge.head_slot = None

        # merge all input nodes into one with multiple slots
        input_nodes = []
        for node in phy_graph.hidden_nodes:
            if isinstance(node.operation, _IOPseudoOperation) and node.operation.type == '_inputs':
                input_nodes.append(node)

        for edge in phy_graph.edges:
            if edge.head in input_nodes:
                edge.head_slot = input_slot_mapping[edge.head]
                edge.head = phy_graph.input_node

        # merge all output nodes into one with multiple slots
        output_nodes = []
        for node in phy_graph.hidden_nodes:
            if isinstance(node.operation, _IOPseudoOperation) and node.operation.type == '_outputs':
                output_nodes.append(node)

        for edge in phy_graph.edges:
            if edge.tail in output_nodes:
                edge.tail_slot = output_slot_mapping[edge.tail]
                edge.tail = phy_graph.output_node

        for node in input_nodes:
            node.remove()
        for node in output_nodes:
            node.remove()

        return phy_model, node_placements