ucs_status_t Connection::RecvAM()

in dissociated-ipc/ucx_conn.cc [167:234]


ucs_status_t Connection::RecvAM(std::promise<std::unique_ptr<arrow::Buffer>> p,
                                const void* header, const size_t header_length,
                                void* data, const size_t data_length,
                                const ucp_am_recv_param_t* param) {
  if (data_length > static_cast<size_t>(std::numeric_limits<int32_t>::max())) {
    ARROW_LOG(ERROR) << "cannot allocate buffer greater than 2 GiB, requested: "
                     << data_length;
    return UCS_ERR_IO_ERROR;
  }

  if (param->recv_attr & UCP_AM_RECV_ATTR_FLAG_DATA) {
    // data provided can be held by us. return UCS_INPROGRESS to make the data persist
    // and we will eventually use ucp_am_data_release to release it.
    auto buffer = std::make_unique<UcxDataBuffer>(ucp_worker_, data, data_length);
    p.set_value(std::move(buffer));
    return UCS_INPROGRESS;
  }

  // rendezvous protocol
  if (param->recv_attr & UCP_AM_RECV_ATTR_FLAG_RNDV) {
    auto maybe_buffer = arrow::default_cpu_memory_manager()->AllocateBuffer(data_length);
    if (!maybe_buffer.ok()) {
      ARROW_LOG(ERROR) << "could not allocate buffer for message: "
                       << maybe_buffer.status().ToString();
      return UCS_ERR_NO_MEMORY;
    }

    auto buffer = maybe_buffer.MoveValueUnsafe();
    void* dest = reinterpret_cast<void*>(buffer->mutable_address());

    ucp_request_param_t recv_param;
    recv_param.op_attr_mask = UCP_OP_ATTR_FIELD_CALLBACK | UCP_OP_ATTR_FIELD_MEMORY_TYPE |
                              UCP_OP_ATTR_FIELD_USER_DATA | UCP_OP_ATTR_FLAG_NO_IMM_CMPL;
    recv_param.memory_type = UCS_MEMORY_TYPE_HOST;
    recv_param.user_data = new RndvPromiseBuffer{std::move(p), std::move(buffer)};
    recv_param.cb.recv_am = [](void* request, ucs_status_t status, size_t length,
                               void* user_data) {
      auto p = std::unique_ptr<RndvPromiseBuffer>(
          reinterpret_cast<RndvPromiseBuffer*>(user_data));
      if (request) {
        ucp_request_free(request);
      }
      if (status == UCS_OK) {
        p->p.set_value(std::move(p->buf));
      } else {
        ARROW_LOG(ERROR) << FromUcsStatus("ucp_am_recv_data_nbx cb", status).ToString();
        p->p.set_value(nullptr);
      }
    };
    void* request =
        ucp_am_recv_data_nbx(ucp_worker_->get(), data, dest, data_length, &recv_param);
    if (UCS_PTR_IS_ERR(request)) {
      return UCS_PTR_STATUS(request);
    }
    return UCS_OK;
  }

  auto maybe_buffer = arrow::default_cpu_memory_manager()->AllocateBuffer(data_length);
  if (!maybe_buffer.ok()) {
    ARROW_LOG(ERROR) << "could not allocate buffer for message: "
                     << maybe_buffer.status().ToString();
    return UCS_ERR_NO_MEMORY;
  }
  auto buffer = maybe_buffer.MoveValueUnsafe();
  std::memcpy(buffer->mutable_data(), data, data_length);
  p.set_value(std::move(buffer));
  return UCS_OK;
}