Skip to content

Commit

Permalink
fix(rust, python): reject multithreading on excessive ',\n' fields (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 authored Feb 15, 2023
1 parent f55cf83 commit 2c5e079
Show file tree
Hide file tree
Showing 3 changed files with 96 additions and 34 deletions.
67 changes: 51 additions & 16 deletions polars/polars-io/src/csv/parser.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use memchr::memchr2_iter;
use num::traits::Pow;
use polars_core::prelude::*;

Expand Down Expand Up @@ -32,12 +33,36 @@ pub(crate) fn next_line_position(
quote_char: Option<u8>,
eol_char: u8,
) -> Option<usize> {
fn accept_line(
line: &[u8],
expected_fields: usize,
delimiter: u8,
eol_char: u8,
quote_char: Option<u8>,
) -> bool {
let mut count = 0usize;
for (field, _) in SplitFields::new(line, delimiter, quote_char, eol_char) {
if memchr2_iter(delimiter, eol_char, field).count() >= expected_fields {
return false;
}
count += 1;
}
count == expected_fields
}

// we check 3 subsequent lines for `accept_line` before we accept
// if 3 groups are rejected we reject completely
let mut rejected_line_groups = 0u8;

let mut total_pos = 0;
if input.is_empty() {
return None;
}
let mut lines_checked = 0u16;
loop {
if rejected_line_groups >= 3 {
return None;
}
lines_checked += 1;
// headers might have an extra value
// So if we have churned through enough lines
Expand All @@ -53,29 +78,39 @@ pub(crate) fn next_line_position(
}
debug_assert!(pos <= input.len());
let new_input = unsafe { input.get_unchecked(pos..) };
let line = SplitLines::new(new_input, quote_char.unwrap_or(b'"'), eol_char).next();

let count_fields =
|line: &[u8]| SplitFields::new(line, delimiter, quote_char, eol_char).count();
let mut lines = SplitLines::new(new_input, quote_char.unwrap_or(b'"'), eol_char);
let line = lines.next();

match (line, expected_fields) {
// count the fields, and determine if they are equal to what we expect from the schema
(Some(line), Some(expected_fields)) if { count_fields(line) == expected_fields } => {
return Some(total_pos + pos)
}
(Some(_), Some(_)) => {
debug_assert!(pos < input.len());
unsafe {
input = input.get_unchecked(pos + 1..);
(Some(line), Some(expected_fields)) => {
if accept_line(line, expected_fields, delimiter, eol_char, quote_char) {
let mut valid = true;
for line in lines.take(2) {
if !accept_line(line, expected_fields, delimiter, eol_char, quote_char) {
valid = false;
break;
}
}
if valid {
return Some(total_pos + pos);
} else {
rejected_line_groups += 1;
}
} else {
debug_assert!(pos < input.len());
unsafe {
input = input.get_unchecked(pos + 1..);
}
total_pos += pos + 1;
}
total_pos += pos + 1;
}
// don't count the fields
(Some(_), None) => return Some(total_pos + pos),
// no new line found, check latest line (without eol) for number of fields
(None, Some(expected_fields)) if { count_fields(new_input) == expected_fields } => {
return Some(total_pos + pos)
}
// // no new line found, check latest line (without eol) for number of fields
// (None, Some(expected_fields)) if { count_fields(new_input) == expected_fields } => {
// return Some(total_pos + pos)
// }
_ => return None,
}
}
Expand Down
35 changes: 17 additions & 18 deletions polars/polars-io/src/csv/read_impl/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -408,30 +408,29 @@ impl<'a> CoreReader<'a> {

let chunk_size = std::cmp::min(self.chunk_size, total_rows);
let n_chunks = total_rows / chunk_size;
if logging {
eprintln!(
"no. of chunks: {n_chunks} processed by: {n_threads} threads at 1 chunk/thread",
);
}

let n_file_chunks = if streaming { n_chunks } else { *n_threads };

// split the file by the nearest new line characters such that every thread processes
// approximately the same number of rows.
Ok((
get_file_chunks(
bytes,
n_file_chunks,
self.schema.len(),
self.delimiter,
self.quote_char,
self.eol_char,
),
chunk_size,
total_rows,
starting_point_offset,

let chunks = get_file_chunks(
bytes,
))
n_file_chunks,
self.schema.len(),
self.delimiter,
self.quote_char,
self.eol_char,
);

if logging {
eprintln!(
"no. of chunks: {} processed by: {n_threads} threads.",
chunks.len()
);
}

Ok((chunks, chunk_size, total_rows, starting_point_offset, bytes))
}

fn get_projection(&mut self) -> Vec<usize> {
Expand Down
28 changes: 28 additions & 0 deletions py-polars/tests/unit/io/test_csv.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from datetime import date, datetime, time, timedelta, timezone
from pathlib import Path

import numpy as np
import pytest

import polars as pl
Expand Down Expand Up @@ -1168,3 +1169,30 @@ def test_read_web_file() -> None:
url = "https://raw.githubusercontent.com/pola-rs/polars/master/examples/datasets/foods1.csv"
df = pl.read_csv(url)
assert df.shape == (27, 4)


@pytest.mark.slow()
def test_csv_multiline_splits() -> None:
# create a very unlikely csv file with many multilines in a
# single field (e.g. 5000). polars must reject multi-threading here
# as it cannot find proper file chunks without sequentially parsing.

np.random.seed(0)
f = io.BytesIO()

def some_multiline_str(n: int) -> str:
strs = []
strs.append('"')
# sample between 0 and 5 so that it is likely
# the multiline field also go 3 separators.
for length in np.random.randint(0, 5, n):
strs.append(f"{'xx,' * length}")

strs.append('"')
return "\n".join(strs)

for _ in range(4):
f.write(f"field1,field2,{some_multiline_str(5000)}\n".encode())

f.seek(0)
assert pl.read_csv(f, has_header=False).shape == (4, 3)

0 comments on commit 2c5e079

Please sign in to comment.