diff --git a/src/lib.rs b/src/lib.rs index 0e58ad7..5e80076 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -190,7 +190,10 @@ impl<'a, T, const STACK_SIZE: usize> StackFuture<'a, T, { STACK_SIZE }> { /// enough to hold F and any required padding. fn as_mut_ptr(&mut self) -> *mut F { assert!(Self::has_space_for::()); - self.data.as_mut_ptr().cast() + // SAFETY: Self is laid out so that the space for the future comes at offset 0. + // This is checked by an assertion in Self::from. Thus it's safe to cast a pointer + // to Self into a pointer to the wrapped future. + unsafe { mem::transmute(self) } } /// Returns a pinned mutable reference to a type F stored in self.data @@ -226,8 +229,17 @@ impl<'a, T, const STACK_SIZE: usize> StackFuture<'a, T, { STACK_SIZE }> { impl<'a, T, const STACK_SIZE: usize> Future for StackFuture<'a, T, { STACK_SIZE }> { type Output = T; - fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - (self.as_mut().poll_fn)(self, cx) + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + // SAFETY: This is doing pin projection. We unpin self so we can + // access self.poll_fn, and then re-pin self to pass it into poll_in. + // The part of the struct that needs to be pinned is data, since it + // contains a potentially self-referential future object, but since we + // do not touch that while self is unpinned and we do not move self + // while unpinned we are okay. + unsafe { + let this = self.get_unchecked_mut(); + (this.poll_fn)(Pin::new_unchecked(this), cx) + } } } @@ -241,12 +253,17 @@ impl<'a, T, const STACK_SIZE: usize> Drop for StackFuture<'a, T, { STACK_SIZE }> mod tests { use crate::StackFuture; use core::task::Poll; + use futures::channel::mpsc; use futures::executor::block_on; use futures::pin_mut; use futures::Future; + use futures::SinkExt; + use futures::Stream; + use futures::StreamExt; use std::sync::Arc; use std::task::Context; use std::task::Wake; + use std::thread; #[test] fn create_and_run() { @@ -340,4 +357,32 @@ mod tests { fn is_aligned(ptr: *mut T, alignment: usize) -> bool { (ptr as usize) & (alignment - 1) == 0 } + + #[test] + fn stress_drop_sender() { + // Regression test for #9 + + const ITER: usize = if cfg!(miri) { 10 } else { 10000 }; + + fn list() -> impl Stream { + let (tx, rx) = mpsc::channel(1); + thread::spawn(move || { + block_on(send_one_two_three(tx)); + }); + rx + } + + for _ in 0..ITER { + let v: Vec<_> = block_on(list().collect()); + assert_eq!(v, vec![1, 2, 3]); + } + } + + fn send_one_two_three(mut tx: mpsc::Sender) -> StackFuture<'static, (), 512> { + StackFuture::from(async move { + for i in 1..=3 { + tx.send(i).await.unwrap(); + } + }) + } }