in grappler/auto_mixed_precision.cc [531:613]
Status GraphTypeTopologyView::InitializeFromGraph(
const GraphDef& graph, const NodeTypeAttrMap& node_type_map) {
if (graph_ != nullptr) {
return errors::InvalidArgument(
"GraphTypeTopologyView is already initialized.");
}
graph_ = &graph;
int num_nodedefs = graph.node_size();
node_name_to_index_.rehash(num_nodedefs);
// Build maps from name to index.
node_type_attrs_.reserve(num_nodedefs); // Only approximate.
node_type_name_to_index_.rehash(num_nodedefs); // Only approximate.
for (int node_idx = 0; node_idx < num_nodedefs; ++node_idx) {
const NodeDef& node = graph.node(node_idx);
node_name_to_index_.emplace(node.name(), node_idx);
for (const TypeAttrId& type_attr : node_type_map.GetTypeAttrs(node)) {
int node_type_idx = node_type_attrs_.size();
node_type_name_to_index_.emplace(NodeTypeKey(node.name(), type_attr),
node_type_idx);
node_type_attrs_.emplace_back(&node, type_attr);
}
}
num_nodes_ = node_type_attrs_.size();
fanins_.resize(num_nodes_);
fanouts_.resize(num_nodes_);
// Add graph edges to the adjacency lists.
for (int node_type_idx = 0; node_type_idx < num_nodes_; ++node_type_idx) {
const NodeTypeId& node_type = node_type_attrs_.at(node_type_idx);
auto input_ports =
node_type_map.GetInputPorts(*node_type.node, node_type.type_attr);
fanins_[node_type_idx].reserve(input_ports.size());
for (int port : input_ports) {
const string& input = node_type.node->input(port);
TensorId tensor = ParseTensorName(input);
const auto it = node_name_to_index_.find(tensor.node());
const bool valid_input = it != node_name_to_index_.end();
if (!valid_input) {
const string error_message = absl::StrCat(
"Non-existent input ", input, " in node ", node_type.node->name());
if (skip_invalid_edges_) {
VLOG(3) << "Skip error: " << error_message;
} else {
return errors::InvalidArgument(error_message);
}
}
if (valid_input) {
const int input_idx = it->second;
const NodeDef& input_node = graph_->node(input_idx);
TypeAttrId input_type_attr =
node_type_map.GetOutputTypeAttr(input_node, tensor.index());
const auto it2 = node_type_name_to_index_.find(
NodeTypeKey(input_node.name(), input_type_attr));
if (it2 == node_type_name_to_index_.end()) {
if (!skip_invalid_edges_) {
return errors::InvalidArgument("Did not find type attr ",
input_type_attr.DebugString(),
" in node ", input_node.name());
}
continue;
}
int input_node_type_idx = it2->second;
fanins_[node_type_idx].push_back(input_node_type_idx);
fanouts_[input_node_type_idx].push_back(node_type_idx);
}
}
// Dedup the input list while it's still hot in cache.
SortAndRemoveDuplicates(&fanins_[node_type_idx]);
}
// Dedup outputs for all the graph nodes.
for (int node_type_idx = 0; node_type_idx < num_nodes_; ++node_type_idx) {
SortAndRemoveDuplicates(&fanouts_[node_type_idx]);
}
return Status::OK();
}