in lib/distributed_runtime/request_handler_impl.cc [495:642]
void RequestHandler::HandleRemoteExecuteOp(
const RemoteExecuteOpRequest* request, RemoteExecuteOpResponse* response,
CallbackFn done) {
TFRT_DLOG(INFO) << "HandleRemoteExecuteOp " << request->op_name();
auto expected_dist_ctx =
server_context_->GetDistributedContext(request->context_id());
if (!expected_dist_ctx) {
done(expected_dist_ctx.takeError());
TFRT_LOG(ERROR) << "Did not find DistributedContext with id "
<< request->context_id();
return;
}
DistributedContext* dist_ctx = expected_dist_ctx.get();
HostContext* host_ctx = dist_ctx->GetHostContext();
CoreRuntime* corert = CoreRuntime::GetFromHostContext(host_ctx);
DeviceManager* device_manager = dist_ctx->GetRemoteDeviceManager();
RemoteObjectManager* object_manager = dist_ctx->GetRemoteObjectManager();
OpHandler* op_handler = corert->GetOpHandler(request->op_handler_name());
Expected<CoreRuntimeOp> expected_op =
corert->MakeOp(request->op_name(), op_handler);
if (!expected_op) {
done(expected_op.takeError());
TFRT_LOG(ERROR) << "Could not MakeOp in RemoteOpHandler "
<< request->op_name();
return;
}
auto op = std::move(expected_op.get());
auto device =
device_manager->GetDeviceRef<Device>(request->in_chain().device());
auto async_args =
std::make_unique<llvm::SmallVector<RCReference<AsyncValue>, 4>>();
async_args->reserve(request->input_size() + 1); // TH inputs + in chain
// Get the potentially async input chain.
if (auto e = GetRemoteObjectFromId(device_manager, object_manager,
request->in_chain(), async_args.get())) {
TFRT_LOG(ERROR) << "Error while getting remote object "
<< request->in_chain().DebugString() << " error: " << e;
done(std::move(e));
return;
}
// Get all other potentially async input tensors.
for (auto i = 0; i < request->input_size(); ++i) {
if (auto e = GetRemoteObjectFromId(device_manager, object_manager,
request->input(i), async_args.get())) {
done(std::move(e));
return;
}
TFRT_DLOG(INFO) << "HandleRemoteExecuteOp wait for input "
<< request->input(i).DebugString() << " av "
<< async_args->back().get();
}
TFRT_DLOG(INFO) << "HandleRemoteExecuteOp " << request->op_name()
<< " wait for " << async_args->size() << " inputs";
auto async_args_ref = async_args.get();
RunWhenReady(*async_args_ref, [host_ctx, dist_ctx, request, response,
done = std::move(done), op = std::move(op),
device = std::move(device),
async_args = std::move(async_args)]() mutable {
TFRT_DLOG(INFO) << "HandleRemoteExecuteOp " << request->op_name();
AsyncValueRef<Chain> chain(FormRef((*async_args)[0].get()));
// Get the actual Tensor inputs which should now be available.
llvm::SmallVector<TensorHandle, 4> args;
args.reserve(request->input_size());
for (auto i = 1; i < async_args->size(); ++i) {
AsyncValueRef<Tensor> tensor((*async_args)[i]);
args.emplace_back(device, tensor->metadata(), tensor.CopyRef());
}
llvm::SmallVector<TensorHandle, 4> results;
results.resize(request->output_size());
// TODO(bramandia): Propagate RequestContext from the request.
ResourceContext resource_context;
Expected<RCReference<tfrt::RequestContext>> req_ctx =
RequestContextBuilder(host_ctx, &resource_context).build();
if (!req_ctx) {
done(llvm::make_error<UnknownErrorInfo>(
StrCat("Failed to build RequestContext ", req_ctx.takeError())));
return;
}
tfrt::ExecutionContext exec_ctx{std::move(*req_ctx)};
// Setup op attributes.
OpAttrs op_attrs;
ParseOpAttrs(*request, &op_attrs);
op(exec_ctx, args, OpAttrsRef(op_attrs), results, &chain);
// Set the output chain mapping in the remote object manager.
RemoteObjectId out_chain_id(request->out_chain().prefix_id(),
request->out_chain().local_id(), device);
dist_ctx->GetRemoteObjectManager()->SetRemoteObject(out_chain_id,
chain.CopyRCRef());
TFRT_DLOG(INFO) << "HandleRemoteExecuteOp " << request->op_name()
<< " out chain " << request->out_chain().DebugString()
<< " av " << chain.GetAsyncValue();
// Wait for results to be ready before sending the response to client.
// Also add a mapping in the remote object manager for each output.
llvm::SmallVector<RCReference<AsyncValue>, 4> async_results;
async_results.reserve(results.size() + 1); // TH results + out chain
async_results.push_back(chain.CopyRCRef());
for (auto i = 0; i < results.size(); ++i) {
async_results.push_back(FormRef(results[i].GetAsyncTensor()));
auto& output_id = request->output(i).id();
auto output_device =
dist_ctx->GetRemoteDeviceManager()->GetDeviceRef<Device>(
output_id.device());
if (device.get() == nullptr) {
done(llvm::make_error<DeviceNotFoundErrorInfo>(
StrCat("Can't find device: ", output_id.device())));
return;
}
RemoteObjectId object_id(output_id.prefix_id(), output_id.local_id(),
std::move(output_device));
dist_ctx->GetRemoteObjectManager()->SetRemoteObject(object_id,
async_results.back());
TFRT_DLOG(INFO) << "HandleRemoteExecuteOp " << request->op_name()
<< " output " << output_id.DebugString() << " av "
<< async_results.back().get();
}
RunWhenReady(async_results, [request, response, device = std::move(device),
chain = chain.CopyRef(),
results = std::move(results),
done = std::move(done)]() mutable {
TFRT_DLOG(INFO) << "HandleRemoteExecuteOp " << request->op_name();
if (chain.IsError()) {
// TODO(ayushd): choose a more informative error code.
done(llvm::make_error<UnknownErrorInfo>(chain.GetError().message));
}
for (auto& result : results) {
auto serialized =
SerializeTensorMetadata(result.GetAvailableMetadata());
TFRT_DLOG(INFO) << "HandleRemoteExecuteOp " << request->op_name()
<< " result metadata " << result.GetAvailableMetadata();
response->add_metadata(serialized);
}
done(Error::success());
});
});
}