in lib/distributed_runtime/remote_op_handler.cc [128:248]
void RemoteOpHandler::Execute(const std::string& op_name,
const OpInvocation& invocation) {
// Wait for async input tensors before starting to dispatch the
// remote op.
llvm::SmallVector<AsyncValue*, 4> to_wait;
auto chain = MakeConstructedAsyncValueRef<Chain>(dist_ctx_->GetHostContext());
*invocation.chain = chain.CopyRef();
// We may need to wait for async inputs, so we add input object ids
// after waiting.
auto arguments =
std::make_unique<llvm::SmallVector<RCReference<AsyncValue>, 8>>();
arguments->resize(invocation.arguments.size());
for (auto i = 0; i < invocation.arguments.size(); ++i) {
auto* tensor_av = invocation.arguments[i].GetAsyncTensor();
(*arguments)[i] = FormRef(tensor_av);
if (!tensor_av->IsAvailable()) {
to_wait.push_back(tensor_av);
}
}
auto request = std::make_unique<RemoteExecuteOpRequest>();
request->set_context_id(dist_ctx_->GetContextId());
request->set_op_handler_name(remote_device_->name().str());
request->set_op_name(op_name);
// Add output object ids to the request.
// TODO(ayushd): optimize so that metadata is not always asynchronous.
auto results = std::make_unique<llvm::SmallVector<TensorAndMetadata, 8>>();
results->resize(invocation.results.size());
for (auto i = 0; i < invocation.results.size(); ++i) {
auto tensor = MakeUnconstructedAsyncValueRef<RemoteTensor>(
dist_ctx_->GetHostContext());
auto metadata = MakeUnconstructedAsyncValueRef<TensorMetadata>(
dist_ctx_->GetHostContext());
invocation.results[i] =
TensorHandle(remote_device_, metadata.CopyRef(), tensor.CopyRef());
(*results)[i].tensor = tensor.ReleaseRCRef();
(*results)[i].metadata = std::move(metadata);
(*results)[i].remote_object_id = std::make_unique<RemoteObjectId>(
dist_ctx_->GetRemoteObjectManager()->AllocateRemoteObject(
remote_device_));
PopulateRemoteExecuteOutputProto(request->add_output(),
*(*results)[i].remote_object_id);
}
// Add information about the chains to the request.
TaskHandle remote_task = remote_device_->GetTaskHandle();
auto in_chain_id =
remote_chain_manager_->GetRemoteChain(remote_device_->GetTaskHandle());
PopulateRemoteObjectIdProto(
request->mutable_in_chain(),
remote_chain_manager_->GetRemoteChain(remote_task));
auto out_chain_id =
dist_ctx_->GetRemoteObjectManager()->AllocateRemoteObject(remote_device_);
PopulateRemoteObjectIdProto(request->mutable_out_chain(), out_chain_id);
remote_chain_manager_->SetRemoteChain(remote_task, out_chain_id);
// Add op attributes to the request.
if (Error attr_error =
PopulateRequestAttrsProto(request.get(), invocation.attrs.freeze())) {
chain.SetError(attr_error);
for (auto& output_th : invocation.results) {
output_th.GetAsyncTensor()->SetError(DecodedDiagnostic(attr_error));
}
return;
}
RunWhenReady(
to_wait,
[dist_ctx = dist_ctx_, arguments = std::move(arguments),
results = std::move(results), request = std::move(request),
attrs = invocation.attrs.freeze(), remote_task = std::move(remote_task),
chain = std::move(chain)]() mutable {
// Add input object ids to the request.
for (auto& input : *arguments) {
// Each TensorHandle should contain an available RemoteTensor.
// The corresponding tensor on the remote side may be
// unavailable.
assert(input->IsAvailable());
auto* request_input = request->add_input();
PopulateRemoteObjectIdProto(
request_input, input->get<RemoteTensor>().remote_object_id());
TFRT_DLOG(INFO) << "RemoteOpHandler input "
<< request_input->DebugString();
}
auto response = std::make_unique<RemoteExecuteOpResponse>();
RemoteClientInterface* remote_client =
dist_ctx->GetRemoteClient(remote_task);
remote_client->RemoteExecuteOpAsync(
RemoteCallContext::GetDefault(), request.get(), response.get(),
[results = std::move(results), request = std::move(request),
response = std::move(response),
chain = chain.CopyRef()](Error e) mutable {
if (e) {
chain.SetError(std::move(e));
return;
}
if (response->metadata_size() != results->size()) {
chain.SetError("unexpected number of remote results");
return;
}
for (auto i = 0; i < response->metadata_size(); ++i) {
Expected<TensorMetadata> metadata =
DeserializeTensorMetadata(response->metadata(i));
if (metadata) {
(*results)[i].metadata.emplace(metadata.get());
(*results)[i].tensor->emplace<RemoteTensor>(
metadata.get(), *(*results)[i].remote_object_id);
} else {
(*results)[i].tensor->SetError(DecodedDiagnostic(
"could not deserialize metadata in response"));
}
}
chain.SetStateConcrete();
});
});
}