Status PartialExecutor::Execute()

in onnxruntime/core/framework/orttraining_partial_executor.cc [133:519]


Status PartialExecutor::Execute(const SessionState& session_state, const std::vector<int>& feed_mlvalue_idxs,
                                const std::vector<OrtValue>& feeds, const std::vector<int>& fetch_mlvalue_idxs,
                                std::vector<OrtValue>& fetches,
                                const std::unordered_map<size_t, CustomAllocator>& fetch_allocators,
                                const logging::Logger& logger) {
  const bool is_profiler_enabled = session_state.Profiler().IsEnabled();
  TimePoint tp;
  TimePoint sync_time_begin;
  TimePoint kernel_begin_time;
  size_t input_activation_sizes = 0;
  size_t input_parameter_sizes = 0;
  size_t total_output_sizes = 0;

  if (is_profiler_enabled) {
    tp = session_state.Profiler().Start();
  }

  ExecutionFrame& frame = state_.GetExecutionFrame(feed_mlvalue_idxs, feeds, fetch_mlvalue_idxs, fetches,
                                                   fetch_allocators, session_state);

  LOGS(logger, INFO) << "Begin execution";
  const SequentialExecutionPlan& seq_exec_plan = *session_state.GetExecutionPlan();
  const auto& exec_plan_vec = seq_exec_plan.execution_plan;
  VLOGS(logger, 1) << "Size of execution plan vector: " << exec_plan_vec.size();

// Enable TRACE_EXECUTION compile flag to dump execution plan
#if defined(TRACE_EXECUTION)
  std::cout << std::make_pair(&seq_exec_plan, &session_state) << std::endl;
#endif

  const auto& graph_viewer = session_state.GetGraphViewer();

#ifdef CONCURRENCY_VISUALIZER
  // need unique name for the series. number of nodes should be good enough for a subgraph
  char series_name[MaxSeriesNameLengthInChars] = "MainGraph";
  if (graph_viewer.IsSubgraph()) {
    auto s = graph_viewer.ParentNode()->Name().substr(0, MaxSeriesNameLengthInChars - 1);
    std::copy(s.cbegin(), s.cend(), series_name);
  }

  diagnostic::marker_series series(series_name);
#endif

#ifdef ENABLE_NVTX_PROFILE
  auto& profile_context = profile::Context::GetInstance();
  const auto tag = profile_context.GetThreadTagOrDefault(std::this_thread::get_id());
  profile::NvtxRangeCreator forward_range(
      "Batch-" + tag + " Forward",
      profile::Color::White);
  profile::NvtxRangeCreator backward_range(
      "Batch-" + tag + " Backward",
      profile::Color::Black);
#endif

#ifdef DEBUG_NODE_INPUTS_OUTPUTS
    utils::NodeDumpContext dump_context { session_state.GetGraphExecutionCounter(), 0 };
#endif


  for (size_t program_counter = state_.GetProgramCounterStart();
       program_counter < state_.GetProgramCounterEnd();
       program_counter += 1) {
    const auto& node_exec_plan = exec_plan_vec[program_counter];
    auto node_index = node_exec_plan.node_index;
    const auto& node = *graph_viewer.GetNode(node_exec_plan.node_index);

#ifdef CONCURRENCY_VISUALIZER
    series.write_flag(node.Name().c_str());
#endif

#ifdef ENABLE_NVTX_PROFILE
    if (node.Description() != "Backward pass" && !forward_range.IsBeginCalled()) {
      // Start timing forward pass when encountering the first forward node.
      forward_range.Begin();
    } else if (node.Description() == "Backward pass" &&
               !backward_range.IsBeginCalled() && forward_range.IsBeginCalled()) {
      // Start timing backward pass when encountering the first backward node.
      // In the meanwhile, forward range ends.
      forward_range.End();
      backward_range.Begin();
    }
#endif

    auto p_op_kernel = session_state.GetKernel(node_index);
    // if a kernel has been added in the session state, it better be NON-null.
    if (p_op_kernel == nullptr) {
      return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Got nullptr from GetKernel for node: ",
                             node.Name());
    }

    if (p_op_kernel->KernelDef().OpName() == "YieldOp") {
      // Do not execute YieldOp (it is an no-op anyways).
      // Decrement the reference count of tensors that are not needed beyond this point.
      // REVEIW(codemzs): The current model assumes the intermediate tensors that are exported
      // as graph outputs are owned by ORT, the risk of caller freeing the tensor or manipulating tensor
      // memory lingers while the tensor is used downstream after the export.
      VLOGS(logger, 1) << "Releasing node ML values.";
      ORT_RETURN_IF_ERROR(ReleaseNodeMLValues(frame, seq_exec_plan, node_exec_plan, logger));
      continue;
    }

#ifdef ONNXRUNTIME_ENABLE_INSTRUMENT
    LARGE_INTEGER kernel_start;
    QueryPerformanceCounter(&kernel_start);
#endif
    // construct OpKernelContext
    // TODO: log kernel inputs?
    OpKernelContextInternal op_kernel_context(session_state, frame, *p_op_kernel, logger, false);

    // Cache lookup. Currently we only cache single-output nodes,
    // to keep memory overhead impact in check. Hence we only look in cache
    // if the current node has one output.
    bool reuse_cached_value = false;
    std::string cached_arg_name;
    if (cache_ != nullptr) {
      if (p_op_kernel->Node().OutputDefs().size() == 1) {
        cached_arg_name = p_op_kernel->Node().OutputDefs()[0]->Name();
        if (cache_.get()->count(cached_arg_name)) {  // found arg in cache_
          VLOGS(logger, 1) << "Found OrtValue in cache for arg: " << cached_arg_name;
          reuse_cached_value = true;
        }
      }
    }

    // TODO: log kernel outputs?
    if (is_profiler_enabled) {
      sync_time_begin = session_state.Profiler().Start();
    }

    // sync before compute
    int queue_id = p_op_kernel->KernelDef().ExecQueueId();
    if (seq_exec_plan.NodeHasFence(node_index)) {
      for (int input_index = 0; input_index < op_kernel_context.InputCount(); ++input_index) {
        Fence_t fence = op_kernel_context.InputFence(input_index);
        if (fence) {
          auto execution_provider_type = p_op_kernel->Node().GetExecutionProviderType();
          if (OrtMemTypeCPUInput == p_op_kernel->KernelDef().InputMemoryType(input_index)) {
            execution_provider_type = kCpuExecutionProvider;
          }
          fence->BeforeUsingAsInput(execution_provider_type, queue_id);
        }
      }

      for (int input_index = 0; input_index < op_kernel_context.ImplicitInputCount(); ++input_index) {
        Fence_t fence = op_kernel_context.ImplicitInputFence(input_index);
        if (fence) {
          auto execution_provider_type = p_op_kernel->Node().GetExecutionProviderType();
          if (OrtMemTypeCPUInput == p_op_kernel->KernelDef().InputMemoryType(input_index)) {
            execution_provider_type = kCpuExecutionProvider;
          }
          fence->BeforeUsingAsInput(execution_provider_type, queue_id);
        }
      }

      for (int output_index = 0; output_index < op_kernel_context.OutputCount(); ++output_index) {
        Fence_t fence = op_kernel_context.OutputFence(output_index);
        if (fence) {
          fence->BeforeUsingAsOutput(p_op_kernel->Node().GetExecutionProviderType(), queue_id);
        }
      }
    }
#ifdef DEBUG_NODE_INPUTS_OUTPUTS
    dump_context.program_counter = program_counter; 
    utils::DumpNodeInputs(dump_context, op_kernel_context, p_op_kernel->Node(), session_state);
#endif

    const std::string node_name_for_profiling = [&]() -> std::string {
      if (!is_profiler_enabled) return {};
      // Derive something meaningful for profile traces and logs if node name field is blank in execution graph
      return node.Name().empty() ? MakeString(node.OpType(), "_", node_index) : node.Name();
    }();

    if (is_profiler_enabled) {
      session_state.Profiler().EndTimeAndRecordEvent(profiling::NODE_EVENT,
                                                     node_name_for_profiling + "_fence_before",
                                                     sync_time_begin,
                                                     {{"op_name", p_op_kernel->KernelDef().OpName()}});
      concurrency::ThreadPool::StartProfiling(session_state.GetThreadPool());
      // call compute on the kernel
      VLOGS(logger, 1) << "Computing kernel: " << node_name_for_profiling;

      kernel_begin_time = session_state.Profiler().Start();

      // Calculate total input sizes for this operation.
      CalculateTotalInputSizes(&op_kernel_context, p_op_kernel,
                               input_activation_sizes, input_parameter_sizes, node_name_for_profiling);
    }

    Status compute_status;
    {
#ifdef CONCURRENCY_VISUALIZER
      diagnostic::span span(series, "%s.%d", node.OpType().c_str(), node.Index());
#endif
#ifdef ENABLE_NVTX_PROFILE
      profile::NvtxRangeCreator node_compute_range(
          MakeString(node.OpType(), ".", node.Index(), "(", node.Name(), ")"), profile::Color::Yellow);
      node_compute_range.Begin();
#endif
      ORT_TRY {
#ifdef ENABLE_TRAINING
        if (p_op_kernel->KernelDef().AllocateInputsContiguously()) {
          ORT_RETURN_IF_ERROR(utils::VerifyInputTensorsAllocatedContiguously(&op_kernel_context));
        }
#endif
        if (!reuse_cached_value) {
          compute_status = p_op_kernel->Compute(&op_kernel_context);
        } else {
          compute_status = op_kernel_context.SetOutputMLValue(0, cache_.get()->at(cached_arg_name));
        }
      }
      ORT_CATCH(const std::exception& ex) {
        ORT_HANDLE_EXCEPTION([&]() {
          compute_status = ORT_MAKE_STATUS(ONNXRUNTIME, RUNTIME_EXCEPTION, ex.what());
        });
      }

#ifdef ENABLE_NVTX_PROFILE
      node_compute_range.End();
#endif
    }

    if (!compute_status.IsOK()) {
      std::ostringstream ss;
      ss << "Non-zero status code returned while running " << node.OpType() << " node. Name:'" << node.Name()
         << "' Status Message: " << compute_status.ErrorMessage();
//If the computation failed, we still can record the memory consumption
#if !defined(ORT_MINIMAL_BUILD) && defined(ORT_MEMORY_PROFILE)
      MemoryInfo::MemoryInfoProfile::CreateEvents("dynamic activations_" + std::to_string(MemoryInfo::GetIteration()),
                                                  MemoryInfo::MemoryInfoProfile::GetAndIncreasePid(),
                                                  MemoryInfo::MapType::DynamicActivation, "", 0);
#endif
      const auto msg_string = ss.str();
      LOGS(logger, ERROR) << msg_string;
      return Status(compute_status.Category(), compute_status.Code(), msg_string);
    }

    if (is_profiler_enabled) {
      // Calculate total output sizes for this operation.
      CalculateTotalOutputSizes(&op_kernel_context, total_output_sizes, node_name_for_profiling);

#if defined(TRACE_EXECUTION)
      // Trace execution step.
      const Node& node = p_op_kernel->Node();
      std::cout << "Executed op kernel node " << node_name_for_profiling
                << " Index=" << node.Index()
                << " OpType=" << node.OpType()
                << " Name=" << node.Name()
                << " Activation_Size=" << input_activation_sizes
                << " Parameter_Size=" << input_parameter_sizes
                << " Output_Size=" << total_output_sizes
                << "\n";
#endif

      session_state.Profiler().EndTimeAndRecordEvent(profiling::NODE_EVENT,
                                                     node_name_for_profiling + "_kernel_time",
                                                     kernel_begin_time,
                                                     // Log additional operation args / info.
                                                     {
                                                         {"op_name", p_op_kernel->KernelDef().OpName()},
                                                         {"provider", p_op_kernel->KernelDef().Provider()},
                                                         {"graph_index", std::to_string(p_op_kernel->Node().Index())},
                                                         {"exec_plan_index", std::to_string(node_index)},
                                                         {"activation_size", std::to_string(input_activation_sizes)},
                                                         {"parameter_size", std::to_string(input_parameter_sizes)},
                                                         {"output_size", std::to_string(total_output_sizes)},
                                                         {"thread_scheduling_stats",
                                                          concurrency::ThreadPool::StopProfiling(
                                                              session_state.GetThreadPool())},
                                                     });
      sync_time_begin = session_state.Profiler().Start();
    }

    // sync after compute for outputs
    if (seq_exec_plan.NodeHasFence(node_index)) {
      for (int input_index = 0; input_index < op_kernel_context.InputCount(); ++input_index) {
        Fence_t fence = op_kernel_context.InputFence(input_index);
        if (fence) {
          fence->AfterUsedAsInput(queue_id);
        }
      }

      for (int input_index = 0; input_index < op_kernel_context.ImplicitInputCount(); ++input_index) {
        Fence_t fence = op_kernel_context.ImplicitInputFence(input_index);
        if (fence) {
          fence->AfterUsedAsInput(queue_id);
        }
      }

      for (int output_index = 0; output_index < op_kernel_context.OutputCount(); ++output_index) {
        Fence_t fence = op_kernel_context.OutputFence(output_index);
        if (fence) {
          fence->AfterUsedAsOutput(queue_id);
        }
      }
    }
#ifdef ONNXRUNTIME_ENABLE_INSTRUMENT
    LARGE_INTEGER kernel_stop;
    QueryPerformanceCounter(&kernel_stop);
    LARGE_INTEGER elapsed;
    elapsed.QuadPart = kernel_stop.QuadPart - kernel_start.QuadPart;
    elapsed.QuadPart *= 1000000;
    elapsed.QuadPart /= perf_freq.QuadPart;
    // Log an event
    TraceLoggingWrite(telemetry_provider_handle,  // handle to my provider
                      "OpEnd",                    // Event Name that should uniquely identify your event.
                      TraceLoggingValue(p_op_kernel->KernelDef().OpName().c_str(), "op_name"),
                      TraceLoggingValue(elapsed.QuadPart, "time"));
#endif
    if (is_profiler_enabled) {
      session_state.Profiler().EndTimeAndRecordEvent(profiling::NODE_EVENT,
                                                     node_name_for_profiling + "_fence_after",
                                                     sync_time_begin,
                                                     {{"op_name", p_op_kernel->KernelDef().OpName()}});
    }

#ifdef DEBUG_NODE_INPUTS_OUTPUTS
    utils::DumpNodeOutputs(dump_context, op_kernel_context, p_op_kernel->Node(), session_state);
#endif

    // free ml-values corresponding to this node
    VLOGS(logger, 1) << "Releasing node ML values.";
    ORT_RETURN_IF_ERROR(ReleaseNodeMLValues(frame, seq_exec_plan, node_exec_plan, logger));
  }

#ifdef ENABLE_NVTX_PROFILE
  // Make sure forward Range object call Begin and End.
  if (!forward_range.IsBeginCalled()) {
    forward_range.Begin();
  }
  if (!forward_range.IsEndCalled()) {
    forward_range.End();
  }
  // Make sure backward Range object call Begin and End.
  if (!backward_range.IsBeginCalled()) {
    backward_range.Begin();
  }
  if (!backward_range.IsEndCalled()) {
    backward_range.End();
  }
#endif

  VLOGS(logger, 1) << "Fetching output.";
  // ExecutionFrame::Finalize will update 'fetches' with the final output
  ORT_RETURN_IF_ERROR(frame.GetOutputs(fetch_mlvalue_idxs, fetches));
  VLOGS(logger, 1) << "Done with execution.";

#if !defined(ORT_MINIMAL_BUILD) && defined(ORT_MEMORY_PROFILE)
  MemoryInfo::MemoryInfoProfile::CreateEvents("dynamic activations_" + std::to_string(MemoryInfo::GetIteration()),
                                              MemoryInfo::MemoryInfoProfile::GetAndIncreasePid(),
                                              MemoryInfo::MapType::DynamicActivation, "", 0);
  MemoryInfo::MemoryInfoProfile::Clear();
#endif

  if (frame.HasMemoryPatternPlanner()) {
    bool all_tensors = true;
    for (const auto& feed : feeds) {
      if (!(feed.IsTensor())) {
        all_tensors = false;
        break;
      }
    }

    if (all_tensors) {
      auto mem_patterns = std::make_unique<MemoryPatternGroup>();
      ORT_RETURN_IF_ERROR(frame.GeneratePatterns(mem_patterns.get()));
      ORT_RETURN_IF_ERROR(session_state.UpdateMemoryPatternGroupCache(feeds, std::move(mem_patterns)));
    }
  }

  if (is_profiler_enabled) {
    session_state.Profiler().EndTimeAndRecordEvent(profiling::SESSION_EVENT, "SequentialExecutor::Execute", tp);
  }

#if !defined(ORT_MINIMAL_BUILD) && defined(ORT_MEMORY_PROFILE)
  for (auto i : frame.GetStaticMemorySizeInfo()) {
    LOGS(logger, INFO) << "[Memory] ExecutionFrame statically allocates "
                       << i.second << " bytes for " << i.first << std::endl;
  }

  for (auto i : frame.GetDynamicMemorySizeInfo()) {
    LOGS(logger, INFO) << "[Memory] ExecutionFrame dynamically allocates "
                       << i.second << " bytes for " << i.first << std::endl;
  }
#endif

  return Status::OK();
}