void RemoteExecute()

in lib/distributed_runtime/kernels.cc [1017:1150]


void RemoteExecute(Chain ch, DistributedContext* dist_context,
                   const TaskHandle receiver, Argument<RemoteExecuteSpec> spec,
                   RemainingArguments inputs, RemainingResults results,
                   StringAttribute program_name, int num_fn_inputs,
                   int32_t num_output_with_tensorhandle,
                   const ExecutionContext& exec_ctx) {
  // If some output IDs are present in the inputs, we assume all output IDs are
  // pre-allocated.
  const bool output_id_allocated = num_fn_inputs != inputs.size();
  auto request = std::make_unique<RemoteExecuteRequest>();
  // program_name will live as long as out_chain is not populated.
  request->set_context_id(dist_context->GetContextId());
  request->set_program_name(program_name.str());
  request->mutable_input()->Reserve(num_fn_inputs);

  for (int i = 0; i < num_fn_inputs; ++i) {
    const RemoteObjectId& input = inputs[i]->get<RemoteObjectId>();
    auto* add_input = request->add_input();
    add_input->set_prefix_id(input.prefix_id);
    add_input->set_local_id(input.local_id);
    add_input->set_device(input.device->name().str());
  }
  // First output: chain
  AsyncValueRef<Chain> out_chain =
      MakeConstructedAsyncValueRef<Chain>(exec_ctx.host());
  results[0] = out_chain.CopyRef();

  // If output_id is preallocated, we only return TensorHandles. Otherwise, we
  // return output ids followed by TensorHandles.
  if (results.size() !=
      1 /*chain*/ + (output_id_allocated ? 0 : spec->output_devices.size()) +
          num_output_with_tensorhandle) {
    out_chain.SetError(llvm::make_error<InvalidArgumentErrorInfo>(StrCat(
        "Mismatch output devices size in RemoteExecuteSpec: ",
        spec->output_devices.size(), " expected: ", results.size() - 1)));
    return;
  }

  // Actual number of outputs of the remote function.
  int num_fn_output;
  if (output_id_allocated) {
    // Each of the output IDs must be passed as inputs.
    num_fn_output = inputs.size() - num_fn_inputs;
  } else {
    // Otherwise, we can infer this from the kernel outputs minus chain minus
    // TensorHandle output
    num_fn_output = results.size() - num_output_with_tensorhandle - 1;
  }
  // Start output index of TensorHandle outputs.
  // If output id is allocated, we only return TensorHandles.
  const int th_output_idx = output_id_allocated ? 0 : num_fn_output;
  request->mutable_output()->Reserve(num_fn_output);
  RemoteObjectManager* manager = dist_context->GetRemoteObjectManager();
  struct RemoteObjectAndMetadata {
    AsyncValueRef<RemoteObjectId> id;
    AsyncValueRef<RemoteTensor> tensor;
    AsyncValueRef<TensorMetadata> metadata;
  };
  llvm::SmallVector<RemoteObjectAndMetadata, 4> remote_objs;
  for (int i = 1; i <= num_fn_output; ++i) {
    RCReference<Device> output_device = spec->output_devices[i - 1];
    AsyncValueRef<RemoteObjectId> out_id;
    if (output_id_allocated) {
      // Reuse output id
      out_id =
          AsyncValueRef<RemoteObjectId>(FormRef(inputs[num_fn_inputs + i - 1]));
    } else {
      // Allocate output id
      out_id = MakeAvailableAsyncValueRef<RemoteObjectId>(
          exec_ctx.host(),
          manager->AllocateRemoteObject(std::move(output_device)));
      // The next num_id_outputs are RemoteObjectId
      results[i] = out_id.CopyRef();
    }
    // The last num_output_with_metadata RemoteObjectIds needs to have
    // TensorMetadata returned.
    const bool need_metadata =
        i > (num_fn_output - num_output_with_tensorhandle);
    if (need_metadata) {
      auto tensor =
          MakeUnconstructedAsyncValueRef<RemoteTensor>(exec_ctx.host());
      auto metadata =
          MakeUnconstructedAsyncValueRef<TensorMetadata>(exec_ctx.host());
      AsyncValueRef<TensorHandle> th = MakeAvailableAsyncValueRef<TensorHandle>(
          exec_ctx.host(), out_id->device, metadata.CopyRef(),
          tensor.CopyRef());
      remote_objs.emplace_back(RemoteObjectAndMetadata{
          out_id.CopyRef(), std::move(tensor), std::move(metadata)});
      // The remaining outputs are TensorHandle
      results[th_output_idx + remote_objs.size()] = th.CopyRef();
    }
    auto* add_output = request->add_output();
    add_output->set_need_metadata(need_metadata);
    auto* add_output_id = add_output->mutable_id();
    add_output_id->set_prefix_id(out_id->prefix_id);
    add_output_id->set_local_id(out_id->local_id);
    add_output_id->set_device(out_id->device->name().str());
  }

  RemoteClientInterface* remote_client =
      dist_context->GetRemoteClient(receiver);
  EnqueueWork(exec_ctx, [remote_client, request = std::move(request),
                         dist_context, out_chain = out_chain.CopyRef(),
                         remote_objs = std::move(remote_objs)]() mutable {
    auto response = std::make_unique<RemoteExecuteResponse>();
    remote_client->RemoteExecuteAsync(
        RemoteCallContext::GetDefault(), request.get(), response.get(),
        [request = std::move(request), response = std::move(response),
         out_chain = out_chain.CopyRef(), remote_objs = std::move(remote_objs),
         host_context = dist_context->GetHostContext()](Error e) mutable {
          // Propagate metadata and output chain
          const int num_metadata = response->metadata_size();
          for (int i = 0; i < remote_objs.size(); ++i) {
            auto& obj = remote_objs[i];
            if (i >= num_metadata) {
              obj.metadata.SetError(DecodedDiagnostic("Metadata not returned"));
              continue;
            }
            auto metadata = DeserializeTensorMetadata(response->metadata(i));
            if (metadata) {
              obj.metadata.emplace(metadata.get());
              obj.tensor.emplace(std::move(metadata.get()), obj.id.get());
            } else {
              obj.tensor.SetError(DecodedDiagnostic(metadata.takeError()));
            }
          }
          if (e) {
            out_chain.SetError(std::move(e));
          } else {
            out_chain.SetStateConcrete();
          }
        });
  });
}