-
-
Notifications
You must be signed in to change notification settings - Fork 2.5k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
io: add write_all_buf to AsyncWriteExt (#3737)
- Loading branch information
Showing
4 changed files
with
209 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -77,6 +77,7 @@ cfg_io_util! { | |
mod write_vectored; | ||
mod write_all; | ||
mod write_buf; | ||
mod write_all_buf; | ||
mod write_int; | ||
|
||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,56 @@ | ||
use crate::io::AsyncWrite; | ||
|
||
use bytes::Buf; | ||
use pin_project_lite::pin_project; | ||
use std::future::Future; | ||
use std::io; | ||
use std::marker::PhantomPinned; | ||
use std::pin::Pin; | ||
use std::task::{Context, Poll}; | ||
|
||
pin_project! { | ||
/// A future to write some of the buffer to an `AsyncWrite`. | ||
#[derive(Debug)] | ||
#[must_use = "futures do nothing unless you `.await` or poll them"] | ||
pub struct WriteAllBuf<'a, W, B> { | ||
writer: &'a mut W, | ||
buf: &'a mut B, | ||
#[pin] | ||
_pin: PhantomPinned, | ||
} | ||
} | ||
|
||
/// Tries to write some bytes from the given `buf` to the writer in an | ||
/// asynchronous manner, returning a future. | ||
pub(crate) fn write_all_buf<'a, W, B>(writer: &'a mut W, buf: &'a mut B) -> WriteAllBuf<'a, W, B> | ||
where | ||
W: AsyncWrite + Unpin, | ||
B: Buf, | ||
{ | ||
WriteAllBuf { | ||
writer, | ||
buf, | ||
_pin: PhantomPinned, | ||
} | ||
} | ||
|
||
impl<W, B> Future for WriteAllBuf<'_, W, B> | ||
where | ||
W: AsyncWrite + Unpin, | ||
B: Buf, | ||
{ | ||
type Output = io::Result<()>; | ||
|
||
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> { | ||
let me = self.project(); | ||
while me.buf.has_remaining() { | ||
let n = ready!(Pin::new(&mut *me.writer).poll_write(cx, me.buf.chunk())?); | ||
me.buf.advance(n); | ||
if n == 0 { | ||
return Poll::Ready(Err(io::ErrorKind::WriteZero.into())); | ||
} | ||
} | ||
|
||
Poll::Ready(Ok(())) | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,96 @@ | ||
#![warn(rust_2018_idioms)] | ||
#![cfg(feature = "full")] | ||
|
||
use tokio::io::{AsyncWrite, AsyncWriteExt}; | ||
use tokio_test::{assert_err, assert_ok}; | ||
|
||
use bytes::{Buf, Bytes, BytesMut}; | ||
use std::cmp; | ||
use std::io; | ||
use std::pin::Pin; | ||
use std::task::{Context, Poll}; | ||
|
||
#[tokio::test] | ||
async fn write_all_buf() { | ||
struct Wr { | ||
buf: BytesMut, | ||
cnt: usize, | ||
} | ||
|
||
impl AsyncWrite for Wr { | ||
fn poll_write( | ||
mut self: Pin<&mut Self>, | ||
_cx: &mut Context<'_>, | ||
buf: &[u8], | ||
) -> Poll<io::Result<usize>> { | ||
let n = cmp::min(4, buf.len()); | ||
dbg!(buf); | ||
let buf = &buf[0..n]; | ||
|
||
self.cnt += 1; | ||
self.buf.extend(buf); | ||
Ok(buf.len()).into() | ||
} | ||
|
||
fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> { | ||
Ok(()).into() | ||
} | ||
|
||
fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> { | ||
Ok(()).into() | ||
} | ||
} | ||
|
||
let mut wr = Wr { | ||
buf: BytesMut::with_capacity(64), | ||
cnt: 0, | ||
}; | ||
|
||
let mut buf = Bytes::from_static(b"hello").chain(Bytes::from_static(b"world")); | ||
|
||
assert_ok!(wr.write_all_buf(&mut buf).await); | ||
assert_eq!(wr.buf, b"helloworld"[..]); | ||
// expect 4 writes, [hell],[o],[worl],[d] | ||
assert_eq!(wr.cnt, 4); | ||
assert_eq!(buf.has_remaining(), false); | ||
} | ||
|
||
#[tokio::test] | ||
async fn write_buf_err() { | ||
/// Error out after writing the first 4 bytes | ||
struct Wr { | ||
cnt: usize, | ||
} | ||
|
||
impl AsyncWrite for Wr { | ||
fn poll_write( | ||
mut self: Pin<&mut Self>, | ||
_cx: &mut Context<'_>, | ||
_buf: &[u8], | ||
) -> Poll<io::Result<usize>> { | ||
self.cnt += 1; | ||
if self.cnt == 2 { | ||
return Poll::Ready(Err(io::Error::new(io::ErrorKind::Other, "whoops"))); | ||
} | ||
Poll::Ready(Ok(4)) | ||
} | ||
|
||
fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> { | ||
Ok(()).into() | ||
} | ||
|
||
fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> { | ||
Ok(()).into() | ||
} | ||
} | ||
|
||
let mut wr = Wr { cnt: 0 }; | ||
|
||
let mut buf = Bytes::from_static(b"hello").chain(Bytes::from_static(b"world")); | ||
|
||
assert_err!(wr.write_all_buf(&mut buf).await); | ||
assert_eq!( | ||
buf.copy_to_bytes(buf.remaining()), | ||
Bytes::from_static(b"oworld") | ||
); | ||
} |