in lib/distributed_runtime/request_handler_impl.cc [311:461]
void RequestHandler::HandleRemoteExecute(const RemoteExecuteRequest* request,
RemoteExecuteResponse* response,
CallbackFn done) {
auto expected = server_context_->GetDistributedContext(request->context_id());
if (!expected) {
done(expected.takeError());
return;
}
DistributedContext* dist_context = expected.get();
FunctionCache* function_cache = dist_context->GetFunctionCache();
FunctionCache::CachedBEF* cached_bef =
function_cache->Prepare(request->program_name());
if (cached_bef == nullptr) {
done(llvm::make_error<InvalidArgumentErrorInfo>(
StrCat("Can't find program: [", request->program_name(), "]")));
return;
}
RCReference<BEFFile>& bef_file = cached_bef->bef_file;
if (bef_file.get() == nullptr) {
done(llvm::make_error<InvalidArgumentErrorInfo>(
StrCat("Can't find function: [", request->program_name(), "]")));
return;
}
const Function* fn = bef_file->GetFunction(request->program_name());
if (fn == nullptr) {
done(llvm::make_error<InvalidArgumentErrorInfo>(
StrCat("Failed to get program from BEFFile with name ",
request->program_name(), ".")));
return;
}
if (fn->result_types().size() != request->output_size()) {
done(llvm::make_error<InvalidArgumentErrorInfo>(
StrCat("Result size mismatch: fn #result: ", fn->result_types().size(),
" Received #outputs: ", request->output_size())));
return;
}
// TODO(bramandia): Propagate RequestContext from the request.
ResourceContext resource_context;
Expected<RCReference<RequestContext>> req_ctx =
RequestContextBuilder(host_ctx(), &resource_context).build();
if (!req_ctx) {
done(llvm::make_error<UnknownErrorInfo>(
StrCat("Failed to build RequestContext ", req_ctx.takeError())));
return;
}
tfrt::ExecutionContext exec_ctx{std::move(*req_ctx)};
RemoteObjectManager* manager = dist_context->GetRemoteObjectManager();
llvm::SmallVector<AsyncValue*, 4> arguments;
llvm::SmallVector<RCReference<AsyncValue>, 4> arguments_ref;
arguments.reserve(fn->argument_types().size());
arguments_ref.reserve(fn->argument_types().size());
// Allow the first argument to be `DistributedContext`.
if (cached_bef->require_distributed_context) {
AsyncValue* dist_context_arg =
server_context_->GetDistributedContextAsyncValue(request->context_id())
.GetAsyncValue();
arguments.push_back(dist_context_arg);
}
if (cached_bef->require_preallocated_outputs) {
for (int i = 0; i < request->output_size(); ++i) {
auto& id = request->output(i).id();
RCReference<Device> device =
dist_context->GetRemoteDeviceManager()->GetDeviceRef<Device>(
id.device());
if (device.get() == nullptr) {
done(llvm::make_error<DeviceNotFoundErrorInfo>(
StrCat("Can't find device: ", id.device())));
return;
}
RCReference<AsyncValue> remote_object_id =
MakeAvailableAsyncValueRef<RemoteObjectId>(host_ctx(), id.prefix_id(),
id.local_id(), device);
arguments_ref.push_back(remote_object_id);
arguments.push_back(remote_object_id.get());
}
}
if (fn->argument_types().size() != arguments.size() + request->input_size()) {
done(llvm::make_error<InvalidArgumentErrorInfo>(
StrCat("Argument size mismatch: fn #arg: ", fn->argument_types().size(),
" Received #inputs: ", request->input_size())));
return;
}
for (int i = 0; i < request->input_size(); ++i) {
auto& id = request->input(i);
RCReference<Device> device =
dist_context->GetRemoteDeviceManager()->GetDeviceRef<Device>(
id.device());
if (device.get() == nullptr) {
done(llvm::make_error<DeviceNotFoundErrorInfo>(
StrCat("Can't find device: ", id.device())));
return;
}
RemoteObjectId input_id(id.prefix_id(), id.local_id(), device);
RCReference<AsyncValue> val = manager->GetRemoteObject(input_id);
arguments_ref.push_back(val);
arguments.push_back(val.get());
}
auto results =
std::make_unique<llvm::SmallVector<RCReference<AsyncValue>, 4>>();
results->resize(fn->result_types().size());
fn->Execute(exec_ctx, arguments, *results);
for (int i = 0; i < request->output_size(); ++i) {
auto& id = request->output(i).id();
RCReference<Device> device =
dist_context->GetRemoteDeviceManager()->GetDeviceRef<Device>(
id.device());
if (device.get() == nullptr) {
done(llvm::make_error<DeviceNotFoundErrorInfo>(
StrCat("Can't find device: ", id.device())));
return;
}
// TODO(bramandia): Do not store the output in the map if the device is not
// a local device.
RemoteObjectId output_id(id.prefix_id(), id.local_id(), device);
manager->SetRemoteObject(output_id, (*results)[i]);
}
// get the pointer of results before being moved on the lambda capture.
auto result_ref = results.get();
// Request will live as long as done is not called yet.
RunWhenReady(*result_ref, [fn, done = std::move(done), request, response,
results = std::move(results),
arguments = std::move(arguments),
arguments_ref =
std::move(arguments_ref)]() mutable {
for (int i = 0; i < request->output_size(); ++i) {
if (request->output(i).need_metadata()) {
if (fn->result_types()[i].GetName() == "!t.tensor") {
std::string serialized =
SerializeTensorMetadata((*results)[i]->get<Tensor>().metadata());
response->add_metadata(serialized);
} else if (fn->result_types()[i].GetName() == "!corert.tensorhandle") {
std::string serialized = SerializeTensorMetadata(
(*results)[i]->get<TensorHandle>().GetAvailableMetadata());
response->add_metadata(serialized);
} else {
done(llvm::make_error<InvalidArgumentErrorInfo>(
StrCat("Invalid type ", fn->result_types()[i].GetName())));
return;
}
}
}
done(Error::success());
});
}