I'm trying to implement an async read wrapper that will add read timeout functionality. The objective is that the API is plain AsyncRead
. In other words, I don't want to add io.read(buf).timeout(t) everywehere in the code. Instead, the read instance itself should return the appropriate io::ErrorKind::TimedOut
after the given timeout expires.
I can't poll the delay
to Ready though. It's always Pending. I've tried with async-std
, futures
, smol-timeout
- the same result. While the timeout does trigger when awaited, it just doesn't when polled. I know timeouts aren't easy. Something needs to wake it up. What am I doing wrong? How to pull this through?
use async_std::{
future::Future,
io,
pin::Pin,
task::{sleep, Context, Poll},
};
use std::time::Duration;
pub struct PrudentIo<IO> {
expired: Option<Pin<Box<dyn Future<Output = ()> + Sync + Send>>>,
timeout: Duration,
io: IO,
}
impl<IO> PrudentIo<IO> {
pub fn new(timeout: Duration, io: IO) -> Self {
PrudentIo {
expired: None,
timeout,
io,
}
}
}
fn delay(t: Duration) -> Option<Pin<Box<dyn Future<Output = ()> + Sync + Send + 'static>>> {
if t.is_zero() {
return None;
}
Some(Box::pin(sleep(t)))
}
impl<IO: io::Read + Unpin> io::Read for PrudentIo<IO> {
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut [u8],
) -> Poll<io::Result<usize>> {
if let Some(ref mut expired) = self.expired {
match expired.as_mut().poll(cx) {
Poll::Ready(_) => {
println!("expired ready");
// too much time passed since last read/write
return Poll::Ready(Err(io::ErrorKind::TimedOut.into()));
}
Poll::Pending => {
println!("expired pending");
// in good time
}
}
}
let res = Pin::new(&mut self.io).poll_read(cx, buf);
println!("read {:?}", res);
match res {
Poll::Pending => {
if self.expired.is_none() {
// No data, start checking for a timeout
self.expired = delay(self.timeout);
}
}
Poll::Ready(_) => self.expired = None,
}
res
}
}
impl<IO: io::Write + Unpin> io::Write for PrudentIo<IO> {
fn poll_write(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
Pin::new(&mut self.io).poll_write(cx, buf)
}
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
Pin::new(&mut self.io).poll_flush(cx)
}
fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
Pin::new(&mut self.io).poll_close(cx)
}
}
#[cfg(test)]
mod io_tests {
use super::*;
use async_std::io::ReadExt;
use async_std::prelude::FutureExt;
use async_std::{
io::{copy, Cursor},
net::TcpStream,
};
use std::time::Duration;
#[async_std::test]
async fn fail_read_after_timeout() -> io::Result<()> {
let mut output = b"______".to_vec();
let io = PendIo;
let mut io = PrudentIo::new(Duration::from_millis(5), io);
let mut io = Pin::new(&mut io);
insta::assert_debug_snapshot!(io.read(&mut output[..]).timeout(Duration::from_secs(1)).await,@"Ok(io::Err(timeou))");
Ok(())
}
#[async_std::test]
async fn timeout_expires() {
let later = delay(Duration::from_millis(1)).expect("some").await;
insta::assert_debug_snapshot!(later,@r"()");
}
/// Mock IO always pending
struct PendIo;
impl io::Read for PendIo {
fn poll_read(
self: Pin<&mut Self>,
_cx: &mut Context<'_>,
_buf: &mut [u8],
) -> Poll<futures_io::Result<usize>> {
Poll::Pending
}
}
impl io::Write for PendIo {
fn poll_write(
self: Pin<&mut Self>,
_cx: &mut Context<'_>,
_buf: &[u8],
) -> Poll<futures_io::Result<usize>> {
Poll::Pending
}
fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<futures_io::Result<()>> {
Poll::Pending
}
fn poll_close(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<futures_io::Result<()>> {
Poll::Pending
}
}
}