in lib/distributed_runtime/kernels.cc [1017:1150]
void RemoteExecute(Chain ch, DistributedContext* dist_context,
const TaskHandle receiver, Argument<RemoteExecuteSpec> spec,
RemainingArguments inputs, RemainingResults results,
StringAttribute program_name, int num_fn_inputs,
int32_t num_output_with_tensorhandle,
const ExecutionContext& exec_ctx) {
// If some output IDs are present in the inputs, we assume all output IDs are
// pre-allocated.
const bool output_id_allocated = num_fn_inputs != inputs.size();
auto request = std::make_unique<RemoteExecuteRequest>();
// program_name will live as long as out_chain is not populated.
request->set_context_id(dist_context->GetContextId());
request->set_program_name(program_name.str());
request->mutable_input()->Reserve(num_fn_inputs);
for (int i = 0; i < num_fn_inputs; ++i) {
const RemoteObjectId& input = inputs[i]->get<RemoteObjectId>();
auto* add_input = request->add_input();
add_input->set_prefix_id(input.prefix_id);
add_input->set_local_id(input.local_id);
add_input->set_device(input.device->name().str());
}
// First output: chain
AsyncValueRef<Chain> out_chain =
MakeConstructedAsyncValueRef<Chain>(exec_ctx.host());
results[0] = out_chain.CopyRef();
// If output_id is preallocated, we only return TensorHandles. Otherwise, we
// return output ids followed by TensorHandles.
if (results.size() !=
1 /*chain*/ + (output_id_allocated ? 0 : spec->output_devices.size()) +
num_output_with_tensorhandle) {
out_chain.SetError(llvm::make_error<InvalidArgumentErrorInfo>(StrCat(
"Mismatch output devices size in RemoteExecuteSpec: ",
spec->output_devices.size(), " expected: ", results.size() - 1)));
return;
}
// Actual number of outputs of the remote function.
int num_fn_output;
if (output_id_allocated) {
// Each of the output IDs must be passed as inputs.
num_fn_output = inputs.size() - num_fn_inputs;
} else {
// Otherwise, we can infer this from the kernel outputs minus chain minus
// TensorHandle output
num_fn_output = results.size() - num_output_with_tensorhandle - 1;
}
// Start output index of TensorHandle outputs.
// If output id is allocated, we only return TensorHandles.
const int th_output_idx = output_id_allocated ? 0 : num_fn_output;
request->mutable_output()->Reserve(num_fn_output);
RemoteObjectManager* manager = dist_context->GetRemoteObjectManager();
struct RemoteObjectAndMetadata {
AsyncValueRef<RemoteObjectId> id;
AsyncValueRef<RemoteTensor> tensor;
AsyncValueRef<TensorMetadata> metadata;
};
llvm::SmallVector<RemoteObjectAndMetadata, 4> remote_objs;
for (int i = 1; i <= num_fn_output; ++i) {
RCReference<Device> output_device = spec->output_devices[i - 1];
AsyncValueRef<RemoteObjectId> out_id;
if (output_id_allocated) {
// Reuse output id
out_id =
AsyncValueRef<RemoteObjectId>(FormRef(inputs[num_fn_inputs + i - 1]));
} else {
// Allocate output id
out_id = MakeAvailableAsyncValueRef<RemoteObjectId>(
exec_ctx.host(),
manager->AllocateRemoteObject(std::move(output_device)));
// The next num_id_outputs are RemoteObjectId
results[i] = out_id.CopyRef();
}
// The last num_output_with_metadata RemoteObjectIds needs to have
// TensorMetadata returned.
const bool need_metadata =
i > (num_fn_output - num_output_with_tensorhandle);
if (need_metadata) {
auto tensor =
MakeUnconstructedAsyncValueRef<RemoteTensor>(exec_ctx.host());
auto metadata =
MakeUnconstructedAsyncValueRef<TensorMetadata>(exec_ctx.host());
AsyncValueRef<TensorHandle> th = MakeAvailableAsyncValueRef<TensorHandle>(
exec_ctx.host(), out_id->device, metadata.CopyRef(),
tensor.CopyRef());
remote_objs.emplace_back(RemoteObjectAndMetadata{
out_id.CopyRef(), std::move(tensor), std::move(metadata)});
// The remaining outputs are TensorHandle
results[th_output_idx + remote_objs.size()] = th.CopyRef();
}
auto* add_output = request->add_output();
add_output->set_need_metadata(need_metadata);
auto* add_output_id = add_output->mutable_id();
add_output_id->set_prefix_id(out_id->prefix_id);
add_output_id->set_local_id(out_id->local_id);
add_output_id->set_device(out_id->device->name().str());
}
RemoteClientInterface* remote_client =
dist_context->GetRemoteClient(receiver);
EnqueueWork(exec_ctx, [remote_client, request = std::move(request),
dist_context, out_chain = out_chain.CopyRef(),
remote_objs = std::move(remote_objs)]() mutable {
auto response = std::make_unique<RemoteExecuteResponse>();
remote_client->RemoteExecuteAsync(
RemoteCallContext::GetDefault(), request.get(), response.get(),
[request = std::move(request), response = std::move(response),
out_chain = out_chain.CopyRef(), remote_objs = std::move(remote_objs),
host_context = dist_context->GetHostContext()](Error e) mutable {
// Propagate metadata and output chain
const int num_metadata = response->metadata_size();
for (int i = 0; i < remote_objs.size(); ++i) {
auto& obj = remote_objs[i];
if (i >= num_metadata) {
obj.metadata.SetError(DecodedDiagnostic("Metadata not returned"));
continue;
}
auto metadata = DeserializeTensorMetadata(response->metadata(i));
if (metadata) {
obj.metadata.emplace(metadata.get());
obj.tensor.emplace(std::move(metadata.get()), obj.id.get());
} else {
obj.tensor.SetError(DecodedDiagnostic(metadata.takeError()));
}
}
if (e) {
out_chain.SetError(std::move(e));
} else {
out_chain.SetStateConcrete();
}
});
});
}