void RemoteOpHandler::Execute()

in lib/distributed_runtime/remote_op_handler.cc [128:248]


void RemoteOpHandler::Execute(const std::string& op_name,
                              const OpInvocation& invocation) {
  // Wait for async input tensors before starting to dispatch the
  // remote op.
  llvm::SmallVector<AsyncValue*, 4> to_wait;

  auto chain = MakeConstructedAsyncValueRef<Chain>(dist_ctx_->GetHostContext());
  *invocation.chain = chain.CopyRef();

  // We may need to wait for async inputs, so we add input object ids
  // after waiting.
  auto arguments =
      std::make_unique<llvm::SmallVector<RCReference<AsyncValue>, 8>>();
  arguments->resize(invocation.arguments.size());
  for (auto i = 0; i < invocation.arguments.size(); ++i) {
    auto* tensor_av = invocation.arguments[i].GetAsyncTensor();
    (*arguments)[i] = FormRef(tensor_av);
    if (!tensor_av->IsAvailable()) {
      to_wait.push_back(tensor_av);
    }
  }

  auto request = std::make_unique<RemoteExecuteOpRequest>();
  request->set_context_id(dist_ctx_->GetContextId());
  request->set_op_handler_name(remote_device_->name().str());
  request->set_op_name(op_name);

  // Add output object ids to the request.
  // TODO(ayushd): optimize so that metadata is not always asynchronous.
  auto results = std::make_unique<llvm::SmallVector<TensorAndMetadata, 8>>();
  results->resize(invocation.results.size());
  for (auto i = 0; i < invocation.results.size(); ++i) {
    auto tensor = MakeUnconstructedAsyncValueRef<RemoteTensor>(
        dist_ctx_->GetHostContext());
    auto metadata = MakeUnconstructedAsyncValueRef<TensorMetadata>(
        dist_ctx_->GetHostContext());
    invocation.results[i] =
        TensorHandle(remote_device_, metadata.CopyRef(), tensor.CopyRef());
    (*results)[i].tensor = tensor.ReleaseRCRef();
    (*results)[i].metadata = std::move(metadata);
    (*results)[i].remote_object_id = std::make_unique<RemoteObjectId>(
        dist_ctx_->GetRemoteObjectManager()->AllocateRemoteObject(
            remote_device_));

    PopulateRemoteExecuteOutputProto(request->add_output(),
                                     *(*results)[i].remote_object_id);
  }

  // Add information about the chains to the request.
  TaskHandle remote_task = remote_device_->GetTaskHandle();
  auto in_chain_id =
      remote_chain_manager_->GetRemoteChain(remote_device_->GetTaskHandle());
  PopulateRemoteObjectIdProto(
      request->mutable_in_chain(),
      remote_chain_manager_->GetRemoteChain(remote_task));
  auto out_chain_id =
      dist_ctx_->GetRemoteObjectManager()->AllocateRemoteObject(remote_device_);
  PopulateRemoteObjectIdProto(request->mutable_out_chain(), out_chain_id);
  remote_chain_manager_->SetRemoteChain(remote_task, out_chain_id);

  // Add op attributes to the request.
  if (Error attr_error =
          PopulateRequestAttrsProto(request.get(), invocation.attrs.freeze())) {
    chain.SetError(attr_error);
    for (auto& output_th : invocation.results) {
      output_th.GetAsyncTensor()->SetError(DecodedDiagnostic(attr_error));
    }
    return;
  }

  RunWhenReady(
      to_wait,
      [dist_ctx = dist_ctx_, arguments = std::move(arguments),
       results = std::move(results), request = std::move(request),
       attrs = invocation.attrs.freeze(), remote_task = std::move(remote_task),
       chain = std::move(chain)]() mutable {
        // Add input object ids to the request.
        for (auto& input : *arguments) {
          // Each TensorHandle should contain an available RemoteTensor.
          // The corresponding tensor on the remote side may be
          // unavailable.
          assert(input->IsAvailable());
          auto* request_input = request->add_input();
          PopulateRemoteObjectIdProto(
              request_input, input->get<RemoteTensor>().remote_object_id());
          TFRT_DLOG(INFO) << "RemoteOpHandler input "
                          << request_input->DebugString();
        }

        auto response = std::make_unique<RemoteExecuteOpResponse>();
        RemoteClientInterface* remote_client =
            dist_ctx->GetRemoteClient(remote_task);
        remote_client->RemoteExecuteOpAsync(
            RemoteCallContext::GetDefault(), request.get(), response.get(),
            [results = std::move(results), request = std::move(request),
             response = std::move(response),
             chain = chain.CopyRef()](Error e) mutable {
              if (e) {
                chain.SetError(std::move(e));
                return;
              }
              if (response->metadata_size() != results->size()) {
                chain.SetError("unexpected number of remote results");
                return;
              }
              for (auto i = 0; i < response->metadata_size(); ++i) {
                Expected<TensorMetadata> metadata =
                    DeserializeTensorMetadata(response->metadata(i));
                if (metadata) {
                  (*results)[i].metadata.emplace(metadata.get());
                  (*results)[i].tensor->emplace<RemoteTensor>(
                      metadata.get(), *(*results)[i].remote_object_id);
                } else {
                  (*results)[i].tensor->SetError(DecodedDiagnostic(
                      "could not deserialize metadata in response"));
                }
              }
              chain.SetStateConcrete();
            });
      });
}