Status NGraphEncapsulateOp::GetExecutable()

in ngraph_bridge/kernels/ngraph_encapsulate_op.cc [361:453]


Status NGraphEncapsulateOp::GetExecutable(
    const std::vector<Tensor>& tf_input_tensors,
    std::shared_ptr<Executable>& ng_exec) {
  auto backend = BackendManager::GetBackend();

  // Compute Signature
  std::vector<const Tensor*> static_input_map;
  std::vector<TensorShape> input_shapes;
  std::stringstream signature_ss;
  for (int i = 0; i < tf_input_tensors.size(); i++) {
    const Tensor& input_tensor = tf_input_tensors[i];
    input_shapes.push_back(input_tensor.shape());
    for (const auto& x : input_tensor.shape()) {
      signature_ss << x.size << ",";
    }
    signature_ss << ";";
  }
  signature_ss << "/";

  static_input_map.resize(tf_input_tensors.size());
  for (int i = 0; i < tf_input_tensors.size(); i++) {
    if (m_input_is_static[i]) {
      static_input_map[i] = &tf_input_tensors[i];
      TF_RETURN_IF_ERROR(
          tf_utils::TensorToStream(signature_ss, tf_input_tensors[i]));
      signature_ss << ";";
    }
  }

  string signature = signature_ss.str();
  NGRAPH_VLOG(5) << "Computed signature: " << signature;
  auto it = m_ng_exec_map.find(signature);
  NGRAPH_VLOG(4) << "NGraphEncapsulateOp::Compute got inputs for cluster "
                 << m_cluster_id;

  // Translate the TensorFlow graph to nGraph.
  std::shared_ptr<ngraph::Function> ng_function;
  if (it == m_ng_exec_map.end()) {
    // Measure the current total memory usage
    long vm, rss, vm0, rss0;
    utils::MemoryProfile(vm0, rss0);

    NGRAPH_VLOG(1) << "Compilation cache miss: " << m_name;
    TF_RETURN_IF_ERROR(Builder::TranslateGraph(input_shapes, static_input_map,
                                               &m_graph, m_name, ng_function));
    utils::DumpNGGraph(ng_function, m_name);

    // Evict the cache if the number of elements exceeds the limit
    std::shared_ptr<Executable> evicted_ng_exec;
    const char* cache_depth_specified =
        std::getenv("NGRAPH_TF_FUNCTION_CACHE_ITEM_DEPTH");
    if (cache_depth_specified != nullptr) {
      m_function_cache_depth_in_items = atoi(cache_depth_specified);
    }
    if (m_ng_exec_map.size() >= m_function_cache_depth_in_items) {
      evicted_ng_exec = m_ng_exec_map[m_lru.back()];
      m_ng_exec_map.erase(m_lru.back());

      m_lru.pop_back();
    }  // cache eviction if cache size greater than cache depth

    try {
      ng_exec = backend->Compile(ng_function);
    } catch (const std::exception& ex) {
      return errors::Internal("Failed to compile function " + m_name + ": ",
                              ex.what());
    }

    m_ng_exec_map[signature] = ng_exec;
    m_lru.push_front(signature);

    // Memory after
    utils::MemoryProfile(vm, rss);
    auto delta_vm_mem = vm - vm0;
    auto delta_res_mem = rss - rss0;
    NGRAPH_VLOG(1) << "NGRAPH_TF_CACHE_PROFILE: OP_ID: " << m_cluster_id
                   << " Cache length: " << m_ng_exec_map.size()
                   << " Cluster: " << m_name << " Delta VM: " << delta_vm_mem
                   << " Delta RSS: " << delta_res_mem
                   << " KB Total RSS: " << rss / (1024 * 1024) << " GB "
                   << " VM: " << vm / (1024 * 1024) << " GB" << endl;
  }  // end of input signature not found in m_ng_exec_map
  else {
    // Found the input signature in m_ng_exec_map, use the cached executable
    // Update the m_lru
    if (signature != m_lru.front()) {
      m_lru.remove(signature);
      m_lru.push_front(signature);
    }
    ng_exec = it->second;
  }
  return Status::OK();
}