void RequestHandler::HandleRemoteExecute()

in lib/distributed_runtime/request_handler_impl.cc [311:461]


void RequestHandler::HandleRemoteExecute(const RemoteExecuteRequest* request,
                                         RemoteExecuteResponse* response,
                                         CallbackFn done) {
  auto expected = server_context_->GetDistributedContext(request->context_id());
  if (!expected) {
    done(expected.takeError());
    return;
  }
  DistributedContext* dist_context = expected.get();

  FunctionCache* function_cache = dist_context->GetFunctionCache();
  FunctionCache::CachedBEF* cached_bef =
      function_cache->Prepare(request->program_name());
  if (cached_bef == nullptr) {
    done(llvm::make_error<InvalidArgumentErrorInfo>(
        StrCat("Can't find program: [", request->program_name(), "]")));
    return;
  }
  RCReference<BEFFile>& bef_file = cached_bef->bef_file;
  if (bef_file.get() == nullptr) {
    done(llvm::make_error<InvalidArgumentErrorInfo>(
        StrCat("Can't find function: [", request->program_name(), "]")));
    return;
  }

  const Function* fn = bef_file->GetFunction(request->program_name());
  if (fn == nullptr) {
    done(llvm::make_error<InvalidArgumentErrorInfo>(
        StrCat("Failed to get program from BEFFile with name ",
               request->program_name(), ".")));
    return;
  }
  if (fn->result_types().size() != request->output_size()) {
    done(llvm::make_error<InvalidArgumentErrorInfo>(
        StrCat("Result size mismatch: fn #result: ", fn->result_types().size(),
               " Received #outputs: ", request->output_size())));
    return;
  }

  // TODO(bramandia): Propagate RequestContext from the request.
  ResourceContext resource_context;
  Expected<RCReference<RequestContext>> req_ctx =
      RequestContextBuilder(host_ctx(), &resource_context).build();
  if (!req_ctx) {
    done(llvm::make_error<UnknownErrorInfo>(
        StrCat("Failed to build RequestContext ", req_ctx.takeError())));
    return;
  }
  tfrt::ExecutionContext exec_ctx{std::move(*req_ctx)};

  RemoteObjectManager* manager = dist_context->GetRemoteObjectManager();
  llvm::SmallVector<AsyncValue*, 4> arguments;
  llvm::SmallVector<RCReference<AsyncValue>, 4> arguments_ref;
  arguments.reserve(fn->argument_types().size());
  arguments_ref.reserve(fn->argument_types().size());
  // Allow the first argument to be `DistributedContext`.
  if (cached_bef->require_distributed_context) {
    AsyncValue* dist_context_arg =
        server_context_->GetDistributedContextAsyncValue(request->context_id())
            .GetAsyncValue();
    arguments.push_back(dist_context_arg);
  }
  if (cached_bef->require_preallocated_outputs) {
    for (int i = 0; i < request->output_size(); ++i) {
      auto& id = request->output(i).id();
      RCReference<Device> device =
          dist_context->GetRemoteDeviceManager()->GetDeviceRef<Device>(
              id.device());
      if (device.get() == nullptr) {
        done(llvm::make_error<DeviceNotFoundErrorInfo>(
            StrCat("Can't find device: ", id.device())));
        return;
      }
      RCReference<AsyncValue> remote_object_id =
          MakeAvailableAsyncValueRef<RemoteObjectId>(host_ctx(), id.prefix_id(),
                                                     id.local_id(), device);
      arguments_ref.push_back(remote_object_id);
      arguments.push_back(remote_object_id.get());
    }
  }
  if (fn->argument_types().size() != arguments.size() + request->input_size()) {
    done(llvm::make_error<InvalidArgumentErrorInfo>(
        StrCat("Argument size mismatch: fn #arg: ", fn->argument_types().size(),
               " Received #inputs: ", request->input_size())));
    return;
  }
  for (int i = 0; i < request->input_size(); ++i) {
    auto& id = request->input(i);

    RCReference<Device> device =
        dist_context->GetRemoteDeviceManager()->GetDeviceRef<Device>(
            id.device());
    if (device.get() == nullptr) {
      done(llvm::make_error<DeviceNotFoundErrorInfo>(
          StrCat("Can't find device: ", id.device())));
      return;
    }
    RemoteObjectId input_id(id.prefix_id(), id.local_id(), device);
    RCReference<AsyncValue> val = manager->GetRemoteObject(input_id);
    arguments_ref.push_back(val);
    arguments.push_back(val.get());
  }
  auto results =
      std::make_unique<llvm::SmallVector<RCReference<AsyncValue>, 4>>();
  results->resize(fn->result_types().size());

  fn->Execute(exec_ctx, arguments, *results);
  for (int i = 0; i < request->output_size(); ++i) {
    auto& id = request->output(i).id();
    RCReference<Device> device =
        dist_context->GetRemoteDeviceManager()->GetDeviceRef<Device>(
            id.device());
    if (device.get() == nullptr) {
      done(llvm::make_error<DeviceNotFoundErrorInfo>(
          StrCat("Can't find device: ", id.device())));
      return;
    }
    // TODO(bramandia): Do not store the output in the map if the device is not
    // a local device.
    RemoteObjectId output_id(id.prefix_id(), id.local_id(), device);
    manager->SetRemoteObject(output_id, (*results)[i]);
  }

  // get the pointer of results before being moved on the lambda capture.
  auto result_ref = results.get();
  // Request will live as long as done is not called yet.
  RunWhenReady(*result_ref, [fn, done = std::move(done), request, response,
                             results = std::move(results),
                             arguments = std::move(arguments),
                             arguments_ref =
                                 std::move(arguments_ref)]() mutable {
    for (int i = 0; i < request->output_size(); ++i) {
      if (request->output(i).need_metadata()) {
        if (fn->result_types()[i].GetName() == "!t.tensor") {
          std::string serialized =
              SerializeTensorMetadata((*results)[i]->get<Tensor>().metadata());
          response->add_metadata(serialized);
        } else if (fn->result_types()[i].GetName() == "!corert.tensorhandle") {
          std::string serialized = SerializeTensorMetadata(
              (*results)[i]->get<TensorHandle>().GetAvailableMetadata());
          response->add_metadata(serialized);
        } else {
          done(llvm::make_error<InvalidArgumentErrorInfo>(
              StrCat("Invalid type ", fn->result_types()[i].GetName())));
          return;
        }
      }
    }
    done(Error::success());
  });
}