Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 48 additions & 3 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,10 @@ impl<'a, T, const STACK_SIZE: usize> StackFuture<'a, T, { STACK_SIZE }> {
/// enough to hold F and any required padding.
fn as_mut_ptr<F>(&mut self) -> *mut F {
assert!(Self::has_space_for::<F>());
self.data.as_mut_ptr().cast()
// SAFETY: Self is laid out so that the space for the future comes at offset 0.
// This is checked by an assertion in Self::from. Thus it's safe to cast a pointer
// to Self into a pointer to the wrapped future.
unsafe { mem::transmute(self) }
}

/// Returns a pinned mutable reference to a type F stored in self.data
Expand Down Expand Up @@ -226,8 +229,17 @@ impl<'a, T, const STACK_SIZE: usize> StackFuture<'a, T, { STACK_SIZE }> {
impl<'a, T, const STACK_SIZE: usize> Future for StackFuture<'a, T, { STACK_SIZE }> {
type Output = T;

fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
(self.as_mut().poll_fn)(self, cx)
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
// SAFETY: This is doing pin projection. We unpin self so we can
// access self.poll_fn, and then re-pin self to pass it into poll_in.
// The part of the struct that needs to be pinned is data, since it
// contains a potentially self-referential future object, but since we
// do not touch that while self is unpinned and we do not move self
// while unpinned we are okay.
unsafe {
let this = self.get_unchecked_mut();
(this.poll_fn)(Pin::new_unchecked(this), cx)
}
}
}

Expand All @@ -241,12 +253,17 @@ impl<'a, T, const STACK_SIZE: usize> Drop for StackFuture<'a, T, { STACK_SIZE }>
mod tests {
use crate::StackFuture;
use core::task::Poll;
use futures::channel::mpsc;
use futures::executor::block_on;
use futures::pin_mut;
use futures::Future;
use futures::SinkExt;
use futures::Stream;
use futures::StreamExt;
use std::sync::Arc;
use std::task::Context;
use std::task::Wake;
use std::thread;

#[test]
fn create_and_run() {
Expand Down Expand Up @@ -340,4 +357,32 @@ mod tests {
fn is_aligned<T>(ptr: *mut T, alignment: usize) -> bool {
(ptr as usize) & (alignment - 1) == 0
}

#[test]
fn stress_drop_sender() {
// Regression test for #9

const ITER: usize = if cfg!(miri) { 10 } else { 10000 };

fn list() -> impl Stream<Item = i32> {
let (tx, rx) = mpsc::channel(1);
thread::spawn(move || {
block_on(send_one_two_three(tx));
});
rx
}

for _ in 0..ITER {
let v: Vec<_> = block_on(list().collect());
assert_eq!(v, vec![1, 2, 3]);
}
}

fn send_one_two_three(mut tx: mpsc::Sender<i32>) -> StackFuture<'static, (), 512> {
StackFuture::from(async move {
for i in 1..=3 {
tx.send(i).await.unwrap();
}
})
}
}