in ngraph_bridge/encapsulate_clusters.cc [398:554]
Status Encapsulator::RewritePass(
int graph_id,
const std::unordered_map<std::string, std::string>& device_config) {
if (!analysis_done) {
return errors::Internal(
"In Encapsulator, called RewritePass without calling AnalysisPass");
}
if (rewrite_done) {
return errors::Internal(
"In Encapsulator, called RewritePass more than once");
}
// Pass 3: Create encapsulation nodes for all clusters.
for (auto& kv : device_name_map) {
int cluster_idx = kv.first;
std::stringstream ss;
ss << "ngraph_cluster_" << cluster_idx;
string encap_node_name = ss.str();
std::vector<DataType> input_types;
std::vector<NodeBuilder::NodeOut> inputs;
for (auto& tup : cluster_input_map[cluster_idx]) {
int src_node_id;
int src_output_idx;
DataType dt;
std::tie(src_node_id, src_output_idx, dt) = tup;
input_types.push_back(dt);
inputs.push_back(
NodeBuilder::NodeOut(graph->FindNodeId(src_node_id), src_output_idx));
}
Node* n;
NodeBuilder nb = NodeBuilder(encap_node_name, "_nGraphEncapsulate")
.Attr("ngraph_cluster", cluster_idx)
.Attr("Targuments", input_types)
.Attr("Tresults", cluster_output_dt_map[cluster_idx])
.Attr("ngraph_graph_id", graph_id)
.Device(device_name_map[cluster_idx])
.Input(inputs);
if (!device_config.empty()) {
NGRAPH_VLOG(3) << "Device config is not empty";
for (auto const& i : device_config) {
// Adding the optional attributes
NGRAPH_VLOG(3) << "Attaching Attribute " << i.first << " Val "
<< i.second;
nb.Attr(i.first, i.second);
}
}
// Find Static Inputs And Add as an attribute
vector<int> static_input_indexes;
GraphDef* gdef_for_current_encapsulate;
gdef_for_current_encapsulate = ClusterManager::GetClusterGraph(cluster_idx);
if (gdef_for_current_encapsulate == nullptr) {
return errors::Internal(
"Did not find encapsulated graph in cluster manager for node ",
encap_node_name);
}
GraphConstructorOptions opts;
opts.allow_internal_ops = true;
Graph graph_for_current_encapsulate(OpRegistry::Global());
TF_RETURN_IF_ERROR(ConvertGraphDefToGraph(
opts, *gdef_for_current_encapsulate, &graph_for_current_encapsulate));
TF_RETURN_IF_ERROR(
GetStaticInputs(&graph_for_current_encapsulate, &static_input_indexes));
nb.Attr("_ngraph_static_inputs", static_input_indexes);
Status status = nb.Finalize(graph, &n);
TF_RETURN_IF_ERROR(status);
n->set_assigned_device_name(device_name_map[cluster_idx]);
cluster_node_map[cluster_idx] = n;
}
// Pass 4: Remap all non-clustered inputs that are reading from
// encapsulated edges, and all control edges that cross cluster
// boundaries.
// Copy the edge pointers, so as not to invalidate the iterator.
std::vector<Edge*> edges;
for (auto edge : graph->edges()) {
edges.push_back(edge);
}
for (auto edge : edges) {
int src_cluster_idx;
bool src_clustered =
(GetNodeCluster(edge->src(), &src_cluster_idx) == Status::OK());
int dst_cluster_idx;
bool dst_clustered =
(GetNodeCluster(edge->dst(), &dst_cluster_idx) == Status::OK());
if (src_cluster_idx == dst_cluster_idx) {
continue;
}
if (edge->IsControlEdge()) {
if (src_clustered && dst_clustered) {
graph->RemoveControlEdge(edge);
graph->AddControlEdge(cluster_node_map[src_cluster_idx],
cluster_node_map[dst_cluster_idx]);
} else if (src_clustered) {
Node* dst = edge->dst();
graph->RemoveControlEdge(edge);
graph->AddControlEdge(cluster_node_map[src_cluster_idx], dst);
} else if (dst_clustered) {
Node* src = edge->src();
graph->RemoveControlEdge(edge);
graph->AddControlEdge(src, cluster_node_map[dst_cluster_idx]);
}
} else {
// This is handled at a later stage (TODO(amprocte): explain)
if (dst_clustered) {
continue;
}
auto it = output_remap_map.find(
std::make_tuple(edge->src()->id(), edge->src_output()));
if (it == output_remap_map.end()) {
continue;
}
int cluster_idx;
int cluster_output;
std::tie(cluster_idx, cluster_output) = it->second;
Status status =
graph->UpdateEdge(cluster_node_map[cluster_idx], cluster_output,
edge->dst(), edge->dst_input());
TF_RETURN_IF_ERROR(status);
}
}
// Pass 6: Remove clustered nodes from the graph.
std::vector<Node*> nodes_to_remove;
for (auto node : graph->op_nodes()) {
int cluster_idx;
if (GetNodeAttr(node->attrs(), "_ngraph_cluster", &cluster_idx) !=
Status::OK()) {
continue;
}
nodes_to_remove.push_back(node);
}
for (auto node : nodes_to_remove) {
NGRAPH_VLOG(4) << "Removing: " << node->name();
graph->RemoveNode(node);
}
rewrite_done = true;
return Status::OK();
}