void ExecuteWithResultMetadataResolved()

in include/tfrt/core_runtime/dispatch_utils.h [375:467]


void ExecuteWithResultMetadataResolved(
    const ExecutionContext& exec_ctx, MutableArrayRef<TensorHandle> arguments,
    const OpAttrsRef& attrs, size_t num_results,
    const llvm::SmallVector<TensorMetadata, 4>& result_mds,
    llvm::SmallVectorImpl<AsyncValueRef<TensorMetadata>>* result_md_avs,
    llvm::SmallVectorImpl<AsyncValueRef<Tensor>>* result_tensor_avs,
    AsyncValueRef<Chain>* chain, bool update_chain,
    typename OpHandlerTraits::OpEntryTy op_entry,
    typename OpHandlerTraits::OpHandlerInfoTy op_handler_info) {
  // If we have no input metadatas (from a metadata function) then we need to
  // resolve the TensorHandle metadata's from the op results.
  if (result_md_avs) {
    result_md_avs->reserve(num_results);
    for (size_t i = 0; i != num_results; ++i) {
      result_md_avs->push_back(
          MakeUnconstructedAsyncValueRef<TensorMetadata>(exec_ctx.host()));
    }
  }

  // Keep track of all the non-resolved values to see if we can dispatch the
  // kernel immediately. If not we will "and then" on these non-resolved values.
  llvm::SmallVector<AsyncValue*, 4> async_args;
  async_args.reserve(arguments.size() + 1);
  llvm::SmallVector<RCReference<AsyncValue>, 4> arg_tensors;
  arg_tensors.reserve(arguments.size());

  assert((!update_chain || (chain && *chain)) &&
         "the op requires an in chain.");
  if (chain && *chain) {
    if (!chain->IsAvailable()) async_args.push_back(chain->GetAsyncValue());
    if (update_chain) {
      // TODO(fishx): Avoid this heap allocation.
      *chain = MakeUnconstructedAsyncValueRef<Chain>(exec_ctx.host());
    }
  }

  for (auto& argument : arguments) {
    AsyncValue* async_tensor = argument.GetAsyncTensor();

    // Keep track of unavailable arguments so we can "and then" them.  We handle
    // errors through the slow path as well.
    if (!async_tensor->IsConcrete()) {
      async_args.push_back(async_tensor);
    }

    arg_tensors.push_back(argument.ReleaseTensorRef());
  }

  if (async_args.empty()) {
    // All input tensor and input chain are available. We can immediately
    // dispatch the kernel synchronously.
    llvm::SmallVector<RCReference<AsyncValue>, 4> result_tensors;
    llvm::SmallVector<AsyncValueRef<TensorMetadata>, 0> empty_md_avs;
    internal::AsyncOpDispatcher<OpHandlerTraits>::RunDispatchFunctionSync(
        op_entry, op_handler_info, arg_tensors, attrs, num_results, result_mds,
        result_md_avs ? *result_md_avs : empty_md_avs, &result_tensors,
        update_chain ? chain : nullptr, exec_ctx);
    result_tensor_avs->reserve(num_results);
    // Fulfill the result async values with the results of the op.
    for (size_t i = 0; i != num_results; ++i) {
      result_tensor_avs->push_back(
          AsyncValueRef<Tensor>(std::move(result_tensors[i])));
    }
    return;
  }

  // We have at least one async tensor input, so we need to run the
  // kernel when it resolves.
  internal::AsyncOpDispatcher<OpHandlerTraits> op_dispatcher(
      exec_ctx, attrs.freeze(), std::move(arg_tensors),
      update_chain ? chain->CopyRef() : AsyncValueRef<Chain>(), result_mds,
      std::move(op_entry), std::move(op_handler_info));

  // The results have to be immediately available, but we don't know what
  // concrete Tensor type they will be fulfilled with.  Create
  // IndirectAsyncValue's to handle this.
  result_tensor_avs->reserve(num_results);
  op_dispatcher.result_ind_avs_ref().reserve(num_results);
  for (size_t i = 0; i != num_results; ++i) {
    auto tensor = MakeIndirectAsyncValue(exec_ctx.host());
    op_dispatcher.result_ind_avs_ref().push_back(tensor);
    result_tensor_avs->push_back(AsyncValueRef<Tensor>(std::move(tensor)));
    if (result_md_avs) {
      op_dispatcher.result_missing_md_avs_ref().push_back(
          (*result_md_avs)[i].CopyRef());
    }
  }

  RunWhenReady(async_args,
               [op_dispatcher = std::move(op_dispatcher)]() mutable {
                 op_dispatcher.RunDispatchFunction();
               });
}