Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

io: Add write_all_buf to AsyncWriteExt #3737

Merged
merged 6 commits into from
May 1, 2021
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 53 additions & 1 deletion tokio/src/io/util/async_write_ext.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use crate::io::util::flush::{flush, Flush};
use crate::io::util::shutdown::{shutdown, Shutdown};
use crate::io::util::write::{write, Write};
use crate::io::util::write_all::{write_all, WriteAll};
use crate::io::util::write_all_buf::{write_all_buf, WriteAllBuf};
use crate::io::util::write_buf::{write_buf, WriteBuf};
use crate::io::util::write_int::{
WriteI128, WriteI128Le, WriteI16, WriteI16Le, WriteI32, WriteI32Le, WriteI64, WriteI64Le,
Expand Down Expand Up @@ -159,7 +160,6 @@ cfg_io_util! {
write_vectored(self, bufs)
}


/// Writes a buffer into this writer, advancing the buffer's internal
/// cursor.
///
Expand Down Expand Up @@ -233,6 +233,58 @@ cfg_io_util! {
write_buf(self, src)
}

/// Attempts to write an entire buffer into this writer
///
/// Equivalent to:
///
/// ```ignore
/// async fn write_all_buf(&mut self, buf: impl Buf) -> Result<(), io::Error> {
/// while buf.has_remaining() {
/// self.write_buf(&mut buf).await?;
/// }
/// }
/// ```
///
/// This method will continuously call [`write`] until
/// [`buf.has_remaining()`](bytes::Buf::has_remaining) returns false. This method will not
/// return until the entire buffer has been successfully written or an error occurs. The
/// first error generated will be returned.
///
/// The buffer is advanced after each chunk is successfully written. After failure,
/// `src.chunk()` will return the chunk that failed to write.
///
/// # Examples
///
/// [`File`] implements `Read` and [`Cursor<&[u8]>`] implements [`Buf`]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the cursor link here is broken.

Copy link
Contributor Author

@rcoh rcoh May 1, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good catch. It was broken for write_buf as well (& it was just wrong, I think? I believe it meant to say file implements AsyncWrite`).

I fixed both.

///
/// [`File`]: crate::fs::File
/// [`Buf`]: bytes::Buf
///
/// ```no_run
/// use tokio::io::{self, AsyncWriteExt};
/// use tokio::fs::File;
///
/// use std::io::Cursor;
///
/// #[tokio::main]
/// async fn main() -> io::Result<()> {
/// let mut file = File::create("foo.txt").await?;
/// let mut buffer = Cursor::new(b"data to write");
///
/// file.write_all_buf(&mut buffer).await?;
/// Ok(())
/// }
/// ```
///
/// [`write`]: AsyncWriteExt::write
fn write_all_buf<'a, B>(&'a mut self, src: &'a mut B) -> WriteAllBuf<'a, Self, B>
where
Self: Sized + Unpin,
B: Buf,
{
write_all_buf(self, src)
}

/// Attempts to write an entire buffer into this writer.
///
/// Equivalent to:
Expand Down
1 change: 1 addition & 0 deletions tokio/src/io/util/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ cfg_io_util! {
mod write_vectored;
mod write_all;
mod write_buf;
mod write_all_buf;
mod write_int;


Expand Down
56 changes: 56 additions & 0 deletions tokio/src/io/util/write_all_buf.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
use crate::io::AsyncWrite;

use bytes::Buf;
use pin_project_lite::pin_project;
use std::future::Future;
use std::io;
use std::marker::PhantomPinned;
use std::pin::Pin;
use std::task::{Context, Poll};

pin_project! {
/// A future to write some of the buffer to an `AsyncWrite`.
#[derive(Debug)]
#[must_use = "futures do nothing unless you `.await` or poll them"]
pub struct WriteAllBuf<'a, W, B> {
writer: &'a mut W,
buf: &'a mut B,
#[pin]
_pin: PhantomPinned,
}
}

/// Tries to write some bytes from the given `buf` to the writer in an
/// asynchronous manner, returning a future.
pub(crate) fn write_all_buf<'a, W, B>(writer: &'a mut W, buf: &'a mut B) -> WriteAllBuf<'a, W, B>
where
W: AsyncWrite + Unpin,
B: Buf,
{
WriteAllBuf {
writer,
buf,
_pin: PhantomPinned,
}
}

impl<W, B> Future for WriteAllBuf<'_, W, B>
where
W: AsyncWrite + Unpin,
B: Buf,
{
type Output = io::Result<()>;

fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
let me = self.project();
while me.buf.has_remaining() {
let n = ready!(Pin::new(&mut *me.writer).poll_write(cx, me.buf.chunk())?);
me.buf.advance(n);
if n == 0 {
return Poll::Ready(Err(io::ErrorKind::WriteZero.into()));
}
}

Poll::Ready(Ok(()))
}
}
96 changes: 96 additions & 0 deletions tokio/tests/io_write_all_buf.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
#![warn(rust_2018_idioms)]
#![cfg(feature = "full")]

use tokio::io::{AsyncWrite, AsyncWriteExt};
use tokio_test::{assert_err, assert_ok};

use bytes::{Buf, Bytes, BytesMut};
use std::cmp;
use std::io;
use std::pin::Pin;
use std::task::{Context, Poll};

#[tokio::test]
async fn write_all_buf() {
struct Wr {
buf: BytesMut,
cnt: usize,
}

impl AsyncWrite for Wr {
fn poll_write(
mut self: Pin<&mut Self>,
_cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
let n = cmp::min(4, buf.len());
dbg!(buf);
let buf = &buf[0..n];

self.cnt += 1;
self.buf.extend(buf);
Ok(buf.len()).into()
}

fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
Ok(()).into()
}

fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
Ok(()).into()
}
}

let mut wr = Wr {
buf: BytesMut::with_capacity(64),
cnt: 0,
};

let mut buf = Bytes::from_static(b"hello").chain(Bytes::from_static(b"world"));

assert_ok!(wr.write_all_buf(&mut buf).await);
assert_eq!(wr.buf, b"helloworld"[..]);
// expect 4 writes, [hell],[o],[worl],[d]
assert_eq!(wr.cnt, 4);
assert_eq!(buf.has_remaining(), false);
}

#[tokio::test]
async fn write_buf_err() {
/// Error out after writing the first 4 bytes
struct Wr {
cnt: usize,
}

impl AsyncWrite for Wr {
fn poll_write(
mut self: Pin<&mut Self>,
_cx: &mut Context<'_>,
_buf: &[u8],
) -> Poll<io::Result<usize>> {
self.cnt += 1;
if self.cnt == 2 {
return Poll::Ready(Err(io::Error::new(io::ErrorKind::Other, "whoops")));
}
Poll::Ready(Ok(4))
}

fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
Ok(()).into()
}

fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
Ok(()).into()
}
}

let mut wr = Wr { cnt: 0 };

let mut buf = Bytes::from_static(b"hello").chain(Bytes::from_static(b"world"));

assert_err!(wr.write_all_buf(&mut buf).await);
assert_eq!(
buf.copy_to_bytes(buf.remaining()),
Bytes::from_static(b"oworld")
);
}