Skip to content

Commit

Permalink
io: add write_all_buf to AsyncWriteExt (#3737)
Browse files Browse the repository at this point in the history
  • Loading branch information
rcoh authored May 1, 2021
1 parent a08ce0d commit 14bb2f6
Show file tree
Hide file tree
Showing 4 changed files with 209 additions and 2 deletions.
58 changes: 56 additions & 2 deletions tokio/src/io/util/async_write_ext.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use crate::io::util::flush::{flush, Flush};
use crate::io::util::shutdown::{shutdown, Shutdown};
use crate::io::util::write::{write, Write};
use crate::io::util::write_all::{write_all, WriteAll};
use crate::io::util::write_all_buf::{write_all_buf, WriteAllBuf};
use crate::io::util::write_buf::{write_buf, WriteBuf};
use crate::io::util::write_int::{
WriteI128, WriteI128Le, WriteI16, WriteI16Le, WriteI32, WriteI32Le, WriteI64, WriteI64Le,
Expand Down Expand Up @@ -159,7 +160,6 @@ cfg_io_util! {
write_vectored(self, bufs)
}


/// Writes a buffer into this writer, advancing the buffer's internal
/// cursor.
///
Expand Down Expand Up @@ -197,10 +197,11 @@ cfg_io_util! {
///
/// # Examples
///
/// [`File`] implements `Read` and [`Cursor<&[u8]>`] implements [`Buf`]:
/// [`File`] implements [`AsyncWrite`] and [`Cursor`]`<&[u8]>` implements [`Buf`]:
///
/// [`File`]: crate::fs::File
/// [`Buf`]: bytes::Buf
/// [`Cursor`]: std::io::Cursor
///
/// ```no_run
/// use tokio::io::{self, AsyncWriteExt};
Expand Down Expand Up @@ -233,6 +234,59 @@ cfg_io_util! {
write_buf(self, src)
}

/// Attempts to write an entire buffer into this writer
///
/// Equivalent to:
///
/// ```ignore
/// async fn write_all_buf(&mut self, buf: impl Buf) -> Result<(), io::Error> {
/// while buf.has_remaining() {
/// self.write_buf(&mut buf).await?;
/// }
/// }
/// ```
///
/// This method will continuously call [`write`] until
/// [`buf.has_remaining()`](bytes::Buf::has_remaining) returns false. This method will not
/// return until the entire buffer has been successfully written or an error occurs. The
/// first error generated will be returned.
///
/// The buffer is advanced after each chunk is successfully written. After failure,
/// `src.chunk()` will return the chunk that failed to write.
///
/// # Examples
///
/// [`File`] implements [`AsyncWrite`] and [`Cursor`]`<&[u8]>` implements [`Buf`]:
///
/// [`File`]: crate::fs::File
/// [`Buf`]: bytes::Buf
/// [`Cursor`]: std::io::Cursor
///
/// ```no_run
/// use tokio::io::{self, AsyncWriteExt};
/// use tokio::fs::File;
///
/// use std::io::Cursor;
///
/// #[tokio::main]
/// async fn main() -> io::Result<()> {
/// let mut file = File::create("foo.txt").await?;
/// let mut buffer = Cursor::new(b"data to write");
///
/// file.write_all_buf(&mut buffer).await?;
/// Ok(())
/// }
/// ```
///
/// [`write`]: AsyncWriteExt::write
fn write_all_buf<'a, B>(&'a mut self, src: &'a mut B) -> WriteAllBuf<'a, Self, B>
where
Self: Sized + Unpin,
B: Buf,
{
write_all_buf(self, src)
}

/// Attempts to write an entire buffer into this writer.
///
/// Equivalent to:
Expand Down
1 change: 1 addition & 0 deletions tokio/src/io/util/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ cfg_io_util! {
mod write_vectored;
mod write_all;
mod write_buf;
mod write_all_buf;
mod write_int;


Expand Down
56 changes: 56 additions & 0 deletions tokio/src/io/util/write_all_buf.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
use crate::io::AsyncWrite;

use bytes::Buf;
use pin_project_lite::pin_project;
use std::future::Future;
use std::io;
use std::marker::PhantomPinned;
use std::pin::Pin;
use std::task::{Context, Poll};

pin_project! {
/// A future to write some of the buffer to an `AsyncWrite`.
#[derive(Debug)]
#[must_use = "futures do nothing unless you `.await` or poll them"]
pub struct WriteAllBuf<'a, W, B> {
writer: &'a mut W,
buf: &'a mut B,
#[pin]
_pin: PhantomPinned,
}
}

/// Tries to write some bytes from the given `buf` to the writer in an
/// asynchronous manner, returning a future.
pub(crate) fn write_all_buf<'a, W, B>(writer: &'a mut W, buf: &'a mut B) -> WriteAllBuf<'a, W, B>
where
W: AsyncWrite + Unpin,
B: Buf,
{
WriteAllBuf {
writer,
buf,
_pin: PhantomPinned,
}
}

impl<W, B> Future for WriteAllBuf<'_, W, B>
where
W: AsyncWrite + Unpin,
B: Buf,
{
type Output = io::Result<()>;

fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
let me = self.project();
while me.buf.has_remaining() {
let n = ready!(Pin::new(&mut *me.writer).poll_write(cx, me.buf.chunk())?);
me.buf.advance(n);
if n == 0 {
return Poll::Ready(Err(io::ErrorKind::WriteZero.into()));
}
}

Poll::Ready(Ok(()))
}
}
96 changes: 96 additions & 0 deletions tokio/tests/io_write_all_buf.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
#![warn(rust_2018_idioms)]
#![cfg(feature = "full")]

use tokio::io::{AsyncWrite, AsyncWriteExt};
use tokio_test::{assert_err, assert_ok};

use bytes::{Buf, Bytes, BytesMut};
use std::cmp;
use std::io;
use std::pin::Pin;
use std::task::{Context, Poll};

#[tokio::test]
async fn write_all_buf() {
struct Wr {
buf: BytesMut,
cnt: usize,
}

impl AsyncWrite for Wr {
fn poll_write(
mut self: Pin<&mut Self>,
_cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
let n = cmp::min(4, buf.len());
dbg!(buf);
let buf = &buf[0..n];

self.cnt += 1;
self.buf.extend(buf);
Ok(buf.len()).into()
}

fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
Ok(()).into()
}

fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
Ok(()).into()
}
}

let mut wr = Wr {
buf: BytesMut::with_capacity(64),
cnt: 0,
};

let mut buf = Bytes::from_static(b"hello").chain(Bytes::from_static(b"world"));

assert_ok!(wr.write_all_buf(&mut buf).await);
assert_eq!(wr.buf, b"helloworld"[..]);
// expect 4 writes, [hell],[o],[worl],[d]
assert_eq!(wr.cnt, 4);
assert_eq!(buf.has_remaining(), false);
}

#[tokio::test]
async fn write_buf_err() {
/// Error out after writing the first 4 bytes
struct Wr {
cnt: usize,
}

impl AsyncWrite for Wr {
fn poll_write(
mut self: Pin<&mut Self>,
_cx: &mut Context<'_>,
_buf: &[u8],
) -> Poll<io::Result<usize>> {
self.cnt += 1;
if self.cnt == 2 {
return Poll::Ready(Err(io::Error::new(io::ErrorKind::Other, "whoops")));
}
Poll::Ready(Ok(4))
}

fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
Ok(()).into()
}

fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
Ok(()).into()
}
}

let mut wr = Wr { cnt: 0 };

let mut buf = Bytes::from_static(b"hello").chain(Bytes::from_static(b"world"));

assert_err!(wr.write_all_buf(&mut buf).await);
assert_eq!(
buf.copy_to_bytes(buf.remaining()),
Bytes::from_static(b"oworld")
);
}

0 comments on commit 14bb2f6

Please sign in to comment.