diff --git a/tokio-stream/src/lib.rs b/tokio-stream/src/lib.rs index a5b3a05e03a..c47e4deee5f 100644 --- a/tokio-stream/src/lib.rs +++ b/tokio-stream/src/lib.rs @@ -80,8 +80,8 @@ pub use stream_ext::{collect::FromStream, StreamExt}; /// Adapters for [`Stream`]s created by methods in [`StreamExt`]. pub mod adapters { pub use crate::stream_ext::{ - Chain, Filter, FilterMap, Fuse, Map, MapWhile, Merge, Peekable, Skip, SkipWhile, Take, - TakeWhile, Then, + Chain, Filter, FilterMap, FilterMapAsync, Fuse, Map, MapWhile, Merge, Peekable, Skip, + SkipWhile, Take, TakeWhile, Then, }; cfg_time! { pub use crate::stream_ext::{ChunksTimeout, Timeout, TimeoutRepeating}; diff --git a/tokio-stream/src/stream_ext.rs b/tokio-stream/src/stream_ext.rs index fe589869215..9ba934a9e68 100644 --- a/tokio-stream/src/stream_ext.rs +++ b/tokio-stream/src/stream_ext.rs @@ -16,6 +16,9 @@ use collect::{Collect, FromStream}; mod filter; pub use filter::Filter; +mod filter_map_async; +pub use filter_map_async::FilterMapAsync; + mod filter_map; pub use filter_map::FilterMap; @@ -479,6 +482,56 @@ pub trait StreamExt: Stream { FilterMap::new(self, f) } + /// Filters the values produced by this stream asynchronously while + /// simultaneously mapping them to a different type according to the + /// provided async closure. + /// + /// The provided closure is executed over all elements of this stream as + /// they are made available, and the returned future is executed. Only one + /// future is executed at the time. If the returned future resolves to + /// [`Some(item)`](Some) then the stream will yield the value `item`, but if + /// it resolves to [`None`], then the value will be skipped. + /// + /// Note that this function consumes the stream passed into it and returns a + /// wrapped version of it, similar to [`Iterator::filter_map`] method in the + /// standard library. + /// + /// Be aware that if the future is not `Unpin`, then neither is the `Stream` + /// returned by this method. To handle this, you can use `tokio::pin!` as in + /// the example below or put the stream in a `Box` with `Box::pin(stream)`. + /// + /// # Examples + /// ``` + /// # #[tokio::main(flavor = "current_thread")] + /// # async fn main() { + /// use std::time::Duration; + /// use tokio::time; + /// use tokio_stream::{self as stream, StreamExt}; + /// + /// let stream = stream::iter(0..=7); + /// let odds = stream.filter_map_async(async |x| { + /// time::sleep(Duration::from_millis(100)).await; + /// if x % 2 == 0 { Some(x + 1) } else { None } + /// }); + /// + /// tokio::pin!(odds); + /// + /// assert_eq!(Some(1), odds.next().await); + /// assert_eq!(Some(3), odds.next().await); + /// assert_eq!(Some(5), odds.next().await); + /// assert_eq!(Some(7), odds.next().await); + /// assert_eq!(None, odds.next().await); + /// # } + /// ``` + fn filter_map_async(self, f: F) -> FilterMapAsync + where + F: FnMut(Self::Item) -> Fut, + Fut: Future>, + Self: Sized, + { + FilterMapAsync::new(self, f) + } + /// Creates a stream which ends after the first `None`. /// /// After a stream returns `None`, behavior is undefined. Future calls to diff --git a/tokio-stream/src/stream_ext/filter_map_async.rs b/tokio-stream/src/stream_ext/filter_map_async.rs new file mode 100644 index 00000000000..d40dfd5f9dd --- /dev/null +++ b/tokio-stream/src/stream_ext/filter_map_async.rs @@ -0,0 +1,80 @@ +use crate::Stream; + +use core::fmt; +use core::future::Future; +use core::pin::Pin; +use core::task::{ready, Context, Poll}; +use pin_project_lite::pin_project; + +pin_project! { + /// Stream for the [`filter_map_async`](super::StreamExt::filter_map_async) method. + #[must_use = "streams do nothing unless polled"] + pub struct FilterMapAsync { + #[pin] + stream: St, + #[pin] + future: Option, + f: F, + } +} + +impl fmt::Debug for FilterMapAsync +where + St: fmt::Debug, +{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("FilterMapAsync") + .field("stream", &self.stream) + .finish() + } +} + +impl FilterMapAsync { + pub(super) fn new(stream: St, f: F) -> Self { + FilterMapAsync { + stream, + future: None, + f, + } + } +} + +impl Stream for FilterMapAsync +where + St: Stream, + Fut: Future>, + F: FnMut(St::Item) -> Fut, +{ + type Item = T; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let mut me = self.project(); + + loop { + if let Some(future) = me.future.as_mut().as_pin_mut() { + let item = ready!(future.poll(cx)); + me.future.set(None); + if let Some(item) = item { + return Poll::Ready(Some(item)); + } + } + + match ready!(me.stream.as_mut().poll_next(cx)) { + Some(item) => { + me.future.set(Some((me.f)(item))); + } + None => return Poll::Ready(None), + } + } + } + + fn size_hint(&self) -> (usize, Option) { + let future_len = usize::from(self.future.is_some()); + let upper = self + .stream + .size_hint() + .1 + .and_then(|upper| upper.checked_add(future_len)); + (0, upper) + } +} diff --git a/tokio-stream/tests/stream_filter_map_async.rs b/tokio-stream/tests/stream_filter_map_async.rs new file mode 100644 index 00000000000..e2718361855 --- /dev/null +++ b/tokio-stream/tests/stream_filter_map_async.rs @@ -0,0 +1,117 @@ +use futures::Stream; +use tokio::sync::Notify; +use tokio_stream::{self as stream, StreamExt}; +use tokio_test::{assert_pending, assert_ready_eq, task}; + +mod support { + pub(crate) mod mpsc; +} + +use support::mpsc; + +#[tokio::test] +async fn basic() { + let (tx, rx) = mpsc::unbounded_channel_stream(); + + let mut st = + task::spawn(rx.filter_map_async(async |x| if x % 2 == 0 { Some(x + 1) } else { None })); + assert_pending!(st.poll_next()); + + tx.send(1).unwrap(); + assert!(st.is_woken()); + assert_pending!(st.poll_next()); + + tx.send(2).unwrap(); + assert!(st.is_woken()); + assert_ready_eq!(st.poll_next(), Some(3)); + + assert_pending!(st.poll_next()); + + tx.send(3).unwrap(); + assert!(st.is_woken()); + assert_pending!(st.poll_next()); + + drop(tx); + assert!(st.is_woken()); + assert_ready_eq!(st.poll_next(), None); +} + +#[tokio::test] +async fn notify_unbounded() { + let (tx, rx) = mpsc::unbounded_channel_stream(); + let notify = Notify::new(); + + let mut st = task::spawn(rx.filter_map_async(async |x| { + notify.notified().await; + if x % 2 == 0 { + Some(x + 1) + } else { + None + } + })); + assert_pending!(st.poll_next()); + + tx.send(0).unwrap(); + assert!(st.is_woken()); + assert_pending!(st.poll_next()); + + notify.notify_one(); + assert!(st.is_woken()); + assert_ready_eq!(st.poll_next(), Some(1)); + + tx.send(1).unwrap(); + assert!(!st.is_woken()); + assert_pending!(st.poll_next()); + + notify.notify_one(); + assert!(st.is_woken()); + assert_pending!(st.poll_next()); + + tx.send(2).unwrap(); + assert!(st.is_woken()); + assert_pending!(st.poll_next()); + + notify.notify_one(); + assert!(st.is_woken()); + assert_ready_eq!(st.poll_next(), Some(3)); + + drop(tx); + assert!(!st.is_woken()); + assert_ready_eq!(st.poll_next(), None); +} + +#[tokio::test] +async fn notify_bounded() { + let notify = Notify::new(); + let mut st = task::spawn(stream::iter(0..3).filter_map_async(async |x| { + notify.notified().await; + if x % 2 == 0 { + Some(x + 1) + } else { + None + } + })); + assert_eq!(st.size_hint(), (0, Some(3))); + assert_pending!(st.poll_next()); + + notify.notify_one(); + assert!(st.is_woken()); + assert_eq!(st.size_hint(), (0, Some(3))); + assert_ready_eq!(st.poll_next(), Some(1)); + assert_eq!(st.size_hint(), (0, Some(2))); + + notify.notify_one(); + assert!(!st.is_woken()); + assert_eq!(st.size_hint(), (0, Some(2))); + assert_pending!(st.poll_next()); + assert_eq!(st.size_hint(), (0, Some(1))); + + notify.notify_one(); + assert!(st.is_woken()); + assert_eq!(st.size_hint(), (0, Some(1))); + assert_ready_eq!(st.poll_next(), Some(3)); + assert_eq!(st.size_hint(), (0, Some(0))); + + assert!(!st.is_woken()); + assert_ready_eq!(st.poll_next(), None); +}