in include/tfrt/core_runtime/dispatch_utils.h [375:467]
void ExecuteWithResultMetadataResolved(
const ExecutionContext& exec_ctx, MutableArrayRef<TensorHandle> arguments,
const OpAttrsRef& attrs, size_t num_results,
const llvm::SmallVector<TensorMetadata, 4>& result_mds,
llvm::SmallVectorImpl<AsyncValueRef<TensorMetadata>>* result_md_avs,
llvm::SmallVectorImpl<AsyncValueRef<Tensor>>* result_tensor_avs,
AsyncValueRef<Chain>* chain, bool update_chain,
typename OpHandlerTraits::OpEntryTy op_entry,
typename OpHandlerTraits::OpHandlerInfoTy op_handler_info) {
// If we have no input metadatas (from a metadata function) then we need to
// resolve the TensorHandle metadata's from the op results.
if (result_md_avs) {
result_md_avs->reserve(num_results);
for (size_t i = 0; i != num_results; ++i) {
result_md_avs->push_back(
MakeUnconstructedAsyncValueRef<TensorMetadata>(exec_ctx.host()));
}
}
// Keep track of all the non-resolved values to see if we can dispatch the
// kernel immediately. If not we will "and then" on these non-resolved values.
llvm::SmallVector<AsyncValue*, 4> async_args;
async_args.reserve(arguments.size() + 1);
llvm::SmallVector<RCReference<AsyncValue>, 4> arg_tensors;
arg_tensors.reserve(arguments.size());
assert((!update_chain || (chain && *chain)) &&
"the op requires an in chain.");
if (chain && *chain) {
if (!chain->IsAvailable()) async_args.push_back(chain->GetAsyncValue());
if (update_chain) {
// TODO(fishx): Avoid this heap allocation.
*chain = MakeUnconstructedAsyncValueRef<Chain>(exec_ctx.host());
}
}
for (auto& argument : arguments) {
AsyncValue* async_tensor = argument.GetAsyncTensor();
// Keep track of unavailable arguments so we can "and then" them. We handle
// errors through the slow path as well.
if (!async_tensor->IsConcrete()) {
async_args.push_back(async_tensor);
}
arg_tensors.push_back(argument.ReleaseTensorRef());
}
if (async_args.empty()) {
// All input tensor and input chain are available. We can immediately
// dispatch the kernel synchronously.
llvm::SmallVector<RCReference<AsyncValue>, 4> result_tensors;
llvm::SmallVector<AsyncValueRef<TensorMetadata>, 0> empty_md_avs;
internal::AsyncOpDispatcher<OpHandlerTraits>::RunDispatchFunctionSync(
op_entry, op_handler_info, arg_tensors, attrs, num_results, result_mds,
result_md_avs ? *result_md_avs : empty_md_avs, &result_tensors,
update_chain ? chain : nullptr, exec_ctx);
result_tensor_avs->reserve(num_results);
// Fulfill the result async values with the results of the op.
for (size_t i = 0; i != num_results; ++i) {
result_tensor_avs->push_back(
AsyncValueRef<Tensor>(std::move(result_tensors[i])));
}
return;
}
// We have at least one async tensor input, so we need to run the
// kernel when it resolves.
internal::AsyncOpDispatcher<OpHandlerTraits> op_dispatcher(
exec_ctx, attrs.freeze(), std::move(arg_tensors),
update_chain ? chain->CopyRef() : AsyncValueRef<Chain>(), result_mds,
std::move(op_entry), std::move(op_handler_info));
// The results have to be immediately available, but we don't know what
// concrete Tensor type they will be fulfilled with. Create
// IndirectAsyncValue's to handle this.
result_tensor_avs->reserve(num_results);
op_dispatcher.result_ind_avs_ref().reserve(num_results);
for (size_t i = 0; i != num_results; ++i) {
auto tensor = MakeIndirectAsyncValue(exec_ctx.host());
op_dispatcher.result_ind_avs_ref().push_back(tensor);
result_tensor_avs->push_back(AsyncValueRef<Tensor>(std::move(tensor)));
if (result_md_avs) {
op_dispatcher.result_missing_md_avs_ref().push_back(
(*result_md_avs)[i].CopyRef());
}
}
RunWhenReady(async_args,
[op_dispatcher = std::move(op_dispatcher)]() mutable {
op_dispatcher.RunDispatchFunction();
});
}