in onnxruntime/core/framework/orttraining_partial_executor.cc [133:519]
Status PartialExecutor::Execute(const SessionState& session_state, const std::vector<int>& feed_mlvalue_idxs,
const std::vector<OrtValue>& feeds, const std::vector<int>& fetch_mlvalue_idxs,
std::vector<OrtValue>& fetches,
const std::unordered_map<size_t, CustomAllocator>& fetch_allocators,
const logging::Logger& logger) {
const bool is_profiler_enabled = session_state.Profiler().IsEnabled();
TimePoint tp;
TimePoint sync_time_begin;
TimePoint kernel_begin_time;
size_t input_activation_sizes = 0;
size_t input_parameter_sizes = 0;
size_t total_output_sizes = 0;
if (is_profiler_enabled) {
tp = session_state.Profiler().Start();
}
ExecutionFrame& frame = state_.GetExecutionFrame(feed_mlvalue_idxs, feeds, fetch_mlvalue_idxs, fetches,
fetch_allocators, session_state);
LOGS(logger, INFO) << "Begin execution";
const SequentialExecutionPlan& seq_exec_plan = *session_state.GetExecutionPlan();
const auto& exec_plan_vec = seq_exec_plan.execution_plan;
VLOGS(logger, 1) << "Size of execution plan vector: " << exec_plan_vec.size();
// Enable TRACE_EXECUTION compile flag to dump execution plan
#if defined(TRACE_EXECUTION)
std::cout << std::make_pair(&seq_exec_plan, &session_state) << std::endl;
#endif
const auto& graph_viewer = session_state.GetGraphViewer();
#ifdef CONCURRENCY_VISUALIZER
// need unique name for the series. number of nodes should be good enough for a subgraph
char series_name[MaxSeriesNameLengthInChars] = "MainGraph";
if (graph_viewer.IsSubgraph()) {
auto s = graph_viewer.ParentNode()->Name().substr(0, MaxSeriesNameLengthInChars - 1);
std::copy(s.cbegin(), s.cend(), series_name);
}
diagnostic::marker_series series(series_name);
#endif
#ifdef ENABLE_NVTX_PROFILE
auto& profile_context = profile::Context::GetInstance();
const auto tag = profile_context.GetThreadTagOrDefault(std::this_thread::get_id());
profile::NvtxRangeCreator forward_range(
"Batch-" + tag + " Forward",
profile::Color::White);
profile::NvtxRangeCreator backward_range(
"Batch-" + tag + " Backward",
profile::Color::Black);
#endif
#ifdef DEBUG_NODE_INPUTS_OUTPUTS
utils::NodeDumpContext dump_context { session_state.GetGraphExecutionCounter(), 0 };
#endif
for (size_t program_counter = state_.GetProgramCounterStart();
program_counter < state_.GetProgramCounterEnd();
program_counter += 1) {
const auto& node_exec_plan = exec_plan_vec[program_counter];
auto node_index = node_exec_plan.node_index;
const auto& node = *graph_viewer.GetNode(node_exec_plan.node_index);
#ifdef CONCURRENCY_VISUALIZER
series.write_flag(node.Name().c_str());
#endif
#ifdef ENABLE_NVTX_PROFILE
if (node.Description() != "Backward pass" && !forward_range.IsBeginCalled()) {
// Start timing forward pass when encountering the first forward node.
forward_range.Begin();
} else if (node.Description() == "Backward pass" &&
!backward_range.IsBeginCalled() && forward_range.IsBeginCalled()) {
// Start timing backward pass when encountering the first backward node.
// In the meanwhile, forward range ends.
forward_range.End();
backward_range.Begin();
}
#endif
auto p_op_kernel = session_state.GetKernel(node_index);
// if a kernel has been added in the session state, it better be NON-null.
if (p_op_kernel == nullptr) {
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Got nullptr from GetKernel for node: ",
node.Name());
}
if (p_op_kernel->KernelDef().OpName() == "YieldOp") {
// Do not execute YieldOp (it is an no-op anyways).
// Decrement the reference count of tensors that are not needed beyond this point.
// REVEIW(codemzs): The current model assumes the intermediate tensors that are exported
// as graph outputs are owned by ORT, the risk of caller freeing the tensor or manipulating tensor
// memory lingers while the tensor is used downstream after the export.
VLOGS(logger, 1) << "Releasing node ML values.";
ORT_RETURN_IF_ERROR(ReleaseNodeMLValues(frame, seq_exec_plan, node_exec_plan, logger));
continue;
}
#ifdef ONNXRUNTIME_ENABLE_INSTRUMENT
LARGE_INTEGER kernel_start;
QueryPerformanceCounter(&kernel_start);
#endif
// construct OpKernelContext
// TODO: log kernel inputs?
OpKernelContextInternal op_kernel_context(session_state, frame, *p_op_kernel, logger, false);
// Cache lookup. Currently we only cache single-output nodes,
// to keep memory overhead impact in check. Hence we only look in cache
// if the current node has one output.
bool reuse_cached_value = false;
std::string cached_arg_name;
if (cache_ != nullptr) {
if (p_op_kernel->Node().OutputDefs().size() == 1) {
cached_arg_name = p_op_kernel->Node().OutputDefs()[0]->Name();
if (cache_.get()->count(cached_arg_name)) { // found arg in cache_
VLOGS(logger, 1) << "Found OrtValue in cache for arg: " << cached_arg_name;
reuse_cached_value = true;
}
}
}
// TODO: log kernel outputs?
if (is_profiler_enabled) {
sync_time_begin = session_state.Profiler().Start();
}
// sync before compute
int queue_id = p_op_kernel->KernelDef().ExecQueueId();
if (seq_exec_plan.NodeHasFence(node_index)) {
for (int input_index = 0; input_index < op_kernel_context.InputCount(); ++input_index) {
Fence_t fence = op_kernel_context.InputFence(input_index);
if (fence) {
auto execution_provider_type = p_op_kernel->Node().GetExecutionProviderType();
if (OrtMemTypeCPUInput == p_op_kernel->KernelDef().InputMemoryType(input_index)) {
execution_provider_type = kCpuExecutionProvider;
}
fence->BeforeUsingAsInput(execution_provider_type, queue_id);
}
}
for (int input_index = 0; input_index < op_kernel_context.ImplicitInputCount(); ++input_index) {
Fence_t fence = op_kernel_context.ImplicitInputFence(input_index);
if (fence) {
auto execution_provider_type = p_op_kernel->Node().GetExecutionProviderType();
if (OrtMemTypeCPUInput == p_op_kernel->KernelDef().InputMemoryType(input_index)) {
execution_provider_type = kCpuExecutionProvider;
}
fence->BeforeUsingAsInput(execution_provider_type, queue_id);
}
}
for (int output_index = 0; output_index < op_kernel_context.OutputCount(); ++output_index) {
Fence_t fence = op_kernel_context.OutputFence(output_index);
if (fence) {
fence->BeforeUsingAsOutput(p_op_kernel->Node().GetExecutionProviderType(), queue_id);
}
}
}
#ifdef DEBUG_NODE_INPUTS_OUTPUTS
dump_context.program_counter = program_counter;
utils::DumpNodeInputs(dump_context, op_kernel_context, p_op_kernel->Node(), session_state);
#endif
const std::string node_name_for_profiling = [&]() -> std::string {
if (!is_profiler_enabled) return {};
// Derive something meaningful for profile traces and logs if node name field is blank in execution graph
return node.Name().empty() ? MakeString(node.OpType(), "_", node_index) : node.Name();
}();
if (is_profiler_enabled) {
session_state.Profiler().EndTimeAndRecordEvent(profiling::NODE_EVENT,
node_name_for_profiling + "_fence_before",
sync_time_begin,
{{"op_name", p_op_kernel->KernelDef().OpName()}});
concurrency::ThreadPool::StartProfiling(session_state.GetThreadPool());
// call compute on the kernel
VLOGS(logger, 1) << "Computing kernel: " << node_name_for_profiling;
kernel_begin_time = session_state.Profiler().Start();
// Calculate total input sizes for this operation.
CalculateTotalInputSizes(&op_kernel_context, p_op_kernel,
input_activation_sizes, input_parameter_sizes, node_name_for_profiling);
}
Status compute_status;
{
#ifdef CONCURRENCY_VISUALIZER
diagnostic::span span(series, "%s.%d", node.OpType().c_str(), node.Index());
#endif
#ifdef ENABLE_NVTX_PROFILE
profile::NvtxRangeCreator node_compute_range(
MakeString(node.OpType(), ".", node.Index(), "(", node.Name(), ")"), profile::Color::Yellow);
node_compute_range.Begin();
#endif
ORT_TRY {
#ifdef ENABLE_TRAINING
if (p_op_kernel->KernelDef().AllocateInputsContiguously()) {
ORT_RETURN_IF_ERROR(utils::VerifyInputTensorsAllocatedContiguously(&op_kernel_context));
}
#endif
if (!reuse_cached_value) {
compute_status = p_op_kernel->Compute(&op_kernel_context);
} else {
compute_status = op_kernel_context.SetOutputMLValue(0, cache_.get()->at(cached_arg_name));
}
}
ORT_CATCH(const std::exception& ex) {
ORT_HANDLE_EXCEPTION([&]() {
compute_status = ORT_MAKE_STATUS(ONNXRUNTIME, RUNTIME_EXCEPTION, ex.what());
});
}
#ifdef ENABLE_NVTX_PROFILE
node_compute_range.End();
#endif
}
if (!compute_status.IsOK()) {
std::ostringstream ss;
ss << "Non-zero status code returned while running " << node.OpType() << " node. Name:'" << node.Name()
<< "' Status Message: " << compute_status.ErrorMessage();
//If the computation failed, we still can record the memory consumption
#if !defined(ORT_MINIMAL_BUILD) && defined(ORT_MEMORY_PROFILE)
MemoryInfo::MemoryInfoProfile::CreateEvents("dynamic activations_" + std::to_string(MemoryInfo::GetIteration()),
MemoryInfo::MemoryInfoProfile::GetAndIncreasePid(),
MemoryInfo::MapType::DynamicActivation, "", 0);
#endif
const auto msg_string = ss.str();
LOGS(logger, ERROR) << msg_string;
return Status(compute_status.Category(), compute_status.Code(), msg_string);
}
if (is_profiler_enabled) {
// Calculate total output sizes for this operation.
CalculateTotalOutputSizes(&op_kernel_context, total_output_sizes, node_name_for_profiling);
#if defined(TRACE_EXECUTION)
// Trace execution step.
const Node& node = p_op_kernel->Node();
std::cout << "Executed op kernel node " << node_name_for_profiling
<< " Index=" << node.Index()
<< " OpType=" << node.OpType()
<< " Name=" << node.Name()
<< " Activation_Size=" << input_activation_sizes
<< " Parameter_Size=" << input_parameter_sizes
<< " Output_Size=" << total_output_sizes
<< "\n";
#endif
session_state.Profiler().EndTimeAndRecordEvent(profiling::NODE_EVENT,
node_name_for_profiling + "_kernel_time",
kernel_begin_time,
// Log additional operation args / info.
{
{"op_name", p_op_kernel->KernelDef().OpName()},
{"provider", p_op_kernel->KernelDef().Provider()},
{"graph_index", std::to_string(p_op_kernel->Node().Index())},
{"exec_plan_index", std::to_string(node_index)},
{"activation_size", std::to_string(input_activation_sizes)},
{"parameter_size", std::to_string(input_parameter_sizes)},
{"output_size", std::to_string(total_output_sizes)},
{"thread_scheduling_stats",
concurrency::ThreadPool::StopProfiling(
session_state.GetThreadPool())},
});
sync_time_begin = session_state.Profiler().Start();
}
// sync after compute for outputs
if (seq_exec_plan.NodeHasFence(node_index)) {
for (int input_index = 0; input_index < op_kernel_context.InputCount(); ++input_index) {
Fence_t fence = op_kernel_context.InputFence(input_index);
if (fence) {
fence->AfterUsedAsInput(queue_id);
}
}
for (int input_index = 0; input_index < op_kernel_context.ImplicitInputCount(); ++input_index) {
Fence_t fence = op_kernel_context.ImplicitInputFence(input_index);
if (fence) {
fence->AfterUsedAsInput(queue_id);
}
}
for (int output_index = 0; output_index < op_kernel_context.OutputCount(); ++output_index) {
Fence_t fence = op_kernel_context.OutputFence(output_index);
if (fence) {
fence->AfterUsedAsOutput(queue_id);
}
}
}
#ifdef ONNXRUNTIME_ENABLE_INSTRUMENT
LARGE_INTEGER kernel_stop;
QueryPerformanceCounter(&kernel_stop);
LARGE_INTEGER elapsed;
elapsed.QuadPart = kernel_stop.QuadPart - kernel_start.QuadPart;
elapsed.QuadPart *= 1000000;
elapsed.QuadPart /= perf_freq.QuadPart;
// Log an event
TraceLoggingWrite(telemetry_provider_handle, // handle to my provider
"OpEnd", // Event Name that should uniquely identify your event.
TraceLoggingValue(p_op_kernel->KernelDef().OpName().c_str(), "op_name"),
TraceLoggingValue(elapsed.QuadPart, "time"));
#endif
if (is_profiler_enabled) {
session_state.Profiler().EndTimeAndRecordEvent(profiling::NODE_EVENT,
node_name_for_profiling + "_fence_after",
sync_time_begin,
{{"op_name", p_op_kernel->KernelDef().OpName()}});
}
#ifdef DEBUG_NODE_INPUTS_OUTPUTS
utils::DumpNodeOutputs(dump_context, op_kernel_context, p_op_kernel->Node(), session_state);
#endif
// free ml-values corresponding to this node
VLOGS(logger, 1) << "Releasing node ML values.";
ORT_RETURN_IF_ERROR(ReleaseNodeMLValues(frame, seq_exec_plan, node_exec_plan, logger));
}
#ifdef ENABLE_NVTX_PROFILE
// Make sure forward Range object call Begin and End.
if (!forward_range.IsBeginCalled()) {
forward_range.Begin();
}
if (!forward_range.IsEndCalled()) {
forward_range.End();
}
// Make sure backward Range object call Begin and End.
if (!backward_range.IsBeginCalled()) {
backward_range.Begin();
}
if (!backward_range.IsEndCalled()) {
backward_range.End();
}
#endif
VLOGS(logger, 1) << "Fetching output.";
// ExecutionFrame::Finalize will update 'fetches' with the final output
ORT_RETURN_IF_ERROR(frame.GetOutputs(fetch_mlvalue_idxs, fetches));
VLOGS(logger, 1) << "Done with execution.";
#if !defined(ORT_MINIMAL_BUILD) && defined(ORT_MEMORY_PROFILE)
MemoryInfo::MemoryInfoProfile::CreateEvents("dynamic activations_" + std::to_string(MemoryInfo::GetIteration()),
MemoryInfo::MemoryInfoProfile::GetAndIncreasePid(),
MemoryInfo::MapType::DynamicActivation, "", 0);
MemoryInfo::MemoryInfoProfile::Clear();
#endif
if (frame.HasMemoryPatternPlanner()) {
bool all_tensors = true;
for (const auto& feed : feeds) {
if (!(feed.IsTensor())) {
all_tensors = false;
break;
}
}
if (all_tensors) {
auto mem_patterns = std::make_unique<MemoryPatternGroup>();
ORT_RETURN_IF_ERROR(frame.GeneratePatterns(mem_patterns.get()));
ORT_RETURN_IF_ERROR(session_state.UpdateMemoryPatternGroupCache(feeds, std::move(mem_patterns)));
}
}
if (is_profiler_enabled) {
session_state.Profiler().EndTimeAndRecordEvent(profiling::SESSION_EVENT, "SequentialExecutor::Execute", tp);
}
#if !defined(ORT_MINIMAL_BUILD) && defined(ORT_MEMORY_PROFILE)
for (auto i : frame.GetStaticMemorySizeInfo()) {
LOGS(logger, INFO) << "[Memory] ExecutionFrame statically allocates "
<< i.second << " bytes for " << i.first << std::endl;
}
for (auto i : frame.GetDynamicMemorySizeInfo()) {
LOGS(logger, INFO) << "[Memory] ExecutionFrame dynamically allocates "
<< i.second << " bytes for " << i.first << std::endl;
}
#endif
return Status::OK();
}