in tensorflow/tensorflow/core/grappler/grappler_item_builder.cc [284:661]
std::unique_ptr<GrapplerItem> GrapplerItemFromMetaGraphDef(
const string& id, const MetaGraphDef& meta_graph, const ItemConfig& cfg) {
if (id.empty()) {
LOG(ERROR) << "id must be non-empty.";
return nullptr;
}
std::unique_ptr<GrapplerItem> new_item(new GrapplerItem());
new_item->id = id;
new_item->graph = meta_graph.graph_def();
// Fill in feed nodes from config, if any provided.
for (const auto& feed_node : cfg.feed_nodes) {
const string feed_name = NodeName(feed_node);
new_item->feed.emplace_back(feed_name, Tensor());
}
for (const auto& fetch_node : cfg.fetch_nodes) {
new_item->fetch.emplace_back(NodeName(fetch_node));
}
// Attempt to detect the fetch node(s) if they were not set explicitly.
if (new_item->fetch.empty() &&
meta_graph.collection_def().count("train_op") > 0) {
const CollectionDef& nodes = meta_graph.collection_def().at("train_op");
if (nodes.has_node_list()) {
for (const auto& node : nodes.node_list().value()) {
new_item->fetch.push_back(NodeName(node));
}
}
}
// Detect feed and fetch nodes from signature defs. Signatures may share same
// inputs or outputs.
std::unordered_set<string> signature_feed_nodes;
std::unordered_set<string> signature_fetch_nodes;
for (const auto& name_and_signature : meta_graph.signature_def()) {
for (const auto& name_and_input : name_and_signature.second.inputs()) {
const TensorInfo& input = name_and_input.second;
if (input.has_coo_sparse()) {
// Define the shapes following the comment of CooSparse.
// TODO(yuefengz): we probably want to use different dim values for the
// three tensors of a SparseTensor.
int64 dim = std::max(1, cfg.placeholder_unknown_output_shape_dim);
TensorShape shape_1d({dim});
TensorShape shape_2d({dim, dim});
if (gtl::InsertIfNotPresent(
&signature_feed_nodes,
NodeName(input.coo_sparse().values_tensor_name()))) {
Tensor value_tensor(input.dtype(), shape_1d);
InitializeTensor(input.dtype(), &value_tensor);
new_item->feed.emplace_back(
NodeName(input.coo_sparse().values_tensor_name()), value_tensor);
}
if (gtl::InsertIfNotPresent(
&signature_feed_nodes,
NodeName(input.coo_sparse().indices_tensor_name()))) {
Tensor indices_tensor(DT_INT64, shape_2d);
InitializeTensor(input.dtype(), &indices_tensor);
new_item->feed.emplace_back(
NodeName(input.coo_sparse().indices_tensor_name()),
indices_tensor);
}
if (gtl::InsertIfNotPresent(
&signature_feed_nodes,
NodeName(input.coo_sparse().dense_shape_tensor_name()))) {
Tensor dense_shape_tensor(DT_INT64, shape_1d);
InitializeTensor(input.dtype(), &dense_shape_tensor);
new_item->feed.emplace_back(
NodeName(input.coo_sparse().dense_shape_tensor_name()),
dense_shape_tensor);
}
} else {
if (gtl::InsertIfNotPresent(&signature_feed_nodes,
NodeName(input.name()))) {
TensorShape shape;
TensorShapeProto shape_proto;
Status s = ReplaceUnknownShapeDim(cfg, input.tensor_shape(),
&shape_proto, &shape);
if (!s.ok()) {
LOG(ERROR) << "Invalid shape for signature input " << input.name()
<< ": " << s << ", skipping this input";
return nullptr;
}
Tensor fake_input(input.dtype(), shape);
InitializeTensor(input.dtype(), &fake_input);
new_item->feed.emplace_back(NodeName(input.name()), fake_input);
}
}
}
for (const auto& name_and_output : name_and_signature.second.outputs()) {
const TensorInfo& output = name_and_output.second;
if (output.has_coo_sparse()) {
if (gtl::InsertIfNotPresent(
&signature_fetch_nodes,
NodeName(output.coo_sparse().values_tensor_name()))) {
new_item->fetch.push_back(
NodeName(output.coo_sparse().values_tensor_name()));
}
if (gtl::InsertIfNotPresent(
&signature_fetch_nodes,
NodeName(output.coo_sparse().indices_tensor_name()))) {
new_item->fetch.push_back(
NodeName(output.coo_sparse().indices_tensor_name()));
}
if (gtl::InsertIfNotPresent(
&signature_fetch_nodes,
NodeName(output.coo_sparse().dense_shape_tensor_name()))) {
new_item->fetch.push_back(
NodeName(output.coo_sparse().dense_shape_tensor_name()));
}
} else {
if (gtl::InsertIfNotPresent(&signature_fetch_nodes,
NodeName(output.name()))) {
new_item->fetch.push_back(NodeName(output.name()));
}
}
}
}
for (const auto& feed : new_item->feed) {
if (feed.first.empty()) {
LOG(ERROR) << "Invalid feed node name skipping this input";
return nullptr;
} else {
VLOG(1) << "Will use feed node " << feed.first;
}
}
for (const auto& fetch : new_item->fetch) {
if (fetch.empty()) {
LOG(ERROR) << "Invalid fetch node name skipping this input";
return nullptr;
} else {
VLOG(1) << "Will use fetch node " << fetch;
}
}
if (new_item->fetch.empty()) {
LOG(ERROR) << "Failed to detect the fetch node(s), skipping this input";
return nullptr;
}
// TODO(yuefengz): consider handling saved_model_main_op and legacy_init_op.
// The reason why they are difficult to handle is because they may not intend
// to initialize all variables that are required to run fetch nodes. We may
// have to run restore op first.
// Try to find initializers from variables and tables as init ops.
for (const string& var_collection :
{"variables", "local_variables", "model_variables",
"trainable_variables"}) {
if (meta_graph.collection_def().count(var_collection) == 0) {
continue;
}
const CollectionDef& vars = meta_graph.collection_def().at(var_collection);
for (const auto& raw_var : vars.bytes_list().value()) {
VariableDef var;
var.ParseFromString(raw_var);
if (!var.initializer_name().empty()) {
new_item->init_ops.push_back(NodeName(var.initializer_name()));
}
}
}
if (meta_graph.collection_def().count("table_initializer") > 0) {
const CollectionDef& inits =
meta_graph.collection_def().at("table_initializer");
if (inits.has_node_list()) {
for (const auto& node : inits.node_list().value()) {
new_item->init_ops.push_back(NodeName(node));
// Tables are initialized from files, which can take a long time. Add
// 30 minutes to the initialization time for each table to avoid
// timing out.
// TODO(bsteiner): adjust the timeout based on the file size.
new_item->expected_init_time += 30 * 60;
}
}
}
// We keep the mapping from asset node to asset files. This should have been
// used as feed but since asset node is usually a constant node, we will fill
// the values of these constant nodes with their actual asset file paths.
std::unordered_map<string, string> asset_node_to_value;
// Assets file may have changed their directory, we assemble their new paths
// if assets_directory_override is set. We also make sure we still can
// access these asset files.
if (!cfg.assets_directory_override.empty()) {
if (meta_graph.collection_def().count("saved_model_assets") > 0) {
const CollectionDef& collection =
meta_graph.collection_def().at("saved_model_assets");
const auto& any_assets = collection.any_list().value();
for (const auto& any_asset : any_assets) {
AssetFileDef asset_file_def;
if (!ParseAny(any_asset, &asset_file_def, "tensorflow.AssetFileDef")
.ok()) {
LOG(ERROR) << "Failed to parse AssetFile.";
continue;
}
string asset_filepath = io::JoinPath(cfg.assets_directory_override,
asset_file_def.filename());
if (!FilesExist({asset_filepath}, nullptr)) {
LOG(ERROR) << "Can't access one or more of the asset files "
<< asset_filepath << ", skipping this input";
return nullptr;
}
asset_node_to_value[NodeName(asset_file_def.tensor_info().name())] =
asset_filepath;
}
}
} else if (meta_graph.collection_def().count("asset_filepaths") > 0) {
const CollectionDef& file_paths =
meta_graph.collection_def().at("asset_filepaths");
std::vector<string> paths;
for (const auto& raw_path : file_paths.bytes_list().value()) {
paths.push_back(raw_path);
}
if (!FilesExist(paths, nullptr)) {
LOG(ERROR) << "Can't access one or more of the asset files, skipping "
"this input";
return nullptr;
}
}
if (meta_graph.collection_def().count("queue_runners") > 0) {
const CollectionDef& vars = meta_graph.collection_def().at("queue_runners");
for (const auto& raw : vars.bytes_list().value()) {
QueueRunnerDef queue_runner;
if (!queue_runner.ParseFromString(raw)) {
LOG(ERROR) << "Could not parse queue_runners, skipping this input";
return nullptr;
}
if (queue_runner.cancel_op_name().empty()) {
LOG(ERROR) << "Queue without a cancel op, skipping this input";
return nullptr;
}
new_item->queue_runners.push_back(queue_runner);
}
}
// Add each node referenced in a collection to the list of nodes to keep.
for (const auto& col : meta_graph.collection_def()) {
const CollectionDef& collection = col.second;
for (const string& node : collection.node_list().value()) {
new_item->keep_ops.push_back(NodeName(node));
}
}
for (auto& node : *new_item->graph.mutable_node()) {
if (IsPlaceholder(node) && node.op() != "PlaceholderWithDefault") {
Status s = UpdatePlaceholderShape(cfg, signature_feed_nodes,
new_item.get(), &node);
if (!s.ok()) return nullptr;
} else if (IsConstant(node)) {
auto it = asset_node_to_value.find(node.name());
if (it != asset_node_to_value.end()) {
auto iter = node.mutable_attr()->find("value");
if (iter == node.attr().end()) {
LOG(ERROR) << "Value attribute expected in const op for asset files";
return nullptr;
}
if (!iter->second.has_tensor() ||
iter->second.tensor().string_val_size() != 1) {
LOG(INFO) << "Unexpected AttrValue proto: "
<< iter->second.DebugString();
return nullptr;
}
LOG(INFO) << "Using asset file " << it->second << " for node "
<< node.name();
*(iter->second.mutable_tensor()->mutable_string_val(0)) = it->second;
}
}
// Erase the recorded result of any previous shape inference to start again
// from scratch.
node.mutable_attr()->erase("_output_shapes");
// Delete user specified placement if requested.
if (cfg.ignore_user_placement) {
node.clear_device();
}
// Delete colocation constraints if requested.
if (cfg.ignore_colocation) {
auto attr = node.mutable_attr();
auto it = attr->find("_class");
if (it != attr->end()) {
attr->erase(it);
}
}
}
if (meta_graph.collection_def().count("savers") > 0) {
const CollectionDef& savers = meta_graph.collection_def().at("savers");
for (const auto& raw : savers.bytes_list().value()) {
SaverDef saver;
// Skip bad savers since we don't need saves/restores to be able to run a
// graph.
if (!saver.ParseFromString(raw)) {
continue;
}
if (saver.filename_tensor_name().empty()) {
continue;
}
new_item->save_op = saver.save_tensor_name();
new_item->restore_op = saver.restore_op_name();
new_item->save_restore_loc_tensor = saver.filename_tensor_name();
// Only use the first saver since it's not clear what to do if there's
// more than one.
break;
}
} else {
const SaverDef& saver = meta_graph.saver_def();
new_item->save_op = saver.save_tensor_name();
new_item->restore_op = saver.restore_op_name();
new_item->save_restore_loc_tensor = saver.filename_tensor_name();
}
// Instantiate all the missing attributes with their default values.
Status attr_status = AddDefaultAttrsToGraphDef(
&new_item->graph,
FunctionLibraryDefinition(OpRegistry::Global(),
new_item->graph.library()),
0, true);
if (!attr_status.ok()) {
LOG(ERROR) << "Failed to instantiate default attribute values: "
<< attr_status.error_message();
return nullptr;
}
// Optimize the graph (function inlining, l1 optimizations, etc).
VLOG(1) << "Number of nodes in graph before RuntimeGraphOptimizer: "
<< new_item->graph.node_size();
Status optimize_status =
RuntimeGraphOptimizer(new_item->graph, &new_item->graph, cfg);
if (!optimize_status.ok()) {
LOG(ERROR) << "Graph preprocessing failed: " << optimize_status;
return nullptr;
}
VLOG(1) << "Number of nodes in graph after RuntimeGraphOptimizer: "
<< new_item->graph.node_size();
if (cfg.prune_graph) {
VLOG(1) << "Pruning graph...";
auto status = PruneGraph(new_item.get());
if (!status.ok()) {
LOG(ERROR) << "Pruning failed: " << status.error_message();
return nullptr;
}
VLOG(1) << "Number of nodes in graph after pruning: "
<< new_item->graph.node_size();
}
// Validate feed, fetch and init nodes
std::unordered_set<string> nodes;
for (const auto& node : new_item->graph.node()) {
nodes.insert(node.name());
}
for (const auto& feed : new_item->feed) {
if (nodes.find(feed.first) == nodes.end()) {
LOG(ERROR) << "Feed node " << feed.first << " doesn't exist in graph";
return nullptr;
}
}
for (const auto& fetch : new_item->fetch) {
if (nodes.find(fetch) == nodes.end()) {
LOG(ERROR) << "Fetch node " << fetch << " doesn't exist in graph";
return nullptr;
}
}
for (const auto& init : new_item->init_ops) {
if (nodes.find(init) == nodes.end()) {
LOG(ERROR) << "Init node " << init << " doesn't exist in graph";
return nullptr;
}
}
return new_item;
}