diff --git a/examples/csv_read.rs b/examples/csv_read.rs index b57972e10c6..988110c45fa 100644 --- a/examples/csv_read.rs +++ b/examples/csv_read.rs @@ -12,7 +12,7 @@ fn read_path(path: &str, projection: Option<&[usize]>) -> Result Result<()> { let mut reader = AsyncReaderBuilder::new().create_reader(file); - let fields = infer_schema(&mut reader, None, true, &infer).await?; + let (fields, _) = infer_schema(&mut reader, None, true, &infer).await?; let mut rows = vec![ByteRecord::default(); 100]; let rows_read = read_rows(&mut reader, 0, &mut rows).await?; diff --git a/examples/csv_read_parallel.rs b/examples/csv_read_parallel.rs index 7a46650cab3..fd71f081cfb 100644 --- a/examples/csv_read_parallel.rs +++ b/examples/csv_read_parallel.rs @@ -17,7 +17,8 @@ fn parallel_read(path: &str) -> Result>>> { let (tx, rx) = unbounded(); let mut reader = read::ReaderBuilder::new().from_path(path)?; - let fields = read::infer_schema(&mut reader, Some(batch_size * 10), has_header, &read::infer)?; + let (fields, _) = + read::infer_schema(&mut reader, Some(batch_size * 10), has_header, &read::infer)?; let fields = Arc::new(fields); let start = SystemTime::now(); diff --git a/src/io/csv/read/infer_schema.rs b/src/io/csv/read/infer_schema.rs index 93ac5f8b1ef..83462bc1e8d 100644 --- a/src/io/csv/read/infer_schema.rs +++ b/src/io/csv/read/infer_schema.rs @@ -10,13 +10,14 @@ use super::super::utils::merge_schema; use super::{ByteRecord, Reader}; /// Infers the [`Field`]s of a CSV file by reading through the first n records up to `max_rows`. +/// Also returns the number of rows used to infer. /// Seeks back to the begining of the file _after_ the header pub fn infer_schema DataType>( reader: &mut Reader, max_rows: Option, has_header: bool, infer: &F, -) -> Result> { +) -> Result<(Vec, usize)> { // get or create header names // when has_header is false, creates default column names with column_ prefix let headers: Vec = if has_header { @@ -57,5 +58,5 @@ pub fn infer_schema DataType>( // return the reader seek back to the start reader.seek(position)?; - Ok(fields) + Ok((fields, records_count)) } diff --git a/src/io/csv/read_async/infer_schema.rs b/src/io/csv/read_async/infer_schema.rs index 506f02dc87c..b22f2a613d3 100644 --- a/src/io/csv/read_async/infer_schema.rs +++ b/src/io/csv/read_async/infer_schema.rs @@ -15,7 +15,7 @@ pub async fn infer_schema( max_rows: Option, has_header: bool, infer: &F, -) -> Result> +) -> Result<(Vec, usize)> where R: AsyncRead + AsyncSeek + Unpin + Send + Sync, F: Fn(&[u8]) -> DataType, @@ -65,5 +65,5 @@ where // return the reader seek back to the start reader.seek(position).await?; - Ok(fields) + Ok((fields, records_count)) } diff --git a/tests/it/io/csv/read.rs b/tests/it/io/csv/read.rs index 90d4cc5eee3..5bf193c1496 100644 --- a/tests/it/io/csv/read.rs +++ b/tests/it/io/csv/read.rs @@ -27,7 +27,7 @@ fn read() -> Result<()> { "Aberdeen, Aberdeen City, UK",57.149651,-2.099075"#; let mut reader = ReaderBuilder::new().from_reader(Cursor::new(data)); - let fields = infer_schema(&mut reader, None, true, &infer)?; + let (fields, _) = infer_schema(&mut reader, None, true, &infer)?; let mut rows = vec![ByteRecord::default(); 100]; let rows_read = read_rows(&mut reader, 0, &mut rows)?; @@ -58,7 +58,7 @@ fn infer_basics() -> Result<()> { let file = Cursor::new("1,2,3\na,b,c\na,,c"); let mut reader = ReaderBuilder::new().from_reader(file); - let fields = infer_schema(&mut reader, Some(10), false, &infer)?; + let (fields, _) = infer_schema(&mut reader, Some(10), false, &infer)?; assert_eq!( fields, @@ -76,7 +76,7 @@ fn infer_ints() -> Result<()> { let file = Cursor::new("1,2,3\n1,a,5\n2,,4"); let mut reader = ReaderBuilder::new().from_reader(file); - let fields = infer_schema(&mut reader, Some(10), false, &infer)?; + let (fields, _) = infer_schema(&mut reader, Some(10), false, &infer)?; assert_eq!( fields, diff --git a/tests/it/io/csv/read_async.rs b/tests/it/io/csv/read_async.rs index 4f4a51d95d8..a31319cc117 100644 --- a/tests/it/io/csv/read_async.rs +++ b/tests/it/io/csv/read_async.rs @@ -23,7 +23,7 @@ async fn read() -> Result<()> { "Aberdeen, Aberdeen City, UK",57.149651,-2.099075"#; let mut reader = AsyncReaderBuilder::new().create_reader(Cursor::new(data.as_bytes())); - let fields = infer_schema(&mut reader, None, true, &infer).await?; + let (fields, _) = infer_schema(&mut reader, None, true, &infer).await?; let mut rows = vec![ByteRecord::default(); 100]; let rows_read = read_rows(&mut reader, 0, &mut rows).await?;