void RequestHandler::HandleRemoteExecuteOp()

in lib/distributed_runtime/request_handler_impl.cc [495:642]


void RequestHandler::HandleRemoteExecuteOp(
    const RemoteExecuteOpRequest* request, RemoteExecuteOpResponse* response,
    CallbackFn done) {
  TFRT_DLOG(INFO) << "HandleRemoteExecuteOp " << request->op_name();
  auto expected_dist_ctx =
      server_context_->GetDistributedContext(request->context_id());
  if (!expected_dist_ctx) {
    done(expected_dist_ctx.takeError());
    TFRT_LOG(ERROR) << "Did not find DistributedContext with id "
                    << request->context_id();
    return;
  }
  DistributedContext* dist_ctx = expected_dist_ctx.get();
  HostContext* host_ctx = dist_ctx->GetHostContext();
  CoreRuntime* corert = CoreRuntime::GetFromHostContext(host_ctx);
  DeviceManager* device_manager = dist_ctx->GetRemoteDeviceManager();
  RemoteObjectManager* object_manager = dist_ctx->GetRemoteObjectManager();

  OpHandler* op_handler = corert->GetOpHandler(request->op_handler_name());
  Expected<CoreRuntimeOp> expected_op =
      corert->MakeOp(request->op_name(), op_handler);
  if (!expected_op) {
    done(expected_op.takeError());
    TFRT_LOG(ERROR) << "Could not MakeOp in RemoteOpHandler "
                    << request->op_name();
    return;
  }
  auto op = std::move(expected_op.get());
  auto device =
      device_manager->GetDeviceRef<Device>(request->in_chain().device());

  auto async_args =
      std::make_unique<llvm::SmallVector<RCReference<AsyncValue>, 4>>();
  async_args->reserve(request->input_size() + 1);  // TH inputs + in chain

  // Get the potentially async input chain.
  if (auto e = GetRemoteObjectFromId(device_manager, object_manager,
                                     request->in_chain(), async_args.get())) {
    TFRT_LOG(ERROR) << "Error while getting remote object "
                    << request->in_chain().DebugString() << " error: " << e;
    done(std::move(e));
    return;
  }

  // Get all other potentially async input tensors.
  for (auto i = 0; i < request->input_size(); ++i) {
    if (auto e = GetRemoteObjectFromId(device_manager, object_manager,
                                       request->input(i), async_args.get())) {
      done(std::move(e));
      return;
    }
    TFRT_DLOG(INFO) << "HandleRemoteExecuteOp wait for input "
                    << request->input(i).DebugString() << " av "
                    << async_args->back().get();
  }

  TFRT_DLOG(INFO) << "HandleRemoteExecuteOp " << request->op_name()
                  << " wait for " << async_args->size() << " inputs";
  auto async_args_ref = async_args.get();
  RunWhenReady(*async_args_ref, [host_ctx, dist_ctx, request, response,
                                 done = std::move(done), op = std::move(op),
                                 device = std::move(device),
                                 async_args = std::move(async_args)]() mutable {
    TFRT_DLOG(INFO) << "HandleRemoteExecuteOp " << request->op_name();
    AsyncValueRef<Chain> chain(FormRef((*async_args)[0].get()));

    // Get the actual Tensor inputs which should now be available.
    llvm::SmallVector<TensorHandle, 4> args;
    args.reserve(request->input_size());
    for (auto i = 1; i < async_args->size(); ++i) {
      AsyncValueRef<Tensor> tensor((*async_args)[i]);
      args.emplace_back(device, tensor->metadata(), tensor.CopyRef());
    }
    llvm::SmallVector<TensorHandle, 4> results;
    results.resize(request->output_size());

    // TODO(bramandia): Propagate RequestContext from the request.
    ResourceContext resource_context;
    Expected<RCReference<tfrt::RequestContext>> req_ctx =
        RequestContextBuilder(host_ctx, &resource_context).build();
    if (!req_ctx) {
      done(llvm::make_error<UnknownErrorInfo>(
          StrCat("Failed to build RequestContext ", req_ctx.takeError())));
      return;
    }
    tfrt::ExecutionContext exec_ctx{std::move(*req_ctx)};

    // Setup op attributes.
    OpAttrs op_attrs;
    ParseOpAttrs(*request, &op_attrs);

    op(exec_ctx, args, OpAttrsRef(op_attrs), results, &chain);

    // Set the output chain mapping in the remote object manager.
    RemoteObjectId out_chain_id(request->out_chain().prefix_id(),
                                request->out_chain().local_id(), device);
    dist_ctx->GetRemoteObjectManager()->SetRemoteObject(out_chain_id,
                                                        chain.CopyRCRef());
    TFRT_DLOG(INFO) << "HandleRemoteExecuteOp " << request->op_name()
                    << " out chain " << request->out_chain().DebugString()
                    << " av " << chain.GetAsyncValue();

    // Wait for results to be ready before sending the response to client.
    // Also add a mapping in the remote object manager for each output.
    llvm::SmallVector<RCReference<AsyncValue>, 4> async_results;
    async_results.reserve(results.size() + 1);  // TH results + out chain
    async_results.push_back(chain.CopyRCRef());
    for (auto i = 0; i < results.size(); ++i) {
      async_results.push_back(FormRef(results[i].GetAsyncTensor()));

      auto& output_id = request->output(i).id();
      auto output_device =
          dist_ctx->GetRemoteDeviceManager()->GetDeviceRef<Device>(
              output_id.device());
      if (device.get() == nullptr) {
        done(llvm::make_error<DeviceNotFoundErrorInfo>(
            StrCat("Can't find device: ", output_id.device())));
        return;
      }
      RemoteObjectId object_id(output_id.prefix_id(), output_id.local_id(),
                               std::move(output_device));
      dist_ctx->GetRemoteObjectManager()->SetRemoteObject(object_id,
                                                          async_results.back());
      TFRT_DLOG(INFO) << "HandleRemoteExecuteOp " << request->op_name()
                      << " output " << output_id.DebugString() << " av "
                      << async_results.back().get();
    }

    RunWhenReady(async_results, [request, response, device = std::move(device),
                                 chain = chain.CopyRef(),
                                 results = std::move(results),
                                 done = std::move(done)]() mutable {
      TFRT_DLOG(INFO) << "HandleRemoteExecuteOp " << request->op_name();
      if (chain.IsError()) {
        // TODO(ayushd): choose a more informative error code.
        done(llvm::make_error<UnknownErrorInfo>(chain.GetError().message));
      }
      for (auto& result : results) {
        auto serialized =
            SerializeTensorMetadata(result.GetAvailableMetadata());
        TFRT_DLOG(INFO) << "HandleRemoteExecuteOp " << request->op_name()
                        << " result metadata " << result.GetAvailableMetadata();
        response->add_metadata(serialized);
      }
      done(Error::success());
    });
  });
}