2

I'm trying to implement an async read wrapper that will add read timeout functionality. The objective is that the API is plain AsyncRead. In other words, I don't want to add io.read(buf).timeout(t) everywehere in the code. Instead, the read instance itself should return the appropriate io::ErrorKind::TimedOut after the given timeout expires.

I can't poll the delay to Ready though. It's always Pending. I've tried with async-std, futures, smol-timeout - the same result. While the timeout does trigger when awaited, it just doesn't when polled. I know timeouts aren't easy. Something needs to wake it up. What am I doing wrong? How to pull this through?

use async_std::{
    future::Future,
    io,
    pin::Pin,
    task::{sleep, Context, Poll},
};
use std::time::Duration;

pub struct PrudentIo<IO> {
    expired: Option<Pin<Box<dyn Future<Output = ()> + Sync + Send>>>,
    timeout: Duration,
    io: IO,
}

impl<IO> PrudentIo<IO> {
    pub fn new(timeout: Duration, io: IO) -> Self {
        PrudentIo {
            expired: None,
            timeout,
            io,
        }
    }
}

fn delay(t: Duration) -> Option<Pin<Box<dyn Future<Output = ()> + Sync + Send + 'static>>> {
    if t.is_zero() {
        return None;
    }
    Some(Box::pin(sleep(t)))
}

impl<IO: io::Read + Unpin> io::Read for PrudentIo<IO> {
    fn poll_read(
        mut self: Pin<&mut Self>,
        cx: &mut Context<'_>,
        buf: &mut [u8],
    ) -> Poll<io::Result<usize>> {
        if let Some(ref mut expired) = self.expired {
            match expired.as_mut().poll(cx) {
                Poll::Ready(_) => {
                    println!("expired ready");
                    // too much time passed since last read/write
                    return Poll::Ready(Err(io::ErrorKind::TimedOut.into()));
                }
                Poll::Pending => {
                    println!("expired pending");
                    // in good time
                }
            }
        }

        let res = Pin::new(&mut self.io).poll_read(cx, buf);
        println!("read {:?}", res);

        match res {
            Poll::Pending => {
                if self.expired.is_none() {
                    // No data, start checking for a timeout
                    self.expired = delay(self.timeout);
                }
            }
            Poll::Ready(_) => self.expired = None,
        }

        res
    }
}
impl<IO: io::Write + Unpin> io::Write for PrudentIo<IO> {
    fn poll_write(
        mut self: Pin<&mut Self>,
        cx: &mut Context<'_>,
        buf: &[u8],
    ) -> Poll<io::Result<usize>> {
        Pin::new(&mut self.io).poll_write(cx, buf)
    }

    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
        Pin::new(&mut self.io).poll_flush(cx)
    }

    fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
        Pin::new(&mut self.io).poll_close(cx)
    }
}

#[cfg(test)]
mod io_tests {
    use super::*;
    use async_std::io::ReadExt;
    use async_std::prelude::FutureExt;
    use async_std::{
        io::{copy, Cursor},
        net::TcpStream,
    };
    use std::time::Duration;

    #[async_std::test]
    async fn fail_read_after_timeout() -> io::Result<()> {
        let mut output = b"______".to_vec();
        let io = PendIo;
        let mut io = PrudentIo::new(Duration::from_millis(5), io);
        let mut io = Pin::new(&mut io);
        insta::assert_debug_snapshot!(io.read(&mut output[..]).timeout(Duration::from_secs(1)).await,@"Ok(io::Err(timeou))");
        Ok(())
    }
    #[async_std::test]
    async fn timeout_expires() {
        let later = delay(Duration::from_millis(1)).expect("some").await;
        insta::assert_debug_snapshot!(later,@r"()");
    }
    /// Mock IO always pending
    struct PendIo;
    impl io::Read for PendIo {
        fn poll_read(
            self: Pin<&mut Self>,
            _cx: &mut Context<'_>,
            _buf: &mut [u8],
        ) -> Poll<futures_io::Result<usize>> {
            Poll::Pending
        }
    }
    impl io::Write for PendIo {
        fn poll_write(
            self: Pin<&mut Self>,
            _cx: &mut Context<'_>,
            _buf: &[u8],
        ) -> Poll<futures_io::Result<usize>> {
            Poll::Pending
        }

        fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<futures_io::Result<()>> {
            Poll::Pending
        }

        fn poll_close(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<futures_io::Result<()>> {
            Poll::Pending
        }
    }
}

Robert Cutajar
  • 3,181
  • 1
  • 30
  • 42

3 Answers3

2

Async timeouts work as follows:

  1. You create the timeout future.
  2. The runtime calls poll into the timeout, it checks whether the timeout has expired.
  3. If it is expired, it returns Ready and done.
  4. If it is not expired, it somehow registers a callback for when the right time has passed it calls cx.waker().wake(), or similar.
  5. When the time has passed, the callback from #4 is invoked, that calls wake() in the proper waker, which instructs the runtime to call poll again.
  6. This time poll will return Ready. Done!

The problem with your code is that you create the delay from inside the poll() implementation: self.expired = delay(self.timeout);. But then you return Pending without polling the timeout even once. This way, there is no callback registered anywhere that would call the Waker. No waker, no timeout.

I see several solutions:

A. Do not initialize PrudentIo::expired to None but create the timeout directly in the constructor. That way the timeout will always be polled before the io at least once, and it will be woken. But you will create a timeout always, even if it is not actually needed.

B. When creating the timeout do a recursive poll:

Poll::Pending => {
    if self.expired.is_none() {
        // No data, start checking for a timeout
        self.expired = delay(self.timeout);
        return self.poll_read(cx, buf);
    }

This will call the io twice, unnecesarily, so it may not be optimal.

C. Add a call to poll after creating the timeout:

Poll::Pending => {
    if self.expired.is_none() {
        // No data, start checking for a timeout
        self.expired = delay(self.timeout);
        self.expired.as_mut().unwrap().as_mut().poll(cx);
    }

Maybe you should match the output of poll in case it returns Ready, but hey, it's a new timeout, it's probably pending yet, and it seems to work nicely.

rodrigo
  • 94,151
  • 12
  • 143
  • 190
  • Nice catch, I thought it would be a waker trouble, but couldn't work it out. Since the timeout is never polled it will not register... I've fixed C up a bit to be on the safe side: `/* No data, start checking for a timeout Poll once to register waker! */ self.expired = delay(self.timeout).and_then(|mut fut| match fut.as_mut().poll(cx) { Poll::Ready(_) => None, Poll::Pending => Some(fut), });` Muchas gracias! – Robert Cutajar Feb 08 '22 at 20:05
1
// This is another solution. I think it is better.

impl<IO: io::AsyncRead + Unpin> io::AsyncRead for PrudentIo<IO> {
    fn poll_read(
        self: Pin<&mut Self>,
        cx: &mut Context<'_>,
        buf: &mut [u8],
    ) -> Poll<io::Result<usize>> {
        let this = self.get_mut();

        let io = Pin::new(&mut this.io);
        if let Poll::Ready(res) = io.poll_read(cx, buf) {
            return Poll::Ready(res);
        }

        loop {
            if let Some(expired) = this.expired.as_mut() {
                ready!(expired.poll(cx));
                this.expired.take();
                return Poll::Ready(Err(io::ErrorKind::TimedOut.into()));
            }

            let timeout = Timer::after(this.timeout);
            this.expired = Some(timeout);
        }
    }
}
moto
  • 9
  • 2
  • Hi @moto, I like the trick with loop to poll the timer once created (the solution to my problem). And your code seems simpler/shorter. However it changes the behavior slightly since the timeout is only enforced after the read, meaning that if a read succeeds right at the timeout, it will not be a timeout. This could be very rare and perhaps even desirable. It would be great if you can describe why is your solution better rather than just posting code. – Robert Cutajar Mar 07 '22 at 07:44
  • @moto, don't you need to reset `this.expire` right before `return Poll::Ready(res);`? – C.M. Sep 30 '22 at 22:37
-1
// 1. smol used, not async_std.
// 2. IO should be 'static.
// 3. when timeout, read_poll return Poll::Ready::Err(io::ErrorKind::Timeout)

use {
    smol::{future::FutureExt, io, ready, Timer},
    std::{
        future::Future,
        pin::Pin,
        task::{Context, Poll},
        time::Duration,
    },
};

// --

pub struct PrudentIo<IO> {
    expired: Option<Pin<Box<dyn Future<Output = io::Result<usize>>>>>,
    timeout: Duration,
    io: IO,
}

impl<IO> PrudentIo<IO> {
    pub fn new(timeout: Duration, io: IO) -> Self {
        PrudentIo {
            expired: None,
            timeout,
            io,
        }
    }
}

impl<IO: io::AsyncRead + Unpin + 'static> io::AsyncRead for PrudentIo<IO> {
    fn poll_read(
        self: Pin<&mut Self>,
        cx: &mut Context<'_>,
        buf: &mut [u8],
    ) -> Poll<io::Result<usize>> {
        let this = self.get_mut();
        loop {
            if let Some(expired) = this.expired.as_mut() {
                let res = ready!(expired.poll(cx))?;
                this.expired.take();
                return Ok(res).into();
            }
            let timeout = this.timeout.clone();
            let (io, read_buf) = unsafe {
                // Safety: ONLY used in poll_read method.
                (&mut *(&mut this.io as *mut IO), &mut *(buf as *mut [u8]))
            };
            let fut = async move {
                let timeout_fut = async {
                    Timer::after(timeout).await;
                    io::Result::<usize>::Err(io::ErrorKind::TimedOut.into())
                };
                let read_fut = io::AsyncReadExt::read(io, read_buf);
                let res = read_fut.or(timeout_fut).await;
                res
            }
            .boxed_local();
            this.expired = Some(fut);
        }
    }
}
impl<IO: io::AsyncWrite + Unpin> io::AsyncWrite for PrudentIo<IO> {
    fn poll_write(
        mut self: Pin<&mut Self>,
        cx: &mut Context<'_>,
        buf: &[u8],
    ) -> Poll<io::Result<usize>> {
        Pin::new(&mut self.io).poll_write(cx, buf)
    }

    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
        Pin::new(&mut self.io).poll_flush(cx)
    }

    fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
        Pin::new(&mut self.io).poll_close(cx)
    }
}
moto
  • 9
  • 2
  • Thanks @moto, I assume this would work, but I'm staying away from unsafe as it is hard for me to reason about it. To be safe, I've solved such situations in the past with an enum reflecting the state - Reading(fut->(len,buf,io)) / Ready(io), It is difficult to handle the buffer then, copy is necessary. Do I understand it right that you keep a mutable reference to the buf in the future beyond the call to `poll_read`? I guess this might work most of the time, but then fail most unpredictably. So I find it unsound. – Robert Cutajar Mar 07 '22 at 08:00