Skip to content

Commit

Permalink
Add tests to check that reading can overflow internal counters
Browse files Browse the repository at this point in the history
  • Loading branch information
Mingun authored and dralley committed Jun 24, 2024
1 parent 4674244 commit 7f92129
Show file tree
Hide file tree
Showing 3 changed files with 202 additions and 0 deletions.
76 changes: 76 additions & 0 deletions src/utils.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,14 @@
use std::borrow::{Borrow, Cow};
use std::fmt::{self, Debug, Formatter};
use std::io;
use std::ops::Deref;

#[cfg(feature = "async-tokio")]
use std::{
pin::Pin,
task::{Context, Poll},
};

#[cfg(feature = "serialize")]
use serde::de::{Deserialize, Deserializer, Error, Visitor};
#[cfg(feature = "serialize")]
Expand Down Expand Up @@ -197,6 +204,75 @@ impl<'de> Serialize for Bytes<'de> {

////////////////////////////////////////////////////////////////////////////////////////////////////

/// A simple producer of infinite stream of bytes, useful in tests.
///
/// Will repeat `chunk` field indefinitely.
pub struct Fountain<'a> {
/// That piece of data repeated infinitely...
pub chunk: &'a [u8],
/// Part of `chunk` that was consumed by BufRead impl
pub consumed: usize,
/// The overall count of read bytes
pub overall_read: u64,
}

impl<'a> io::Read for Fountain<'a> {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
let available = &self.chunk[self.consumed..];
let len = buf.len().min(available.len());
let (portion, _) = available.split_at(len);

buf.copy_from_slice(portion);
Ok(len)
}
}

impl<'a> io::BufRead for Fountain<'a> {
#[inline]
fn fill_buf(&mut self) -> io::Result<&[u8]> {
Ok(&self.chunk[self.consumed..])
}

fn consume(&mut self, amt: usize) {
self.consumed += amt;
if self.consumed == self.chunk.len() {
self.consumed = 0;
}
self.overall_read += amt as u64;
}
}

#[cfg(feature = "async-tokio")]
impl<'a> tokio::io::AsyncRead for Fountain<'a> {
fn poll_read(
self: Pin<&mut Self>,
_cx: &mut Context<'_>,
buf: &mut tokio::io::ReadBuf<'_>,
) -> Poll<io::Result<()>> {
let available = &self.chunk[self.consumed..];
let len = buf.remaining().min(available.len());
let (portion, _) = available.split_at(len);

buf.put_slice(portion);
Poll::Ready(Ok(()))
}
}

#[cfg(feature = "async-tokio")]
impl<'a> tokio::io::AsyncBufRead for Fountain<'a> {
#[inline]
fn poll_fill_buf(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<&[u8]>> {
Poll::Ready(io::BufRead::fill_buf(self.get_mut()))
}

#[inline]
fn consume(self: Pin<&mut Self>, amt: usize) {
io::BufRead::consume(self.get_mut(), amt);
}
}

////////////////////////////////////////////////////////////////////////////////////////////////////

/// A function to check whether the byte is a whitespace (blank, new line, carriage return or tab).
#[inline]
pub const fn is_whitespace(b: u8) -> bool {
Expand Down
65 changes: 65 additions & 0 deletions tests/async-tokio.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
use std::iter;

use pretty_assertions::assert_eq;
use quick_xml::events::Event::*;
use quick_xml::name::QName;
use quick_xml::reader::Reader;

// Import `small_buffers_tests!`
Expand Down Expand Up @@ -36,3 +39,65 @@ async fn test_sample() {
}
assert_eq!((count, reads), (1247, 5245));
}

/// Regression test for https://github.com/tafia/quick-xml/issues/751
///
/// Actually, that error was not found in async reader, but we would to test it as well.
#[tokio::test]
async fn issue751() {
let mut text = Vec::new();
let mut chunk = Vec::new();
chunk.extend_from_slice(b"<content>");
for data in iter::repeat(b"some text inside").take(1000) {
chunk.extend_from_slice(data);
text.extend_from_slice(data);
}
chunk.extend_from_slice(b"</content>");

let mut reader = Reader::from_reader(quick_xml::utils::Fountain {
chunk: &chunk,
consumed: 0,
overall_read: 0,
});
let mut buf = Vec::new();
let mut starts = 0u64;
let mut ends = 0u64;
let mut texts = 0u64;
loop {
buf.clear();
match reader.read_event_into_async(&mut buf).await {
Err(e) => panic!("Error at position {}: {:?}", reader.error_position(), e),
Ok(Eof) => break,

Ok(Start(e)) => {
starts += 1;
assert_eq!(
e.name(),
QName(b"content"),
"starts: {starts}, ends: {ends}, texts: {texts}"
);
}
Ok(End(e)) => {
ends += 1;
assert_eq!(
e.name(),
QName(b"content"),
"starts: {starts}, ends: {ends}, texts: {texts}"
);
}
Ok(Text(e)) => {
texts += 1;
assert_eq!(
e.as_ref(),
text,
"starts: {starts}, ends: {ends}, texts: {texts}"
);
}
_ => (),
}
// If we successfully read more than `u32::MAX`, the test is passed
if reader.get_ref().overall_read >= u32::MAX as u64 {
break;
}
}
}
61 changes: 61 additions & 0 deletions tests/issues.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
//!
//! Name each module / test as `issue<GH number>` and keep sorted by issue number
use std::iter;
use std::sync::mpsc;

use quick_xml::errors::{Error, IllFormedError, SyntaxError};
Expand Down Expand Up @@ -276,3 +277,63 @@ fn issue706() {
}
}
}

/// Regression test for https://github.com/tafia/quick-xml/issues/751
#[test]
fn issue751() {
let mut text = Vec::new();
let mut chunk = Vec::new();
chunk.extend_from_slice(b"<content>");
for data in iter::repeat(b"some text inside").take(1000) {
chunk.extend_from_slice(data);
text.extend_from_slice(data);
}
chunk.extend_from_slice(b"</content>");

let mut reader = Reader::from_reader(quick_xml::utils::Fountain {
chunk: &chunk,
consumed: 0,
overall_read: 0,
});
let mut buf = Vec::new();
let mut starts = 0u64;
let mut ends = 0u64;
let mut texts = 0u64;
loop {
buf.clear();
match reader.read_event_into(&mut buf) {
Err(e) => panic!("Error at position {}: {:?}", reader.error_position(), e),
Ok(Event::Eof) => break,

Ok(Event::Start(e)) => {
starts += 1;
assert_eq!(
e.name(),
QName(b"content"),
"starts: {starts}, ends: {ends}, texts: {texts}"
);
}
Ok(Event::End(e)) => {
ends += 1;
assert_eq!(
e.name(),
QName(b"content"),
"starts: {starts}, ends: {ends}, texts: {texts}"
);
}
Ok(Event::Text(e)) => {
texts += 1;
assert_eq!(
e.as_ref(),
text,
"starts: {starts}, ends: {ends}, texts: {texts}"
);
}
_ => (),
}
// If we successfully read more than `u32::MAX`, the test is passed
if reader.get_ref().overall_read >= u32::MAX as u64 {
break;
}
}
}

0 comments on commit 7f92129

Please sign in to comment.