shed/tokio_shim/src/lib.rs (321 lines of code) (raw):

/* * Copyright (c) Meta Platforms, Inc. and affiliates. * * This source code is licensed under both the MIT license found in the * LICENSE-MIT file in the root directory of this source tree and the Apache * License, Version 2.0 found in the LICENSE-APACHE file in the root directory * of this source tree. */ use futures::{future::Future, ready, stream::Stream}; use pin_project::pin_project; use std::any::Any; use std::pin::Pin; use std::sync::Once; use std::task::{Context, Poll}; use std::time::{Duration, Instant}; use thiserror::Error; pub mod task { use super::*; #[derive(Debug, Error)] pub enum JoinHandleError { #[error("Tokio 0.2 JoinError")] Tokio02(#[from] tokio_02::task::JoinError), #[error("Tokio 1.x JoinError")] Tokio1x(#[from] tokio_1x::task::JoinError), } impl JoinHandleError { // For now just implement the required apis // See https://docs.rs/tokio/1/tokio/task/struct.JoinError.html#method.into_panic pub fn into_panic(self) -> Box<dyn Any + Send + 'static> { match self { JoinHandleError::Tokio02(inner) => inner.into_panic(), JoinHandleError::Tokio1x(inner) => inner.into_panic(), } } pub fn is_panic(&self) -> bool { match self { JoinHandleError::Tokio02(inner) => inner.is_panic(), JoinHandleError::Tokio1x(inner) => inner.is_panic(), } } } #[pin_project(project = JoinHandleProj)] pub enum JoinHandle<T> { Tokio02(#[pin] tokio_02::task::JoinHandle<T>), Tokio1x(#[pin] tokio_1x::task::JoinHandle<T>), Fallback(Option<T>), } impl<T> Future for JoinHandle<T> where T: Send + 'static, { type Output = Result<T, JoinHandleError>; fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { let ret = match self.project() { JoinHandleProj::Tokio02(f) => ready!(f.poll(cx)).map_err(JoinHandleError::from), JoinHandleProj::Tokio1x(f) => ready!(f.poll(cx)).map_err(JoinHandleError::from), JoinHandleProj::Fallback(value) => return Poll::Ready(Ok(value.take().unwrap())), }; Poll::Ready(ret) } } pub fn spawn<F>(fut: F) -> JoinHandle<<F as Future>::Output> where F: Future + Send + 'static, F::Output: Send + 'static, { if let Ok(handle) = tokio_02::runtime::Handle::try_current() { return JoinHandle::Tokio02(handle.spawn(fut)); } if let Ok(handle) = tokio_1x::runtime::Handle::try_current() { return JoinHandle::Tokio1x(handle.spawn(fut)); } // This is what tokio::spawn would give you, so we don't try to do better here. panic!("A Tokio 0.2 or 1.x runtime is required, but neither was running"); } pub fn spawn_blocking<F, R>(f: F) -> JoinHandle<R> where F: FnOnce() -> R + Send + 'static, R: Send + 'static, { if let Ok(handle) = tokio_02::runtime::Handle::try_current() { return JoinHandle::Tokio02(handle.spawn_blocking(f)); } if let Ok(handle) = tokio_1x::runtime::Handle::try_current() { return JoinHandle::Tokio1x(handle.spawn_blocking(f)); } // This is what tokio::spawn_blocking would give you, so we don't try to do better here. panic!("A Tokio 0.2 or 1.x runtime is required, but neither was running"); } /// Like `spawn_blocking`, but if there is no tokio runtime, just runs the code inline. /// This prints a warning, as this is NOT desireable and can cause performance problems pub fn spawn_blocking_fallback_inline<F, R>(f: F) -> JoinHandle<R> where F: FnOnce() -> R + Send + 'static, R: Send + 'static, { if let Ok(handle) = tokio_02::runtime::Handle::try_current() { return JoinHandle::Tokio02(handle.spawn_blocking(f)); } if let Ok(handle) = tokio_1x::runtime::Handle::try_current() { return JoinHandle::Tokio1x(handle.spawn_blocking(f)); } static WARN: Once = Once::new(); WARN.call_once(|| { use std::io::Write; let _ = writeln!( std::io::stderr(), "Falling back to running blocking code inline. Please use a tokio runtime instead!!" ); }); JoinHandle::Fallback(Some(f())) } } pub mod time { use super::*; #[pin_project(project = SleepProj)] pub enum Sleep { Tokio02(#[pin] tokio_02::time::Delay), Tokio1x(#[pin] tokio_1x::time::Sleep), } impl Future for Sleep { type Output = (); fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { match self.project() { SleepProj::Tokio02(f) => f.poll(cx), SleepProj::Tokio1x(f) => f.poll(cx), } } } pub fn sleep(duration: Duration) -> Sleep { if tokio_02::runtime::Handle::try_current().is_ok() { return Sleep::Tokio02(tokio_02::time::delay_for(duration)); } if tokio_1x::runtime::Handle::try_current().is_ok() { return Sleep::Tokio1x(tokio_1x::time::sleep(duration)); } // This is what tokio::time::sleep would do. panic!("A Tokio 0.2 or 1.x runtime is required, but neither was running"); } pub fn sleep_until(instant: Instant) -> Sleep { if tokio_02::runtime::Handle::try_current().is_ok() { return Sleep::Tokio02(tokio_02::time::delay_until( tokio_02::time::Instant::from_std(instant), )); } if tokio_1x::runtime::Handle::try_current().is_ok() { return Sleep::Tokio1x(tokio_1x::time::sleep_until( tokio_1x::time::Instant::from_std(instant), )); } // This is what tokio::time::sleep would do. panic!("A Tokio 0.2 or 1.x runtime is required, but neither was running"); } #[derive(Debug, Error)] #[error("deadline has elapsed")] pub struct Elapsed; #[pin_project(project = TimeoutProj)] pub enum Timeout<F> { Tokio02(#[pin] tokio_02::time::Timeout<F>), Tokio1x(#[pin] tokio_1x::time::Timeout<F>), } impl<F> Future for Timeout<F> where F: Future, { type Output = Result<<F as Future>::Output, Elapsed>; fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { let res = match self.project() { TimeoutProj::Tokio02(f) => { ready!(f.poll(cx)).map_err(|_: tokio_02::time::Elapsed| Elapsed) } TimeoutProj::Tokio1x(f) => { ready!(f.poll(cx)).map_err(|_: tokio_1x::time::error::Elapsed| Elapsed) } }; Poll::Ready(res) } } pub fn timeout<F: Future>(duration: Duration, fut: F) -> Timeout<F> { if tokio_02::runtime::Handle::try_current().is_ok() { return Timeout::Tokio02(tokio_02::time::timeout(duration, fut)); } if tokio_1x::runtime::Handle::try_current().is_ok() { return Timeout::Tokio1x(tokio_1x::time::timeout(duration, fut)); } // This is what tokio::time::timeout would do. panic!("A Tokio 0.2 or 1.x runtime is required, but neither was running"); } #[pin_project(project = IntervalStreamProj)] pub enum IntervalStream { Tokio02(#[pin] tokio_02::time::Interval), Tokio1x(#[pin] tokio_1x_stream::wrappers::IntervalStream), } impl Stream for IntervalStream { type Item = Instant; fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> { let ret = match self.project() { IntervalStreamProj::Tokio02(f) => ready!(f.poll_next(cx)).map(|i| i.into_std()), IntervalStreamProj::Tokio1x(f) => ready!(f.poll_next(cx)).map(|i| i.into_std()), }; Poll::Ready(ret) } } pub fn interval_stream(period: Duration) -> IntervalStream { if tokio_02::runtime::Handle::try_current().is_ok() { return IntervalStream::Tokio02(tokio_02::time::interval(period)); } if tokio_1x::runtime::Handle::try_current().is_ok() { let interval = tokio_1x::time::interval(period); let stream = tokio_1x_stream::wrappers::IntervalStream::new(interval); return IntervalStream::Tokio1x(stream); } // This is what tokio::time::interval_at would do. panic!("A Tokio 0.2 or 1.x runtime is required, but neither was running"); } } pub mod runtime { use super::*; use task::JoinHandle; #[derive(Debug, Clone)] pub enum Handle { Tokio02(tokio_02::runtime::Handle), Tokio1x(tokio_1x::runtime::Handle), } impl Handle { pub fn current() -> Self { if let Ok(hdl) = tokio_02::runtime::Handle::try_current() { return Self::Tokio02(hdl); } if let Ok(hdl) = tokio_1x::runtime::Handle::try_current() { return Self::Tokio1x(hdl); } panic!("A Tokio 0.2 or 1.x runtime is required, but neither was running"); } pub fn spawn<F>(&self, fut: F) -> JoinHandle<<F as Future>::Output> where F: Future + Send + 'static, F::Output: Send + 'static, { match self { Self::Tokio02(hdl) => JoinHandle::Tokio02(hdl.spawn(fut)), Self::Tokio1x(hdl) => JoinHandle::Tokio1x(hdl.spawn(fut)), } } } } #[cfg(test)] mod test { use super::*; use futures::{future, stream::StreamExt}; async fn test() { task::spawn(future::ready(())).await.unwrap(); task::spawn_blocking(|| ()).await.unwrap(); time::sleep(Duration::from_millis(1)).await; time::sleep_until(Instant::now() + Duration::from_millis(1)).await; time::interval_stream(Duration::from_millis(1)).next().await; assert!( time::timeout(Duration::from_millis(1), future::pending::<()>()) .await .is_err() ); runtime::Handle::current() .spawn(future::ready(())) .await .unwrap(); } #[test] fn test_02() -> Result<(), anyhow::Error> { let mut rt = tokio_02::runtime::Builder::new() .enable_all() .basic_scheduler() .build()?; rt.block_on(test()); Ok(()) } #[test] fn test_1x() -> Result<(), anyhow::Error> { let rt = tokio_1x::runtime::Builder::new_current_thread() .enable_all() .build()?; rt.block_on(test()); Ok(()) } #[test] #[should_panic] fn test_panic_forwarding_02() { let mut rt = tokio_02::runtime::Builder::new() .enable_all() .build() .unwrap(); rt.block_on(async { let je = task::spawn(async { panic!("gus"); }) .await .unwrap_err(); std::panic::resume_unwind(je.into_panic()) }); } #[test] #[should_panic] fn test_panic_forwarding_1x() { let rt = tokio_1x::runtime::Builder::new_current_thread() .enable_all() .build() .unwrap(); rt.block_on(async { let je = task::spawn(async { panic!("gus"); }) .await .unwrap_err(); std::panic::resume_unwind(je.into_panic()) }); } #[test] fn test_fallback() { // No tokio running assert!( futures::executor::block_on(task::spawn_blocking_fallback_inline(|| true)).unwrap() ); // Second time still works, even though it doesn't write to stderr. assert!( futures::executor::block_on(task::spawn_blocking_fallback_inline(|| true)).unwrap() ); } }