Status Encapsulator::RewritePass()

in ngraph_bridge/encapsulate_clusters.cc [398:554]


Status Encapsulator::RewritePass(
    int graph_id,
    const std::unordered_map<std::string, std::string>& device_config) {
  if (!analysis_done) {
    return errors::Internal(
        "In Encapsulator, called RewritePass without calling AnalysisPass");
  }
  if (rewrite_done) {
    return errors::Internal(
        "In Encapsulator, called RewritePass more than once");
  }
  // Pass 3: Create encapsulation nodes for all clusters.
  for (auto& kv : device_name_map) {
    int cluster_idx = kv.first;
    std::stringstream ss;
    ss << "ngraph_cluster_" << cluster_idx;

    string encap_node_name = ss.str();
    std::vector<DataType> input_types;
    std::vector<NodeBuilder::NodeOut> inputs;

    for (auto& tup : cluster_input_map[cluster_idx]) {
      int src_node_id;
      int src_output_idx;
      DataType dt;
      std::tie(src_node_id, src_output_idx, dt) = tup;

      input_types.push_back(dt);

      inputs.push_back(
          NodeBuilder::NodeOut(graph->FindNodeId(src_node_id), src_output_idx));
    }

    Node* n;
    NodeBuilder nb = NodeBuilder(encap_node_name, "_nGraphEncapsulate")
                         .Attr("ngraph_cluster", cluster_idx)
                         .Attr("Targuments", input_types)
                         .Attr("Tresults", cluster_output_dt_map[cluster_idx])
                         .Attr("ngraph_graph_id", graph_id)
                         .Device(device_name_map[cluster_idx])
                         .Input(inputs);
    if (!device_config.empty()) {
      NGRAPH_VLOG(3) << "Device config is not empty";
      for (auto const& i : device_config) {
        // Adding the optional attributes
        NGRAPH_VLOG(3) << "Attaching Attribute " << i.first << " Val "
                       << i.second;
        nb.Attr(i.first, i.second);
      }
    }

    // Find Static Inputs And Add as an attribute
    vector<int> static_input_indexes;
    GraphDef* gdef_for_current_encapsulate;
    gdef_for_current_encapsulate = ClusterManager::GetClusterGraph(cluster_idx);
    if (gdef_for_current_encapsulate == nullptr) {
      return errors::Internal(
          "Did not find encapsulated graph in cluster manager for node ",
          encap_node_name);
    }
    GraphConstructorOptions opts;
    opts.allow_internal_ops = true;
    Graph graph_for_current_encapsulate(OpRegistry::Global());
    TF_RETURN_IF_ERROR(ConvertGraphDefToGraph(
        opts, *gdef_for_current_encapsulate, &graph_for_current_encapsulate));

    TF_RETURN_IF_ERROR(
        GetStaticInputs(&graph_for_current_encapsulate, &static_input_indexes));
    nb.Attr("_ngraph_static_inputs", static_input_indexes);

    Status status = nb.Finalize(graph, &n);
    TF_RETURN_IF_ERROR(status);
    n->set_assigned_device_name(device_name_map[cluster_idx]);

    cluster_node_map[cluster_idx] = n;
  }

  // Pass 4: Remap all non-clustered inputs that are reading from
  // encapsulated edges, and all control edges that cross cluster
  // boundaries.

  // Copy the edge pointers, so as not to invalidate the iterator.
  std::vector<Edge*> edges;
  for (auto edge : graph->edges()) {
    edges.push_back(edge);
  }

  for (auto edge : edges) {
    int src_cluster_idx;
    bool src_clustered =
        (GetNodeCluster(edge->src(), &src_cluster_idx) == Status::OK());
    int dst_cluster_idx;
    bool dst_clustered =
        (GetNodeCluster(edge->dst(), &dst_cluster_idx) == Status::OK());

    if (src_cluster_idx == dst_cluster_idx) {
      continue;
    }

    if (edge->IsControlEdge()) {
      if (src_clustered && dst_clustered) {
        graph->RemoveControlEdge(edge);
        graph->AddControlEdge(cluster_node_map[src_cluster_idx],
                              cluster_node_map[dst_cluster_idx]);
      } else if (src_clustered) {
        Node* dst = edge->dst();
        graph->RemoveControlEdge(edge);
        graph->AddControlEdge(cluster_node_map[src_cluster_idx], dst);
      } else if (dst_clustered) {
        Node* src = edge->src();
        graph->RemoveControlEdge(edge);
        graph->AddControlEdge(src, cluster_node_map[dst_cluster_idx]);
      }
    } else {
      // This is handled at a later stage (TODO(amprocte): explain)
      if (dst_clustered) {
        continue;
      }

      auto it = output_remap_map.find(
          std::make_tuple(edge->src()->id(), edge->src_output()));

      if (it == output_remap_map.end()) {
        continue;
      }

      int cluster_idx;
      int cluster_output;
      std::tie(cluster_idx, cluster_output) = it->second;

      Status status =
          graph->UpdateEdge(cluster_node_map[cluster_idx], cluster_output,
                            edge->dst(), edge->dst_input());
      TF_RETURN_IF_ERROR(status);
    }
  }

  // Pass 6: Remove clustered nodes from the graph.
  std::vector<Node*> nodes_to_remove;
  for (auto node : graph->op_nodes()) {
    int cluster_idx;

    if (GetNodeAttr(node->attrs(), "_ngraph_cluster", &cluster_idx) !=
        Status::OK()) {
      continue;
    }
    nodes_to_remove.push_back(node);
  }

  for (auto node : nodes_to_remove) {
    NGRAPH_VLOG(4) << "Removing: " << node->name();
    graph->RemoveNode(node);
  }

  rewrite_done = true;
  return Status::OK();
}