shed/futures_01_ext/src/streamfork.rs (260 lines of code) (raw):

/* * Copyright (c) Meta Platforms, Inc. and affiliates. * * This source code is licensed under both the MIT license found in the * LICENSE-MIT file in the root directory of this source tree and the Apache * License, Version 2.0 found in the LICENSE-APACHE file in the root directory * of this source tree. */ use futures::sink::Sink; use futures::stream::{Fuse, Stream}; use futures::{try_ready, Async, AsyncSink, Future, Poll}; /// Fork a Stream into two /// /// Returns a Future for a process that consumes items from a Stream and /// forwards them to two sinks depending on a predicate. If the predicate /// returns false, send the value to out1, otherwise out2. pub fn streamfork<In, Out1, Out2, F, E>( inp: In, out1: Out1, out2: Out2, pred: F, ) -> Forker<In, Out1, Out2, F, E> where In: Stream, Out1: Sink<SinkItem = In::Item>, Out2: Sink<SinkItem = In::Item, SinkError = Out1::SinkError>, F: FnMut(&In::Item) -> Result<bool, E>, E: From<In::Error> + From<Out1::SinkError> + From<Out2::SinkError>, { Forker { inp: Some(inp.fuse()), out1: Out::new(out1), out2: Out::new(out2), pred, finished: None, } } #[must_use = "futures do nothing unless you `.await` or poll them"] pub struct Forker<In, Out1, Out2, F, E> where In: Stream, Out1: Sink, Out2: Sink, { inp: Option<Fuse<In>>, out1: Out<Out1>, out2: Out<Out2>, pred: F, finished: Option<Result<(), E>>, } struct Out<O> where O: Sink, { out: Option<O>, buf: Option<O::SinkItem>, } impl<S: Sink> Out<S> { fn new(s: S) -> Self { Out { out: Some(s), buf: None, } } fn out_mut(&mut self) -> &mut S { self.out.as_mut().take().expect("Out after completion") } fn take_result(&mut self) -> S { self.out.take().expect("Out missing") } fn try_start_send(&mut self, item: S::SinkItem) -> Poll<(), S::SinkError> { debug_assert!(self.buf.is_none()); if let AsyncSink::NotReady(item) = self.out_mut().start_send(item)? { self.buf = Some(item); return Ok(Async::NotReady); } Ok(Async::Ready(())) } fn push(&mut self) -> Poll<(), S::SinkError> { if let Some(item) = self.buf.take() { self.try_start_send(item) } else { Ok(Async::Ready(())) } } fn poll_complete(&mut self) -> Poll<(), S::SinkError> { self.out_mut().poll_complete() } } impl<In, Out1, Out2, F, E> Forker<In, Out1, Out2, F, E> where In: Stream, Out1: Sink, Out2: Sink, E: From<In::Error> + From<Out1::SinkError> + From<Out2::SinkError>, { fn inp_mut(&mut self) -> &mut Fuse<In> { self.inp.as_mut().take().expect("Input after completion") } fn take_result(&mut self) -> (In, Out1, Out2) { let inp = self.inp.take().expect("Input missing in result"); let out1 = self.out1.take_result(); let out2 = self.out2.take_result(); (inp.into_inner(), out1, out2) } fn poll_complete_both(&mut self) -> Poll<(), E> { let r1 = self.out1.poll_complete()?.is_ready(); let r2 = self.out2.poll_complete()?.is_ready(); if !(r1 && r2) { return Ok(Async::NotReady); } Ok(Async::Ready(())) } #[cfg(test)] pub(crate) fn out1(&mut self) -> &Out1 { self.out1.out_mut() } #[cfg(test)] pub(crate) fn out2(&mut self) -> &Out2 { self.out2.out_mut() } } impl<In, Out1, Out2, F, E> Future for Forker<In, Out1, Out2, F, E> where In: Stream, Out1: Sink<SinkItem = In::Item>, Out2: Sink<SinkItem = In::Item>, F: FnMut(&In::Item) -> Result<bool, E>, E: From<In::Error> + From<Out1::SinkError> + From<Out2::SinkError>, { type Item = (In, Out1, Out2); type Error = E; fn poll(&mut self) -> Poll<Self::Item, Self::Error> { if self.finished.is_some() { // Polling input stream ended, possibly with an error. // Let's make sure we send all already fetched data to the outputs try_ready!(self.poll_complete_both()); let finished_res = self.finished.take().expect("is_some() returned false"); return finished_res.map(|()| Async::Ready(self.take_result())); } // Make sure both outputs are clear to accept new data { let r1 = self.out1.push()?.is_ready(); let r2 = self.out2.push()?.is_ready(); if !(r1 && r2) { return Ok(Async::NotReady); } } // Read input and send to outputs until either input dries up or outputs are full loop { match self.inp_mut().poll() { Ok(Async::Ready(Some(item))) => { if (self.pred)(&item)? { try_ready!(self.out2.try_start_send(item)) } else { try_ready!(self.out1.try_start_send(item)) } } Ok(Async::NotReady) => { self.out1.poll_complete()?; self.out2.poll_complete()?; return Ok(Async::NotReady); } Ok(Async::Ready(None)) => { if !self.poll_complete_both()?.is_ready() { self.finished = Some(Ok(())); return Ok(Async::NotReady); } return Ok(Async::Ready(self.take_result())); } Err(err) => { if !self.poll_complete_both()?.is_ready() { self.finished = Some(Err(err.into())); return Ok(Async::NotReady); } return Err(err.into()); } } } } } #[cfg(test)] mod test { use super::*; use futures::stream::{iter_ok, once}; use futures::{sink::Sink, Async, AsyncSink, Future, StartSend}; #[test] fn simple() { let even = Vec::new(); let odd = Vec::new(); let nums = iter_ok(0i32..10); let (_, even, odd) = streamfork(nums, even, odd, |n| Ok::<_, ()>(*n % 2 == 1)) .wait() .unwrap(); assert_eq!(even, vec![0, 2, 4, 6, 8]); assert_eq!(odd, vec![1, 3, 5, 7, 9]); } struct DelayedSink { inner: Vec<u32>, buffer: Vec<u32>, poll_complete_left: u32, } impl DelayedSink { fn new(poll_complete_left: u32) -> Self { Self { inner: vec![], buffer: vec![], poll_complete_left, } } } impl Sink for DelayedSink { type SinkItem = u32; type SinkError = (); fn start_send( &mut self, item: Self::SinkItem, ) -> StartSend<Self::SinkItem, Self::SinkError> { self.buffer.push(item); Ok(AsyncSink::Ready) } fn poll_complete(&mut self) -> Poll<(), Self::SinkError> { if self.buffer.is_empty() { return Ok(Async::Ready(())); } if self.poll_complete_left == 0 { Ok(Async::Ready(())) } else { self.poll_complete_left -= 1; let val = self.buffer.remove(0); self.inner.push(val); Ok(Async::NotReady) } } } #[test] fn delayed_poll() { let even = DelayedSink::new(5); let odd = DelayedSink::new(5); let nums = iter_ok(0u32..2); let mut fork = streamfork(nums, even, odd, |n| Ok::<_, ()>(*n % 2 == 1)); loop { let res = fork.poll().expect("no error expected"); if let Async::Ready((_, even, odd)) = res { assert_eq!(even.inner, vec![0]); assert_eq!(odd.inner, vec![1]); break; } } } #[test] fn delayed_poll_with_err() { let even = DelayedSink::new(5); let odd = DelayedSink::new(5); let nums = iter_ok(0u32..2).chain(once(Err(()))); let mut fork = streamfork(nums, even, odd, |n| Ok::<_, ()>(*n % 2 == 1)); loop { let res = fork.poll(); if res.is_err() { assert_eq!(fork.out1().inner, vec![0]); assert_eq!(fork.out2().inner, vec![1]); break; } if res.unwrap().is_ready() { panic!("expected an error"); } } } }