diff --git a/examples/cat.rs b/examples/cat.rs index d41b7ab1..33d02e38 100644 --- a/examples/cat.rs +++ b/examples/cat.rs @@ -29,7 +29,7 @@ fn main() { loop { // Read a chunk - let (res, b) = file.read_at(buf, pos).await; + let (res, b) = file.read_at(buf, pos).submit().await; let n = res.unwrap(); if n == 0 { diff --git a/examples/mix.rs b/examples/mix.rs index 4e094019..aaad7f60 100644 --- a/examples/mix.rs +++ b/examples/mix.rs @@ -34,7 +34,7 @@ fn main() { loop { // Read a chunk - let (res, b) = file.read_at(buf, pos).await; + let (res, b) = file.read_at(buf, pos).submit().await; let n = res.unwrap(); if n == 0 { diff --git a/src/fs/file.rs b/src/fs/file.rs index 9cd47f21..0582388a 100644 --- a/src/fs/file.rs +++ b/src/fs/file.rs @@ -4,7 +4,7 @@ use crate::fs::OpenOptions; use crate::io::SharedFd; use crate::runtime::driver::op::Op; -use crate::{UnsubmittedOneshot, UnsubmittedWrite}; +use crate::{UnsubmittedOneshot, UnsubmittedRead, UnsubmittedWrite}; use std::fmt; use std::io; use std::os::unix::io::{AsRawFd, FromRawFd, IntoRawFd, RawFd}; @@ -165,7 +165,7 @@ impl File { /// let buffer = vec![0; 10]; /// /// // Read up to 10 bytes - /// let (res, buffer) = f.read_at(buffer, 0).await; + /// let (res, buffer) = f.read_at(buffer, 0).submit().await; /// let n = res?; /// /// println!("The bytes: {:?}", &buffer[..n]); @@ -176,10 +176,8 @@ impl File { /// }) /// } /// ``` - pub async fn read_at(&self, buf: T, pos: u64) -> crate::BufResult { - // Submit the read operation - let op = Op::read_at(&self.fd, buf, pos).unwrap(); - op.await + pub fn read_at(&self, buf: T, pos: u64) -> UnsubmittedRead { + UnsubmittedOneshot::read_at(&self.fd, buf, pos) } /// Read some bytes at the specified offset from the file into the specified @@ -417,7 +415,7 @@ impl File { } while buf.bytes_total() != 0 { - let (res, slice) = self.read_at(buf, pos).await; + let (res, slice) = self.read_at(buf, pos).submit().await; match res { Ok(0) => { return ( diff --git a/src/io/mod.rs b/src/io/mod.rs index 6985bdd3..4b2272a7 100644 --- a/src/io/mod.rs +++ b/src/io/mod.rs @@ -15,7 +15,7 @@ pub(crate) use noop::NoOp; mod open; -mod read; +pub(crate) mod read; mod read_fixed; diff --git a/src/io/read.rs b/src/io/read.rs index c3395b40..a203ebf1 100644 --- a/src/io/read.rs +++ b/src/io/read.rs @@ -1,64 +1,69 @@ +use io_uring::cqueue::Entry; + use crate::buf::BoundedBufMut; use crate::io::SharedFd; -use crate::BufResult; +use crate::{BufResult, OneshotOutputTransform, UnsubmittedOneshot}; -use crate::runtime::driver::op::{Completable, CqeResult, Op}; -use crate::runtime::CONTEXT; use std::io; +use std::marker::PhantomData; + +/// An unsubmitted read operation. +pub type UnsubmittedRead = UnsubmittedOneshot, ReadTransform>; -pub(crate) struct Read { +#[allow(missing_docs)] +pub struct ReadData { /// Holds a strong ref to the FD, preventing the file from being closed /// while the operation is in-flight. - #[allow(dead_code)] - fd: SharedFd, + _fd: SharedFd, - /// Reference to the in-flight buffer. - pub(crate) buf: T, + buf: T, } -impl Op> { - pub(crate) fn read_at(fd: &SharedFd, buf: T, offset: u64) -> io::Result>> { - use io_uring::{opcode, types}; - - CONTEXT.with(|x| { - x.handle().expect("Not in a runtime context").submit_op( - Read { - fd: fd.clone(), - buf, - }, - |read| { - // Get raw buffer info - let ptr = read.buf.stable_mut_ptr(); - let len = read.buf.bytes_total(); - opcode::Read::new(types::Fd(fd.raw_fd()), ptr, len as _) - .offset(offset as _) - .build() - }, - ) - }) - } +#[allow(missing_docs)] +pub struct ReadTransform { + _phantom: PhantomData, } -impl Completable for Read +impl OneshotOutputTransform for ReadTransform where T: BoundedBufMut, { type Output = BufResult; + type StoredData = ReadData; - fn complete(self, cqe: CqeResult) -> Self::Output { - // Convert the operation result to `usize` - let res = cqe.result.map(|v| v as usize); - // Recover the buffer - let mut buf = self.buf; - - // If the operation was successful, advance the initialized cursor. - if let Ok(n) = res { + fn transform_oneshot_output(self, mut data: Self::StoredData, cqe: Entry) -> Self::Output { + let n = cqe.result(); + let res = if n >= 0 { // Safety: the kernel wrote `n` bytes to the buffer. - unsafe { - buf.set_init(n); - } - } + unsafe { data.buf.set_init(n as usize) }; + Ok(n as usize) + } else { + Err(io::Error::from_raw_os_error(-n)) + }; + + (res, data.buf) + } +} + +impl UnsubmittedRead { + pub(crate) fn read_at(fd: &SharedFd, mut buf: T, offset: u64) -> Self { + use io_uring::{opcode, types}; + + // Get raw buffer info + let ptr = buf.stable_mut_ptr(); + let len = buf.bytes_total(); - (res, buf) + Self::new( + ReadData { + _fd: fd.clone(), + buf, + }, + ReadTransform { + _phantom: PhantomData, + }, + opcode::Read::new(types::Fd(fd.raw_fd()), ptr, len as _) + .offset(offset as _) + .build(), + ) } } diff --git a/src/io/socket.rs b/src/io/socket.rs index dda1bb36..081eaf00 100644 --- a/src/io/socket.rs +++ b/src/io/socket.rs @@ -169,8 +169,7 @@ impl Socket { } pub(crate) async fn read(&self, buf: T) -> crate::BufResult { - let op = Op::read_at(&self.fd, buf, 0).unwrap(); - op.await + UnsubmittedOneshot::read_at(&self.fd, buf, 0).submit().await } pub(crate) async fn read_fixed(&self, buf: T) -> crate::BufResult diff --git a/src/lib.rs b/src/lib.rs index d1cc6e02..e75b5803 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -20,7 +20,7 @@ //! // Read some data, the buffer is passed by ownership and //! // submitted to the kernel. When the operation completes, //! // we get the buffer back. -//! let (res, buf) = file.read_at(buf, 0).await; +//! let (res, buf) = file.read_at(buf, 0).submit().await; //! let n = res?; //! //! // Display the contents @@ -78,6 +78,7 @@ pub mod buf; pub mod fs; pub mod net; +pub use io::read::*; pub use io::write::*; pub use runtime::driver::op::{InFlightOneshot, OneshotOutputTransform, UnsubmittedOneshot}; pub use runtime::spawn; @@ -115,7 +116,7 @@ use std::future::Future; /// // Read some data, the buffer is passed by ownership and /// // submitted to the kernel. When the operation completes, /// // we get the buffer back. -/// let (res, buf) = file.read_at(buf, 0).await; +/// let (res, buf) = file.read_at(buf, 0).submit().await; /// let n = res?; /// /// // Display the contents @@ -254,7 +255,7 @@ impl Builder { /// // Read some data, the buffer is passed by ownership and /// // submitted to the kernel. When the operation completes, /// // we get the buffer back. -/// let (res, buf) = file.read_at(buf, 0).await; +/// let (res, buf) = file.read_at(buf, 0).submit().await; /// let n = res?; /// /// // Display the contents diff --git a/tests/driver.rs b/tests/driver.rs index f4381dd5..a123aa27 100644 --- a/tests/driver.rs +++ b/tests/driver.rs @@ -58,6 +58,7 @@ fn complete_ops_on_drop() { }, 25 * 1024 * 1024, ) + .submit() .await .0 .unwrap(); diff --git a/tests/fs_file.rs b/tests/fs_file.rs index 6ec14d43..ab3172f2 100644 --- a/tests/fs_file.rs +++ b/tests/fs_file.rs @@ -19,7 +19,7 @@ const HELLO: &[u8] = b"hello world..."; async fn read_hello(file: &File) { let buf = Vec::with_capacity(1024); - let (res, buf) = file.read_at(buf, 0).await; + let (res, buf) = file.read_at(buf, 0).submit().await; let n = res.unwrap(); assert_eq!(n, HELLO.len()); @@ -315,6 +315,55 @@ fn basic_fallocate() { }); } +#[test] +fn write_linked() { + tokio_uring::start(async { + let tempfile = tempfile(); + let file = File::create(tempfile.path()).await.unwrap(); + + let write1 = file + .write_at(HELLO, 0) + .set_flags(io_uring::squeue::Flags::IO_LINK) + .submit(); + let write2 = file.write_at(HELLO, HELLO.len() as u64).submit(); + + let res1 = write1.await; + let res2 = write2.await; + res1.0.unwrap(); + res2.0.unwrap(); + + let file = std::fs::read(tempfile.path()).unwrap(); + assert_eq!(file, [HELLO, HELLO].concat()); + }); +} + +#[test] +fn read_linked() { + tokio_uring::start(async { + let mut tempfile = tempfile(); + let file = File::open(tempfile.path()).await.unwrap(); + + tempfile.write_all(&[HELLO, HELLO].concat()).unwrap(); + + let buf1 = Vec::with_capacity(HELLO.len()); + let buf2 = Vec::with_capacity(HELLO.len()); + + let read1 = file + .read_at(buf1, 0) + .set_flags(io_uring::squeue::Flags::IO_LINK) + .submit(); + let read2 = file.read_at(buf2, HELLO.len() as u64).submit(); + + let res1 = read1.await; + let res2 = read2.await; + + res1.0.unwrap(); + res2.0.unwrap(); + + assert_eq!([HELLO, HELLO].concat(), [res1.1, res2.1].concat()); + }); +} + fn tempfile() -> NamedTempFile { NamedTempFile::new().unwrap() }