std::unique_ptr GrapplerItemFromMetaGraphDef()

in tensorflow/tensorflow/core/grappler/grappler_item_builder.cc [284:661]


std::unique_ptr<GrapplerItem> GrapplerItemFromMetaGraphDef(
    const string& id, const MetaGraphDef& meta_graph, const ItemConfig& cfg) {
  if (id.empty()) {
    LOG(ERROR) << "id must be non-empty.";
    return nullptr;
  }
  std::unique_ptr<GrapplerItem> new_item(new GrapplerItem());
  new_item->id = id;
  new_item->graph = meta_graph.graph_def();

  // Fill in feed nodes from config, if any provided.
  for (const auto& feed_node : cfg.feed_nodes) {
    const string feed_name = NodeName(feed_node);
    new_item->feed.emplace_back(feed_name, Tensor());
  }
  for (const auto& fetch_node : cfg.fetch_nodes) {
    new_item->fetch.emplace_back(NodeName(fetch_node));
  }

  // Attempt to detect the fetch node(s) if they were not set explicitly.
  if (new_item->fetch.empty() &&
      meta_graph.collection_def().count("train_op") > 0) {
    const CollectionDef& nodes = meta_graph.collection_def().at("train_op");
    if (nodes.has_node_list()) {
      for (const auto& node : nodes.node_list().value()) {
        new_item->fetch.push_back(NodeName(node));
      }
    }
  }

  // Detect feed and fetch nodes from signature defs. Signatures may share same
  // inputs or outputs.
  std::unordered_set<string> signature_feed_nodes;
  std::unordered_set<string> signature_fetch_nodes;
  for (const auto& name_and_signature : meta_graph.signature_def()) {
    for (const auto& name_and_input : name_and_signature.second.inputs()) {
      const TensorInfo& input = name_and_input.second;
      if (input.has_coo_sparse()) {
        // Define the shapes following the comment of CooSparse.
        // TODO(yuefengz): we probably want to use different dim values for the
        // three tensors of a SparseTensor.
        int64 dim = std::max(1, cfg.placeholder_unknown_output_shape_dim);
        TensorShape shape_1d({dim});
        TensorShape shape_2d({dim, dim});

        if (gtl::InsertIfNotPresent(
                &signature_feed_nodes,
                NodeName(input.coo_sparse().values_tensor_name()))) {
          Tensor value_tensor(input.dtype(), shape_1d);
          InitializeTensor(input.dtype(), &value_tensor);
          new_item->feed.emplace_back(
              NodeName(input.coo_sparse().values_tensor_name()), value_tensor);
        }
        if (gtl::InsertIfNotPresent(
                &signature_feed_nodes,
                NodeName(input.coo_sparse().indices_tensor_name()))) {
          Tensor indices_tensor(DT_INT64, shape_2d);
          InitializeTensor(input.dtype(), &indices_tensor);
          new_item->feed.emplace_back(
              NodeName(input.coo_sparse().indices_tensor_name()),
              indices_tensor);
        }
        if (gtl::InsertIfNotPresent(
                &signature_feed_nodes,
                NodeName(input.coo_sparse().dense_shape_tensor_name()))) {
          Tensor dense_shape_tensor(DT_INT64, shape_1d);
          InitializeTensor(input.dtype(), &dense_shape_tensor);
          new_item->feed.emplace_back(
              NodeName(input.coo_sparse().dense_shape_tensor_name()),
              dense_shape_tensor);
        }
      } else {
        if (gtl::InsertIfNotPresent(&signature_feed_nodes,
                                    NodeName(input.name()))) {
          TensorShape shape;
          TensorShapeProto shape_proto;
          Status s = ReplaceUnknownShapeDim(cfg, input.tensor_shape(),
                                            &shape_proto, &shape);
          if (!s.ok()) {
            LOG(ERROR) << "Invalid shape for signature input " << input.name()
                       << ": " << s << ", skipping this input";
            return nullptr;
          }

          Tensor fake_input(input.dtype(), shape);
          InitializeTensor(input.dtype(), &fake_input);
          new_item->feed.emplace_back(NodeName(input.name()), fake_input);
        }
      }
    }
    for (const auto& name_and_output : name_and_signature.second.outputs()) {
      const TensorInfo& output = name_and_output.second;
      if (output.has_coo_sparse()) {
        if (gtl::InsertIfNotPresent(
                &signature_fetch_nodes,
                NodeName(output.coo_sparse().values_tensor_name()))) {
          new_item->fetch.push_back(
              NodeName(output.coo_sparse().values_tensor_name()));
        }
        if (gtl::InsertIfNotPresent(
                &signature_fetch_nodes,
                NodeName(output.coo_sparse().indices_tensor_name()))) {
          new_item->fetch.push_back(
              NodeName(output.coo_sparse().indices_tensor_name()));
        }
        if (gtl::InsertIfNotPresent(
                &signature_fetch_nodes,
                NodeName(output.coo_sparse().dense_shape_tensor_name()))) {
          new_item->fetch.push_back(
              NodeName(output.coo_sparse().dense_shape_tensor_name()));
        }
      } else {
        if (gtl::InsertIfNotPresent(&signature_fetch_nodes,
                                    NodeName(output.name()))) {
          new_item->fetch.push_back(NodeName(output.name()));
        }
      }
    }
  }

  for (const auto& feed : new_item->feed) {
    if (feed.first.empty()) {
      LOG(ERROR) << "Invalid feed node name skipping this input";
      return nullptr;
    } else {
      VLOG(1) << "Will use feed node " << feed.first;
    }
  }

  for (const auto& fetch : new_item->fetch) {
    if (fetch.empty()) {
      LOG(ERROR) << "Invalid fetch node name skipping this input";
      return nullptr;
    } else {
      VLOG(1) << "Will use fetch node " << fetch;
    }
  }

  if (new_item->fetch.empty()) {
    LOG(ERROR) << "Failed to detect the fetch node(s), skipping this input";
    return nullptr;
  }

  // TODO(yuefengz): consider handling saved_model_main_op and legacy_init_op.
  // The reason why they are difficult to handle is because they may not intend
  // to initialize all variables that are required to run fetch nodes. We may
  // have to run restore op first.

  // Try to find initializers from variables and tables as init ops.
  for (const string& var_collection :
       {"variables", "local_variables", "model_variables",
        "trainable_variables"}) {
    if (meta_graph.collection_def().count(var_collection) == 0) {
      continue;
    }
    const CollectionDef& vars = meta_graph.collection_def().at(var_collection);
    for (const auto& raw_var : vars.bytes_list().value()) {
      VariableDef var;
      var.ParseFromString(raw_var);
      if (!var.initializer_name().empty()) {
        new_item->init_ops.push_back(NodeName(var.initializer_name()));
      }
    }
  }

  if (meta_graph.collection_def().count("table_initializer") > 0) {
    const CollectionDef& inits =
        meta_graph.collection_def().at("table_initializer");
    if (inits.has_node_list()) {
      for (const auto& node : inits.node_list().value()) {
        new_item->init_ops.push_back(NodeName(node));
        // Tables are initialized from files, which can take a long time. Add
        // 30 minutes to the initialization time for each table to avoid
        // timing out.
        // TODO(bsteiner): adjust the timeout based on the file size.
        new_item->expected_init_time += 30 * 60;
      }
    }
  }

  // We keep the mapping from asset node to asset files. This should have been
  // used as feed but since asset node is usually a constant node, we will fill
  // the values of these constant nodes with their actual asset file paths.
  std::unordered_map<string, string> asset_node_to_value;

  // Assets file may have changed their directory, we assemble their new paths
  // if assets_directory_override is set. We also make sure we still can
  // access these asset files.
  if (!cfg.assets_directory_override.empty()) {
    if (meta_graph.collection_def().count("saved_model_assets") > 0) {
      const CollectionDef& collection =
          meta_graph.collection_def().at("saved_model_assets");
      const auto& any_assets = collection.any_list().value();
      for (const auto& any_asset : any_assets) {
        AssetFileDef asset_file_def;
        if (!ParseAny(any_asset, &asset_file_def, "tensorflow.AssetFileDef")
                 .ok()) {
          LOG(ERROR) << "Failed to parse AssetFile.";
          continue;
        }
        string asset_filepath = io::JoinPath(cfg.assets_directory_override,
                                             asset_file_def.filename());
        if (!FilesExist({asset_filepath}, nullptr)) {
          LOG(ERROR) << "Can't access one or more of the asset files "
                     << asset_filepath << ", skipping this input";
          return nullptr;
        }
        asset_node_to_value[NodeName(asset_file_def.tensor_info().name())] =
            asset_filepath;
      }
    }
  } else if (meta_graph.collection_def().count("asset_filepaths") > 0) {
    const CollectionDef& file_paths =
        meta_graph.collection_def().at("asset_filepaths");
    std::vector<string> paths;
    for (const auto& raw_path : file_paths.bytes_list().value()) {
      paths.push_back(raw_path);
    }
    if (!FilesExist(paths, nullptr)) {
      LOG(ERROR) << "Can't access one or more of the asset files, skipping "
                    "this input";
      return nullptr;
    }
  }

  if (meta_graph.collection_def().count("queue_runners") > 0) {
    const CollectionDef& vars = meta_graph.collection_def().at("queue_runners");
    for (const auto& raw : vars.bytes_list().value()) {
      QueueRunnerDef queue_runner;
      if (!queue_runner.ParseFromString(raw)) {
        LOG(ERROR) << "Could not parse queue_runners, skipping this input";
        return nullptr;
      }
      if (queue_runner.cancel_op_name().empty()) {
        LOG(ERROR) << "Queue without a cancel op, skipping this input";
        return nullptr;
      }
      new_item->queue_runners.push_back(queue_runner);
    }
  }

  // Add each node referenced in a collection to the list of nodes to keep.
  for (const auto& col : meta_graph.collection_def()) {
    const CollectionDef& collection = col.second;
    for (const string& node : collection.node_list().value()) {
      new_item->keep_ops.push_back(NodeName(node));
    }
  }

  for (auto& node : *new_item->graph.mutable_node()) {
    if (IsPlaceholder(node) && node.op() != "PlaceholderWithDefault") {
      Status s = UpdatePlaceholderShape(cfg, signature_feed_nodes,
                                        new_item.get(), &node);
      if (!s.ok()) return nullptr;
    } else if (IsConstant(node)) {
      auto it = asset_node_to_value.find(node.name());
      if (it != asset_node_to_value.end()) {
        auto iter = node.mutable_attr()->find("value");
        if (iter == node.attr().end()) {
          LOG(ERROR) << "Value attribute expected in const op for asset files";
          return nullptr;
        }
        if (!iter->second.has_tensor() ||
            iter->second.tensor().string_val_size() != 1) {
          LOG(INFO) << "Unexpected AttrValue proto: "
                    << iter->second.DebugString();
          return nullptr;
        }
        LOG(INFO) << "Using asset file " << it->second << " for node "
                  << node.name();
        *(iter->second.mutable_tensor()->mutable_string_val(0)) = it->second;
      }
    }

    // Erase the recorded result of any previous shape inference to start again
    // from scratch.
    node.mutable_attr()->erase("_output_shapes");

    // Delete user specified placement if requested.
    if (cfg.ignore_user_placement) {
      node.clear_device();
    }
    // Delete colocation constraints if requested.
    if (cfg.ignore_colocation) {
      auto attr = node.mutable_attr();
      auto it = attr->find("_class");
      if (it != attr->end()) {
        attr->erase(it);
      }
    }
  }

  if (meta_graph.collection_def().count("savers") > 0) {
    const CollectionDef& savers = meta_graph.collection_def().at("savers");
    for (const auto& raw : savers.bytes_list().value()) {
      SaverDef saver;
      // Skip bad savers since we don't need saves/restores to be able to run a
      // graph.
      if (!saver.ParseFromString(raw)) {
        continue;
      }
      if (saver.filename_tensor_name().empty()) {
        continue;
      }
      new_item->save_op = saver.save_tensor_name();
      new_item->restore_op = saver.restore_op_name();
      new_item->save_restore_loc_tensor = saver.filename_tensor_name();
      // Only use the first saver since it's not clear what to do if there's
      // more than one.
      break;
    }
  } else {
    const SaverDef& saver = meta_graph.saver_def();
    new_item->save_op = saver.save_tensor_name();
    new_item->restore_op = saver.restore_op_name();
    new_item->save_restore_loc_tensor = saver.filename_tensor_name();
  }

  // Instantiate all the missing attributes with their default values.
  Status attr_status = AddDefaultAttrsToGraphDef(
      &new_item->graph,
      FunctionLibraryDefinition(OpRegistry::Global(),
                                new_item->graph.library()),
      0, true);
  if (!attr_status.ok()) {
    LOG(ERROR) << "Failed to instantiate default attribute values: "
               << attr_status.error_message();
    return nullptr;
  }

  // Optimize the graph (function inlining, l1 optimizations, etc).
  VLOG(1) << "Number of nodes in graph before RuntimeGraphOptimizer: "
          << new_item->graph.node_size();
  Status optimize_status =
      RuntimeGraphOptimizer(new_item->graph, &new_item->graph, cfg);
  if (!optimize_status.ok()) {
    LOG(ERROR) << "Graph preprocessing failed: " << optimize_status;
    return nullptr;
  }
  VLOG(1) << "Number of nodes in graph after RuntimeGraphOptimizer: "
          << new_item->graph.node_size();

  if (cfg.prune_graph) {
    VLOG(1) << "Pruning graph...";
    auto status = PruneGraph(new_item.get());
    if (!status.ok()) {
      LOG(ERROR) << "Pruning failed: " << status.error_message();
      return nullptr;
    }
    VLOG(1) << "Number of nodes in graph after pruning: "
            << new_item->graph.node_size();
  }

  // Validate feed, fetch and init nodes
  std::unordered_set<string> nodes;
  for (const auto& node : new_item->graph.node()) {
    nodes.insert(node.name());
  }
  for (const auto& feed : new_item->feed) {
    if (nodes.find(feed.first) == nodes.end()) {
      LOG(ERROR) << "Feed node " << feed.first << " doesn't exist in graph";
      return nullptr;
    }
  }
  for (const auto& fetch : new_item->fetch) {
    if (nodes.find(fetch) == nodes.end()) {
      LOG(ERROR) << "Fetch node " << fetch << " doesn't exist in graph";
      return nullptr;
    }
  }
  for (const auto& init : new_item->init_ops) {
    if (nodes.find(init) == nodes.end()) {
      LOG(ERROR) << "Init node " << init << " doesn't exist in graph";
      return nullptr;
    }
  }
  return new_item;
}