Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix CSV infinite loop and improve error messages #3470

Merged
merged 2 commits into from
Jan 6, 2023
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 63 additions & 20 deletions arrow-csv/src/reader/records.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ pub struct RecordReader<R> {

num_columns: usize,

num_rows: usize,
line_number: usize,
offsets: Vec<usize>,
data: Vec<u8>,
}
Expand All @@ -42,19 +42,21 @@ impl<R: BufRead> RecordReader<R> {
reader,
delimiter,
num_columns,
num_rows: 0,
line_number: 1,
offsets: vec![],
data: vec![],
}
}

fn fill_buf(&mut self, to_read: usize) -> Result<(), ArrowError> {
/// Clears and then fills the buffers on this [`RecordReader`]
/// returning the number of records read
fn fill_buf(&mut self, to_read: usize) -> Result<usize, ArrowError> {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I reworked this to return the quantity read as I thought this made for a slightly clearer API

// Reserve sufficient capacity in offsets
self.offsets.resize(to_read * self.num_columns + 1, 0);
self.num_rows = 0;

let mut read = 0;
if to_read == 0 {
return Ok(());
return Ok(0);
}

// The current offset into `self.data`
Expand All @@ -71,7 +73,7 @@ impl<R: BufRead> RecordReader<R> {

'input: loop {
// Reserve necessary space in output data based on best estimate
let remaining_rows = to_read - self.num_rows;
let remaining_rows = to_read - read;
let capacity = remaining_rows * self.num_columns * AVERAGE_FIELD_SIZE;
let estimated_data = capacity.max(MIN_CAPACITY);
self.data.resize(output_offset + estimated_data, 0);
Expand All @@ -94,24 +96,26 @@ impl<R: BufRead> RecordReader<R> {
ReadRecordResult::InputEmpty => break 'input, // Input exhausted, need to read more
ReadRecordResult::OutputFull => break, // Need to allocate more capacity
ReadRecordResult::OutputEndsFull => {
return Err(ArrowError::CsvError(format!("incorrect number of fields, expected {} got more than {}", self.num_columns, field_count)))
let line_number = self.line_number + read;
return Err(ArrowError::CsvError(format!("incorrect number of fields for line {}, expected {} got more than {}", line_number, self.num_columns, field_count)));
}
ReadRecordResult::Record => {
if field_count != self.num_columns {
return Err(ArrowError::CsvError(format!("incorrect number of fields, expected {} got {}", self.num_columns, field_count)))
let line_number = self.line_number + read;
return Err(ArrowError::CsvError(format!("incorrect number of fields for line {}, expected {} got {}", line_number, self.num_columns, field_count)));
}
self.num_rows += 1;
read += 1;
field_count = 0;

if self.num_rows == to_read {
break 'outer // Read sufficient rows
if read == to_read {
break 'outer; // Read sufficient rows
}

if input.len() == input_offset {
// Input exhausted, need to read more
// Without this read_record will interpret the empty input
// byte array as indicating the end of the file
break 'input
break 'input;
}
}
}
Expand All @@ -135,28 +139,38 @@ impl<R: BufRead> RecordReader<R> {
});
});

Ok(())
self.line_number += read;

Ok(read)
}

/// Skips forward `to_skip` rows
pub fn skip(&mut self, mut to_skip: usize) -> Result<(), ArrowError> {
pub fn skip(&mut self, to_skip: usize) -> Result<(), ArrowError> {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we can document the error behavior.

// TODO: This could be done by scanning for unquoted newline delimiters
while to_skip != 0 {
self.fill_buf(to_skip.min(1024))?;
to_skip -= self.num_rows;
let mut skipped = 0;
while to_skip > skipped {
let read = self.fill_buf(to_skip.min(1024))?;
if read == 0 {
return Err(ArrowError::CsvError(format!(
"Failed to skip {} rows only found {}",
to_skip, skipped
)));
}

skipped += read;
}
Ok(())
}

/// Reads up to `to_read` rows from the reader
pub fn read(&mut self, to_read: usize) -> Result<StringRecords<'_>, ArrowError> {
self.fill_buf(to_read)?;
let num_rows = self.fill_buf(to_read)?;

// Need to slice fields to the actual number of rows read
//
// We intentionally avoid using `Vec::truncate` to avoid having
// to re-initialize the data again
let num_fields = self.num_rows * self.num_columns;
let num_fields = num_rows * self.num_columns;
let last_offset = self.offsets[num_fields];

// Need to truncate data to the actual amount of data read
Expand All @@ -165,8 +179,8 @@ impl<R: BufRead> RecordReader<R> {
})?;

Ok(StringRecords {
num_rows,
num_columns: self.num_columns,
num_rows: self.num_rows,
offsets: &self.offsets[..num_fields + 1],
data,
})
Expand Down Expand Up @@ -263,4 +277,33 @@ mod tests {
})
}
}

#[test]
fn test_invalid_fields() {
let csv = "a,b\nb,c\na\n";
let cursor = Cursor::new(csv.as_bytes());
let mut reader = RecordReader::new(cursor, Reader::new(), 2);
let err = reader.read(4).unwrap_err().to_string();

let expected =
"Csv error: incorrect number of fields for line 3, expected 2 got 1";

assert_eq!(err, expected);

// Test with initial skip
let cursor = Cursor::new(csv.as_bytes());
let mut reader = RecordReader::new(cursor, Reader::new(), 2);
reader.skip(1).unwrap();
let err = reader.read(4).unwrap_err().to_string();
assert_eq!(err, expected);
}

#[test]
fn test_skip_insufficient_rows() {
let csv = "a\nv\n";
let cursor = Cursor::new(csv.as_bytes());
let mut reader = RecordReader::new(cursor, Reader::new(), 1);
let err = reader.skip(3).unwrap_err().to_string();
assert_eq!(err, "Csv error: Failed to skip 3 rows only found 2");
}
}