tokio_test/
io.rs

1#![cfg(not(loom))]
2
3//! A mock type implementing [`AsyncRead`] and [`AsyncWrite`].
4//!
5//!
6//! # Overview
7//!
8//! Provides a type that implements [`AsyncRead`] + [`AsyncWrite`] that can be configured
9//! to handle an arbitrary sequence of read and write operations. This is useful
10//! for writing unit tests for networking services as using an actual network
11//! type is fairly non deterministic.
12//!
13//! # Usage
14//!
15//! Attempting to write data that the mock isn't expecting will result in a
16//! panic.
17//!
18//! [`AsyncRead`]: tokio::io::AsyncRead
19//! [`AsyncWrite`]: tokio::io::AsyncWrite
20
21use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
22use tokio::sync::mpsc;
23use tokio::time::{self, Duration, Instant, Sleep};
24use tokio_stream::wrappers::UnboundedReceiverStream;
25
26use futures_core::Stream;
27use std::collections::VecDeque;
28use std::fmt;
29use std::future::Future;
30use std::pin::Pin;
31use std::sync::Arc;
32use std::task::{self, ready, Poll, Waker};
33use std::{cmp, io};
34
35/// An I/O object that follows a predefined script.
36///
37/// This value is created by `Builder` and implements `AsyncRead` + `AsyncWrite`. It
38/// follows the scenario described by the builder and panics otherwise.
39#[derive(Debug)]
40pub struct Mock {
41    inner: Inner,
42}
43
44/// A handle to send additional actions to the related `Mock`.
45#[derive(Debug)]
46pub struct Handle {
47    tx: mpsc::UnboundedSender<Action>,
48}
49
50/// Builds `Mock` instances.
51#[derive(Debug, Clone, Default)]
52pub struct Builder {
53    // Sequence of actions for the Mock to take
54    actions: VecDeque<Action>,
55    name: String,
56}
57
58#[derive(Debug, Clone)]
59enum Action {
60    Read(Vec<u8>),
61    Write(Vec<u8>),
62    Wait(Duration),
63    // Wrapped in Arc so that Builder can be cloned and Send.
64    // Mock is not cloned as does not need to check Rc for ref counts.
65    ReadError(Option<Arc<io::Error>>),
66    WriteError(Option<Arc<io::Error>>),
67}
68
69struct Inner {
70    actions: VecDeque<Action>,
71    waiting: Option<Instant>,
72    sleep: Option<Pin<Box<Sleep>>>,
73    read_wait: Option<Waker>,
74    rx: UnboundedReceiverStream<Action>,
75    name: String,
76}
77
78impl Builder {
79    /// Return a new, empty `Builder`.
80    pub fn new() -> Self {
81        Self::default()
82    }
83
84    /// Sequence a `read` operation.
85    ///
86    /// The next operation in the mock's script will be to expect a `read` call
87    /// and return `buf`.
88    pub fn read(&mut self, buf: &[u8]) -> &mut Self {
89        self.actions.push_back(Action::Read(buf.into()));
90        self
91    }
92
93    /// Sequence a `read` operation that produces an error.
94    ///
95    /// The next operation in the mock's script will be to expect a `read` call
96    /// and return `error`.
97    pub fn read_error(&mut self, error: io::Error) -> &mut Self {
98        let error = Some(error.into());
99        self.actions.push_back(Action::ReadError(error));
100        self
101    }
102
103    /// Sequence a `write` operation.
104    ///
105    /// The next operation in the mock's script will be to expect a `write`
106    /// call.
107    pub fn write(&mut self, buf: &[u8]) -> &mut Self {
108        self.actions.push_back(Action::Write(buf.into()));
109        self
110    }
111
112    /// Sequence a `write` operation that produces an error.
113    ///
114    /// The next operation in the mock's script will be to expect a `write`
115    /// call that provides `error`.
116    pub fn write_error(&mut self, error: io::Error) -> &mut Self {
117        let error = Some(error.into());
118        self.actions.push_back(Action::WriteError(error));
119        self
120    }
121
122    /// Sequence a wait.
123    ///
124    /// The next operation in the mock's script will be to wait without doing so
125    /// for `duration` amount of time.
126    pub fn wait(&mut self, duration: Duration) -> &mut Self {
127        let duration = cmp::max(duration, Duration::from_millis(1));
128        self.actions.push_back(Action::Wait(duration));
129        self
130    }
131
132    /// Set name of the mock IO object to include in panic messages and debug output
133    pub fn name(&mut self, name: impl Into<String>) -> &mut Self {
134        self.name = name.into();
135        self
136    }
137
138    /// Build a `Mock` value according to the defined script.
139    pub fn build(&mut self) -> Mock {
140        let (mock, _) = self.build_with_handle();
141        mock
142    }
143
144    /// Build a `Mock` value paired with a handle
145    pub fn build_with_handle(&mut self) -> (Mock, Handle) {
146        let (inner, handle) = Inner::new(self.actions.clone(), self.name.clone());
147
148        let mock = Mock { inner };
149
150        (mock, handle)
151    }
152}
153
154impl Handle {
155    /// Sequence a `read` operation.
156    ///
157    /// The next operation in the mock's script will be to expect a `read` call
158    /// and return `buf`.
159    pub fn read(&mut self, buf: &[u8]) -> &mut Self {
160        self.tx.send(Action::Read(buf.into())).unwrap();
161        self
162    }
163
164    /// Sequence a `read` operation error.
165    ///
166    /// The next operation in the mock's script will be to expect a `read` call
167    /// and return `error`.
168    pub fn read_error(&mut self, error: io::Error) -> &mut Self {
169        let error = Some(error.into());
170        self.tx.send(Action::ReadError(error)).unwrap();
171        self
172    }
173
174    /// Sequence a `write` operation.
175    ///
176    /// The next operation in the mock's script will be to expect a `write`
177    /// call.
178    pub fn write(&mut self, buf: &[u8]) -> &mut Self {
179        self.tx.send(Action::Write(buf.into())).unwrap();
180        self
181    }
182
183    /// Sequence a `write` operation error.
184    ///
185    /// The next operation in the mock's script will be to expect a `write`
186    /// call error.
187    pub fn write_error(&mut self, error: io::Error) -> &mut Self {
188        let error = Some(error.into());
189        self.tx.send(Action::WriteError(error)).unwrap();
190        self
191    }
192}
193
194impl Inner {
195    fn new(actions: VecDeque<Action>, name: String) -> (Inner, Handle) {
196        let (tx, rx) = mpsc::unbounded_channel();
197
198        let rx = UnboundedReceiverStream::new(rx);
199
200        let inner = Inner {
201            actions,
202            sleep: None,
203            read_wait: None,
204            rx,
205            waiting: None,
206            name,
207        };
208
209        let handle = Handle { tx };
210
211        (inner, handle)
212    }
213
214    fn poll_action(&mut self, cx: &mut task::Context<'_>) -> Poll<Option<Action>> {
215        Pin::new(&mut self.rx).poll_next(cx)
216    }
217
218    fn read(&mut self, dst: &mut ReadBuf<'_>) -> io::Result<()> {
219        match self.action() {
220            Some(&mut Action::Read(ref mut data)) => {
221                // Figure out how much to copy
222                let n = cmp::min(dst.remaining(), data.len());
223
224                // Copy the data into the `dst` slice
225                dst.put_slice(&data[..n]);
226
227                // Drain the data from the source
228                data.drain(..n);
229
230                Ok(())
231            }
232            Some(&mut Action::ReadError(ref mut err)) => {
233                // As the
234                let err = err.take().expect("Should have been removed from actions.");
235                let err = Arc::try_unwrap(err).expect("There are no other references.");
236                Err(err)
237            }
238            Some(_) => {
239                // Either waiting or expecting a write
240                Err(io::ErrorKind::WouldBlock.into())
241            }
242            None => Ok(()),
243        }
244    }
245
246    fn write(&mut self, mut src: &[u8]) -> io::Result<usize> {
247        let mut ret = 0;
248
249        if self.actions.is_empty() {
250            return Err(io::ErrorKind::BrokenPipe.into());
251        }
252
253        if let Some(&mut Action::Wait(..)) = self.action() {
254            return Err(io::ErrorKind::WouldBlock.into());
255        }
256
257        if let Some(&mut Action::WriteError(ref mut err)) = self.action() {
258            let err = err.take().expect("Should have been removed from actions.");
259            let err = Arc::try_unwrap(err).expect("There are no other references.");
260            return Err(err);
261        }
262
263        for i in 0..self.actions.len() {
264            match self.actions[i] {
265                Action::Write(ref mut expect) => {
266                    let n = cmp::min(src.len(), expect.len());
267
268                    assert_eq!(&src[..n], &expect[..n], "name={} i={}", self.name, i);
269
270                    // Drop data that was matched
271                    expect.drain(..n);
272                    src = &src[n..];
273
274                    ret += n;
275
276                    if src.is_empty() {
277                        return Ok(ret);
278                    }
279                }
280                Action::Wait(..) | Action::WriteError(..) => {
281                    break;
282                }
283                _ => {}
284            }
285
286            // TODO: remove write
287        }
288
289        Ok(ret)
290    }
291
292    fn remaining_wait(&mut self) -> Option<Duration> {
293        match self.action() {
294            Some(&mut Action::Wait(dur)) => Some(dur),
295            _ => None,
296        }
297    }
298
299    fn action(&mut self) -> Option<&mut Action> {
300        loop {
301            if self.actions.is_empty() {
302                return None;
303            }
304
305            match self.actions[0] {
306                Action::Read(ref mut data) => {
307                    if !data.is_empty() {
308                        break;
309                    }
310                }
311                Action::Write(ref mut data) => {
312                    if !data.is_empty() {
313                        break;
314                    }
315                }
316                Action::Wait(ref mut dur) => {
317                    if let Some(until) = self.waiting {
318                        let now = Instant::now();
319
320                        if now < until {
321                            break;
322                        } else {
323                            self.waiting = None;
324                        }
325                    } else {
326                        self.waiting = Some(Instant::now() + *dur);
327                        break;
328                    }
329                }
330                Action::ReadError(ref mut error) | Action::WriteError(ref mut error) => {
331                    if error.is_some() {
332                        break;
333                    }
334                }
335            }
336
337            let _action = self.actions.pop_front();
338        }
339
340        self.actions.front_mut()
341    }
342}
343
344// ===== impl Inner =====
345
346impl Mock {
347    fn maybe_wakeup_reader(&mut self) {
348        match self.inner.action() {
349            Some(&mut Action::Read(_)) | Some(&mut Action::ReadError(_)) | None => {
350                if let Some(waker) = self.inner.read_wait.take() {
351                    waker.wake();
352                }
353            }
354            _ => {}
355        }
356    }
357}
358
359impl AsyncRead for Mock {
360    fn poll_read(
361        mut self: Pin<&mut Self>,
362        cx: &mut task::Context<'_>,
363        buf: &mut ReadBuf<'_>,
364    ) -> Poll<io::Result<()>> {
365        loop {
366            if let Some(ref mut sleep) = self.inner.sleep {
367                ready!(Pin::new(sleep).poll(cx));
368            }
369
370            // If a sleep is set, it has already fired
371            self.inner.sleep = None;
372
373            // Capture 'filled' to monitor if it changed
374            let filled = buf.filled().len();
375
376            match self.inner.read(buf) {
377                Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => {
378                    if let Some(rem) = self.inner.remaining_wait() {
379                        let until = Instant::now() + rem;
380                        self.inner.sleep = Some(Box::pin(time::sleep_until(until)));
381                    } else {
382                        self.inner.read_wait = Some(cx.waker().clone());
383                        return Poll::Pending;
384                    }
385                }
386                Ok(()) => {
387                    if buf.filled().len() == filled {
388                        match ready!(self.inner.poll_action(cx)) {
389                            Some(action) => {
390                                self.inner.actions.push_back(action);
391                                continue;
392                            }
393                            None => {
394                                return Poll::Ready(Ok(()));
395                            }
396                        }
397                    } else {
398                        return Poll::Ready(Ok(()));
399                    }
400                }
401                Err(e) => return Poll::Ready(Err(e)),
402            }
403        }
404    }
405}
406
407impl AsyncWrite for Mock {
408    fn poll_write(
409        mut self: Pin<&mut Self>,
410        cx: &mut task::Context<'_>,
411        buf: &[u8],
412    ) -> Poll<io::Result<usize>> {
413        loop {
414            if let Some(ref mut sleep) = self.inner.sleep {
415                ready!(Pin::new(sleep).poll(cx));
416            }
417
418            // If a sleep is set, it has already fired
419            self.inner.sleep = None;
420
421            if self.inner.actions.is_empty() {
422                match self.inner.poll_action(cx) {
423                    Poll::Pending => {
424                        // do not propagate pending
425                    }
426                    Poll::Ready(Some(action)) => {
427                        self.inner.actions.push_back(action);
428                    }
429                    Poll::Ready(None) => {
430                        panic!("unexpected write {}", self.pmsg());
431                    }
432                }
433            }
434
435            match self.inner.write(buf) {
436                Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => {
437                    if let Some(rem) = self.inner.remaining_wait() {
438                        let until = Instant::now() + rem;
439                        self.inner.sleep = Some(Box::pin(time::sleep_until(until)));
440                    } else {
441                        panic!("unexpected WouldBlock {}", self.pmsg());
442                    }
443                }
444                Ok(0) => {
445                    // TODO: Is this correct?
446                    if !self.inner.actions.is_empty() {
447                        return Poll::Pending;
448                    }
449
450                    // TODO: Extract
451                    match ready!(self.inner.poll_action(cx)) {
452                        Some(action) => {
453                            self.inner.actions.push_back(action);
454                            continue;
455                        }
456                        None => {
457                            panic!("unexpected write {}", self.pmsg());
458                        }
459                    }
460                }
461                ret => {
462                    self.maybe_wakeup_reader();
463                    return Poll::Ready(ret);
464                }
465            }
466        }
467    }
468
469    fn poll_flush(self: Pin<&mut Self>, _cx: &mut task::Context<'_>) -> Poll<io::Result<()>> {
470        Poll::Ready(Ok(()))
471    }
472
473    fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut task::Context<'_>) -> Poll<io::Result<()>> {
474        Poll::Ready(Ok(()))
475    }
476}
477
478/// Ensures that Mock isn't dropped with data "inside".
479impl Drop for Mock {
480    fn drop(&mut self) {
481        // Avoid double panicking, since makes debugging much harder.
482        if std::thread::panicking() {
483            return;
484        }
485
486        self.inner.actions.iter().for_each(|a| match a {
487            Action::Read(data) => assert!(
488                data.is_empty(),
489                "There is still data left to read. {}",
490                self.pmsg()
491            ),
492            Action::Write(data) => assert!(
493                data.is_empty(),
494                "There is still data left to write. {}",
495                self.pmsg()
496            ),
497            _ => (),
498        });
499    }
500}
501/*
502/// Returns `true` if called from the context of a futures-rs Task
503fn is_task_ctx() -> bool {
504    use std::panic;
505
506    // Save the existing panic hook
507    let h = panic::take_hook();
508
509    // Install a new one that does nothing
510    panic::set_hook(Box::new(|_| {}));
511
512    // Attempt to call the fn
513    let r = panic::catch_unwind(|| task::current()).is_ok();
514
515    // Re-install the old one
516    panic::set_hook(h);
517
518    // Return the result
519    r
520}
521*/
522
523impl fmt::Debug for Inner {
524    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
525        if self.name.is_empty() {
526            write!(f, "Inner {{...}}")
527        } else {
528            write!(f, "Inner {{name={}, ...}}", self.name)
529        }
530    }
531}
532
533struct PanicMsgSnippet<'a>(&'a Inner);
534
535impl<'a> fmt::Display for PanicMsgSnippet<'a> {
536    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
537        if self.0.name.is_empty() {
538            write!(f, "({} actions remain)", self.0.actions.len())
539        } else {
540            write!(
541                f,
542                "(name {}, {} actions remain)",
543                self.0.name,
544                self.0.actions.len()
545            )
546        }
547    }
548}
549
550impl Mock {
551    fn pmsg(&self) -> PanicMsgSnippet<'_> {
552        PanicMsgSnippet(&self.inner)
553    }
554}