in lib/bef_executor_driver/bef_executor_driver.cc [72:254]
int RunBefExecutor(
const RunBefConfig& run_config,
const std::function<llvm::Expected<ExecutionContext>(
HostContext*, ResourceContext*)>& create_execution_context) {
assert(create_execution_context);
TFRT_TRACE_SCOPE(Default, "Bef Executor");
metrics::AddTFRTVersionMetric();
// Set up the input file.
std::string error_message;
auto file = mlir::openInputFile(run_config.input_filename, &error_message);
if (!file) {
llvm::errs() << error_message << "\n";
return 1;
}
// Parse the input file.
mlir::MLIRContext context;
llvm::SourceMgr source_mgr;
mlir::SourceMgrDiagnosticVerifierHandler source_mgr_handler(source_mgr,
&context);
auto get_loc = [&](Optional<DecodedLocation> loc) -> mlir::Location {
if (!loc) return mlir::UnknownLoc::get(&context);
if (loc->is<FileLineColLocation>()) {
auto file_loc = loc->get<FileLineColLocation>();
return mlir::FileLineColLoc::get(&context, file_loc.filename,
file_loc.line, file_loc.column);
};
return mlir::NameLoc::get(
mlir::StringAttr::get(&context, loc->get<OpaqueLocation>().loc));
};
auto decoded_diagnostic_handler = [&](const DecodedDiagnostic& diag) {
emitError(get_loc(diag.location)) << "runtime error: " << diag.message;
};
assert(GetNumReferenceCountedObjects() == 0 &&
"We have reference-counted objects before we started to do anything");
std::unique_ptr<HostAllocator> host_allocator;
switch (run_config.host_allocator_type) {
case HostAllocatorType::kMalloc:
host_allocator = CreateMallocAllocator();
tfrt::outs() << "Choosing malloc.\n";
break;
case HostAllocatorType::kTestFixedSizeMalloc:
host_allocator = tfrt::CreateFixedSizeAllocator();
tfrt::outs() << "Choosing fixed size malloc.\n";
break;
case HostAllocatorType::kProfiledMalloc:
host_allocator = CreateMallocAllocator();
host_allocator = CreateProfiledAllocator(std::move(host_allocator));
tfrt::outs() << "Choosing profiled allocator based on malloc.\n";
break;
case HostAllocatorType::kLeakCheckMalloc:
host_allocator = CreateMallocAllocator();
host_allocator = CreateLeakCheckAllocator(std::move(host_allocator));
tfrt::outs() << "Choosing memory leak check allocator.\n";
}
tfrt::outs().flush();
auto buffer = file->getBuffer();
// Handle BefBuffer alignment.
// mlir::openInputFile() should return 4KB aligned buffer when a file is
// memory-mapped. When the returned buffer is not aligned by 4KB, it could
// be from stdin by pipe operator.
// The following logic create an aligned buffer (BefBuffer),
// and copy the buffer contents.
// The original buffer cannot be released because of source_mgr_handler.
llvm::ArrayRef<uint8_t> buffer_arr;
BefBuffer aligned_bef_buffer;
if (reinterpret_cast<uint64_t>(buffer.data()) % GetRequiredBefAlignment()) {
aligned_bef_buffer.resize(buffer.size());
std::memcpy(aligned_bef_buffer.data(), buffer.data(), buffer.size());
buffer_arr = llvm::ArrayRef<uint8_t>(
reinterpret_cast<const uint8_t*>(aligned_bef_buffer.data()),
aligned_bef_buffer.size());
} else {
buffer_arr = llvm::ArrayRef<uint8_t>(
reinterpret_cast<const uint8_t*>(buffer.data()), buffer.size());
}
std::unique_ptr<ConcurrentWorkQueue> work_queue =
CreateWorkQueue(run_config.work_queue_type);
if (work_queue == nullptr) {
llvm::errs() << run_config.program_name
<< ": couldn't create work queue type "
<< run_config.work_queue_type << "\n";
return 1;
}
tfrt::outs() << "Choosing " << work_queue->name() << " work queue.\n";
tfrt::outs().flush();
assert(AsyncValue::GetNumAsyncValueInstances() == 0 &&
"We have async values allocated before we started to do anything");
auto async_value_guard = llvm::make_scope_exit([]() {
assert(AsyncValue::GetNumAsyncValueInstances() == 0 &&
"All async values should be cleaned up at the end");
assert(GetNumReferenceCountedObjects() == 0 &&
"We have live reference-counted objects before exit.");
});
auto core_rt =
CoreRuntime::Create(decoded_diagnostic_handler, std::move(host_allocator),
std::move(work_queue));
if (!core_rt) {
llvm::errs() << core_rt.takeError();
return 1;
}
auto* host = core_rt.get()->GetHostContext();
// If there are any libraries specified, load them and see if they have a
// kernel registration function.
for (const auto& lib_name : run_config.shared_libs) {
std::string err;
auto dyn_lib =
llvm::sys::DynamicLibrary::getPermanentLibrary(lib_name.c_str(), &err);
if (!dyn_lib.isValid()) {
llvm::errs() << run_config.program_name << ": couldn't load library "
<< err << "\n";
return 1;
}
// The library should specify a kernel registration entrypoint.
if (auto kernel_reg = dyn_lib.SearchForAddressOfSymbol("RegisterKernels")) {
reinterpret_cast<void (*)(KernelRegistry*)>(kernel_reg)(
host->GetMutableRegistry());
}
}
auto bef(BEFFile::Open(buffer_arr, host->GetKernelRegistry(),
decoded_diagnostic_handler, host->allocator()));
if (!bef) {
return mlir::failed(source_mgr_handler.verify());
}
llvm::SmallVector<const Function*, 8> function_list;
if (run_config.functions.empty()) {
// No functions specified in the command line. Try to run all functions in
// the input BEF file.
bef->GetFunctionList(&function_list);
} else {
function_list.reserve(run_config.functions.size());
for (auto& fn_name : run_config.functions) {
auto* fn = bef->GetFunction(fn_name);
if (!fn) {
llvm::errs() << run_config.program_name << ": couldn't find function "
<< fn_name << "\n";
return 1;
}
function_list.push_back(fn);
}
}
// Run the init function first if exists.
auto test_init_function = bef->GetFunction(run_config.test_init_function);
if (test_init_function) {
RunBefFunction(host, *test_init_function, create_execution_context,
run_config.print_error_code);
}
// Loop over each of the functions, running each as a standalone testcase.
for (auto* fn : function_list) {
if (fn != test_init_function) {
RunBefFunction(host, *fn, create_execution_context,
run_config.print_error_code);
}
}
bef.reset();
// Verify the diagnostic handler to make sure that each of the diagnostics
// matched.
return mlir::failed(source_mgr_handler.verify());
}