src/client/http/spawn.rs (78 lines of code) (raw):

// Licensed to the Apache Software Foundation (ASF) under one // or more contributor license agreements. See the NOTICE file // distributed with this work for additional information // regarding copyright ownership. The ASF licenses this file // to you under the Apache License, Version 2.0 (the // "License"); you may not use this file except in compliance // with the License. You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, // software distributed under the License is distributed on an // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY // KIND, either express or implied. See the License for the // specific language governing permissions and limitations // under the License. use crate::client::{ HttpError, HttpErrorKind, HttpRequest, HttpResponse, HttpResponseBody, HttpService, }; use async_trait::async_trait; use bytes::Bytes; use http::Response; use http_body_util::BodyExt; use hyper::body::{Body, Frame}; use std::pin::Pin; use std::task::{Context, Poll}; use thiserror::Error; use tokio::runtime::Handle; use tokio::task::JoinHandle; /// Spawn error #[derive(Debug, Error)] #[error("SpawnError")] struct SpawnError {} impl From<SpawnError> for HttpError { fn from(value: SpawnError) -> Self { Self::new(HttpErrorKind::Interrupted, value) } } /// Wraps a provided [`HttpService`] and runs it on a separate tokio runtime /// /// See example on [`SpawnedReqwestConnector`] /// /// [`SpawnedReqwestConnector`]: crate::client::http::SpawnedReqwestConnector #[derive(Debug)] pub struct SpawnService<T: HttpService + Clone> { inner: T, runtime: Handle, } impl<T: HttpService + Clone> SpawnService<T> { /// Creates a new [`SpawnService`] from the provided pub fn new(inner: T, runtime: Handle) -> Self { Self { inner, runtime } } } #[async_trait] impl<T: HttpService + Clone> HttpService for SpawnService<T> { async fn call(&self, req: HttpRequest) -> Result<HttpResponse, HttpError> { let inner = self.inner.clone(); let (send, recv) = tokio::sync::oneshot::channel(); // We use an unbounded channel to prevent backpressure across the runtime boundary // which could in turn starve the underlying IO operations let (sender, receiver) = tokio::sync::mpsc::unbounded_channel(); let handle = SpawnHandle(self.runtime.spawn(async move { let r = match HttpService::call(&inner, req).await { Ok(resp) => resp, Err(e) => { let _ = send.send(Err(e)); return; } }; let (parts, mut body) = r.into_parts(); if send.send(Ok(parts)).is_err() { return; } while let Some(x) = body.frame().await { sender.send(x).unwrap(); } })); let parts = recv.await.map_err(|_| SpawnError {})??; Ok(Response::from_parts( parts, HttpResponseBody::new(SpawnBody { stream: receiver, _worker: handle, }), )) } } /// A wrapper around a [`JoinHandle`] that aborts on drop struct SpawnHandle(JoinHandle<()>); impl Drop for SpawnHandle { fn drop(&mut self) { self.0.abort(); } } type StreamItem = Result<Frame<Bytes>, HttpError>; struct SpawnBody { stream: tokio::sync::mpsc::UnboundedReceiver<StreamItem>, _worker: SpawnHandle, } impl Body for SpawnBody { type Data = Bytes; type Error = HttpError; fn poll_frame(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<StreamItem>> { self.stream.poll_recv(cx) } } #[cfg(not(target_arch = "wasm32"))] #[cfg(test)] mod tests { use super::*; use crate::client::mock_server::MockServer; use crate::client::retry::RetryExt; use crate::client::HttpClient; use crate::RetryConfig; async fn test_client(client: HttpClient) { let (send, recv) = tokio::sync::oneshot::channel(); let mock = MockServer::new().await; mock.push(Response::new("BANANAS".to_string())); let url = mock.url().to_string(); let thread = std::thread::spawn(|| { futures::executor::block_on(async move { let retry = RetryConfig::default(); let ret = client.get(url).send_retry(&retry).await.unwrap(); let payload = ret.into_body().bytes().await.unwrap(); assert_eq!(payload.as_ref(), b"BANANAS"); let _ = send.send(()); }) }); recv.await.unwrap(); thread.join().unwrap(); } #[tokio::test] async fn test_spawn() { let client = HttpClient::new(SpawnService::new(reqwest::Client::new(), Handle::current())); test_client(client).await; } #[tokio::test] #[should_panic] async fn test_no_spawn() { let client = HttpClient::new(reqwest::Client::new()); test_client(client).await; } }