in nni/retiarii/execution/logical_optimizer/logical_plan.py [0:0]
def assemble(self, multi_model_placement: Dict[Model, Device]) \
-> Tuple[Model, Dict[Node, Device]]:
"""
Given a set of models to be formed in a physical model and their device placement,
this function replaces all the logical node in this LogicalPlan with executable physical nodes
for the physical model.
Parameters
----------
multi_model_placement : dict
a dict of models and device placement.
These models will be assembled into the same physical model to run.
Returns
-------
phy_model : Model
the physical model formed by models in `multi_model_placement`
all logical node are replaced by physical nodes
node_placements : dict
the device placement of the nodes in `phy_model`
"""
phy_model = Model(_internal=True)
phy_graph = self.lp_model.root_graph._fork_to(phy_model)
phy_graph._rename_graph(phy_graph.name, "_model")
# merge sub-graphs
for model in multi_model_placement:
if phy_model.evaluator is None and model.evaluator is not None:
phy_model.evaluator = model.evaluator
for graph_name in model.graphs:
if graph_name != model._root_graph_name:
new_graph = model.graphs[graph_name]._fork_to(
phy_model, name_prefix=f'M_{model.model_id}_')
# prefix of M_ of hidden_nodes name in non-root graphs is added here
for new_node in new_graph.hidden_nodes:
if isinstance(new_node.operation, Cell):
old_cell_name = new_node.operation.cell_name
new_node.operation = copy.deepcopy(new_node.operation)
new_node.operation.cell_name = f'M_{model.model_id}_{old_cell_name}'
assert(phy_model.evaluator is not None)
# When replace logical nodes, merge the training configs when
# input/output nodes are replaced.
evaluator_slot = {} # Model ID -> Slot ID
input_slot_mapping = {}
output_slot_mapping = {}
# Replace all logical nodes to executable physical nodes
hidden_nodes = phy_graph.hidden_nodes.copy()
node_placements = {}
added_models = []
for node in hidden_nodes:
if isinstance(node, OriginNode):
model_id = node.original_graph.model.model_id
if node.original_graph.model not in multi_model_placement:
for edge in node.incoming_edges:
edge.remove()
for edge in node.outgoing_edges:
edge.remove()
node.remove()
continue
if isinstance(node, AbstractLogicalNode):
new_node, placement = node.assemble(multi_model_placement)
if isinstance(new_node.operation, _IOPseudoOperation):
model_id = new_node.graph.model.model_id
if model_id not in evaluator_slot:
added_models.append(model_id)
evaluator_slot[model_id] = len(added_models) - 1
slot = evaluator_slot[model_id]
else:
slot = evaluator_slot[model_id]
# If a model's inputs/outputs are not used in the multi-model
# the codegen and trainer should not generate and use them
# "use_input" and "use_output" are used to mark whether
# an input/output of a model is used in a multi-model
if new_node.operation.type == '_inputs':
input_slot_mapping[new_node] = slot
if new_node.operation.type == '_outputs':
output_slot_mapping[new_node] = slot
self.node_replace(node, new_node)
# name prefix of M_ of cells in hidden_nodes of root graphs is added here
# FIXME: merge this rename with non-root graph, only do once.
if isinstance(new_node.operation, Cell):
old_cell_name = new_node.operation.cell_name
new_node.operation = copy.deepcopy(new_node.operation)
new_node.operation.cell_name = f'M_{model_id}_{old_cell_name}'
# input should be at CPU, move it to GPU first if necessary
if isinstance(new_node.operation, _IOPseudoOperation) and new_node.operation.type == '_inputs':
# hack: only support single_server
node_placements[new_node] = CPUDevice(node_id=placement.node_id)
else:
node_placements[new_node] = placement
node.remove()
# If two nodes are placed on different devices, use ToDevice op to copy the node
# TODO: when copying one node to multiple devices, broadcast is more efficient than P2P communication
existing_edges = phy_graph.edges.copy()
# Avoid a node is copied multiple times on the same device
copied_op: Dict[Tuple(Node, Device), Node] = {}
for edge in existing_edges:
head_placement = node_placements[edge.head]
tail_placement = node_placements[edge.tail]
if head_placement != tail_placement:
if head_placement.node_id != tail_placement.node_id:
raise ValueError('Cross-server placement is not supported.')
# Same server different devices
if (edge.head, tail_placement) in copied_op:
to_node = copied_op[(edge.head, tail_placement)]
else:
dst_name = edge.head.name + "_to_" + edge.tail.name
to_operation = Operation.new(
'ToDevice', {
"device": tail_placement, "src": (
edge.head.name, edge.head_slot), "dst": dst_name})
to_node = Node(phy_graph, uid(), dst_name, to_operation)._register()
Edge((edge.head, edge.head_slot), (to_node, None), _internal=True)._register()
copied_op[(edge.head, tail_placement)] = to_node
node_placements[to_node] = head_placement
edge.head = to_node
edge.head_slot = None
# merge all input nodes into one with multiple slots
input_nodes = []
for node in phy_graph.hidden_nodes:
if isinstance(node.operation, _IOPseudoOperation) and node.operation.type == '_inputs':
input_nodes.append(node)
for edge in phy_graph.edges:
if edge.head in input_nodes:
edge.head_slot = input_slot_mapping[edge.head]
edge.head = phy_graph.input_node
# merge all output nodes into one with multiple slots
output_nodes = []
for node in phy_graph.hidden_nodes:
if isinstance(node.operation, _IOPseudoOperation) and node.operation.type == '_outputs':
output_nodes.append(node)
for edge in phy_graph.edges:
if edge.tail in output_nodes:
edge.tail_slot = output_slot_mapping[edge.tail]
edge.tail = phy_graph.output_node
for node in input_nodes:
node.remove()
for node in output_nodes:
node.remove()
return phy_model, node_placements