From 0458d30b4bf078251ac4c4dfc4669e8beec8bc3b Mon Sep 17 00:00:00 2001 From: Max Norfolk <66913041+mnorfolk03@users.noreply.github.com> Date: Mon, 4 Nov 2024 15:39:40 -0500 Subject: [PATCH] fix: CSV Infer Schema now properly supports escaped characters. (#13214) --- .../core/src/datasource/file_format/csv.rs | 56 ++++++++++++++++++- 1 file changed, 55 insertions(+), 1 deletion(-) diff --git a/datafusion/core/src/datasource/file_format/csv.rs b/datafusion/core/src/datasource/file_format/csv.rs index 2aaef2cda1c8..0335c8aa3ff6 100644 --- a/datafusion/core/src/datasource/file_format/csv.rs +++ b/datafusion/core/src/datasource/file_format/csv.rs @@ -454,7 +454,12 @@ impl CsvFormat { .has_header .unwrap_or(state.config_options().catalog.has_header), ) - .with_delimiter(self.options.delimiter); + .with_delimiter(self.options.delimiter) + .with_quote(self.options.quote); + + if let Some(escape) = self.options.escape { + format = format.with_escape(escape); + } if let Some(comment) = self.options.comment { format = format.with_comment(comment); @@ -867,6 +872,55 @@ mod tests { Ok(()) } + #[tokio::test] + async fn test_infer_schema_escape_chars() -> Result<()> { + let session_ctx = SessionContext::new(); + let state = session_ctx.state(); + let variable_object_store = Arc::new(VariableStream::new( + Bytes::from( + r#"c1,c2,c3,c4 +0.3,"Here, is a comma\"",third,3 +0.31,"double quotes are ok, "" quote",third again,9 +0.314,abc,xyz,27"#, + ), + 1, + )); + let object_meta = ObjectMeta { + location: Path::parse("/")?, + last_modified: DateTime::default(), + size: usize::MAX, + e_tag: None, + version: None, + }; + + let num_rows_to_read = 3; + let csv_format = CsvFormat::default() + .with_has_header(true) + .with_schema_infer_max_rec(num_rows_to_read) + .with_quote(b'"') + .with_escape(Some(b'\\')); + + let inferred_schema = csv_format + .infer_schema( + &state, + &(variable_object_store.clone() as Arc), + &[object_meta], + ) + .await?; + + let actual_fields: Vec<_> = inferred_schema + .fields() + .iter() + .map(|f| format!("{}: {:?}", f.name(), f.data_type())) + .collect(); + + assert_eq!( + vec!["c1: Float64", "c2: Utf8", "c3: Utf8", "c4: Int64",], + actual_fields + ); + Ok(()) + } + #[rstest( file_compression_type, case(FileCompressionType::UNCOMPRESSED),