diff --git a/crates/polars-lazy/src/frame/ndjson.rs b/crates/polars-lazy/src/frame/ndjson.rs index 45a5edbe4656..8a584ba12876 100644 --- a/crates/polars-lazy/src/frame/ndjson.rs +++ b/crates/polars-lazy/src/frame/ndjson.rs @@ -45,6 +45,7 @@ impl LazyJsonLineReader { } /// Set the number of rows to use when inferring the json schema. /// the default is 100 rows. + /// Ignored when the schema is specified explicitly using [`Self::with_schema`]. /// Setting to `None` will do a full table scan, very slow. #[must_use] pub fn with_infer_schema_length(mut self, num_rows: Option) -> Self { diff --git a/crates/polars-lazy/src/physical_plan/executors/scan/ndjson.rs b/crates/polars-lazy/src/physical_plan/executors/scan/ndjson.rs index a6f6cedda2f7..e0bd5caf59d8 100644 --- a/crates/polars-lazy/src/physical_plan/executors/scan/ndjson.rs +++ b/crates/polars-lazy/src/physical_plan/executors/scan/ndjson.rs @@ -20,6 +20,11 @@ impl AnonymousScan for LazyJsonLineReader { } fn schema(&self, infer_schema_length: Option) -> PolarsResult { + // Short-circuit schema inference if the schema has been explicitly provided. + if let Some(schema) = &self.schema { + return Ok(schema.clone()); + } + let f = polars_utils::open_file(&self.path)?; let mut reader = std::io::BufReader::new(f); diff --git a/py-polars/polars/io/ndjson.py b/py-polars/polars/io/ndjson.py index 1f0e05e7dbb9..cb0d4802120a 100644 --- a/py-polars/polars/io/ndjson.py +++ b/py-polars/polars/io/ndjson.py @@ -66,6 +66,7 @@ def scan_ndjson( rechunk: bool = True, row_count_name: str | None = None, row_count_offset: int = 0, + schema: SchemaDefinition | None = None, ) -> LazyFrame: """ Lazily read from a newline delimited JSON file or multiple files via glob patterns. @@ -92,6 +93,16 @@ def scan_ndjson( DataFrame row_count_offset Offset to start the row_count column (only use if the name is set) + schema : Sequence of str, (str,DataType) pairs, or a {str:DataType,} dict + The DataFrame schema may be declared in several ways: + + * As a dict of {name:type} pairs; if type is None, it will be auto-inferred. + * As a list of column names; in this case types are automatically inferred. + * As a list of (name,type) pairs; this is equivalent to the dictionary form. + + If you supply a list of column names that does not match the names in the + underlying data, the names given here will overwrite them. The number + of names given in the schema should match the underlying data dimensions. """ if isinstance(source, (str, Path)): @@ -100,6 +111,7 @@ def scan_ndjson( return pl.LazyFrame._scan_ndjson( source, infer_schema_length=infer_schema_length, + schema=schema, batch_size=batch_size, n_rows=n_rows, low_memory=low_memory, diff --git a/py-polars/polars/lazyframe/frame.py b/py-polars/polars/lazyframe/frame.py index e7b3757a5a10..138f0244b587 100644 --- a/py-polars/polars/lazyframe/frame.py +++ b/py-polars/polars/lazyframe/frame.py @@ -482,6 +482,7 @@ def _scan_ndjson( source: str, *, infer_schema_length: int | None = None, + schema: SchemaDefinition | None = None, batch_size: int | None = None, n_rows: int | None = None, low_memory: bool = False, @@ -503,6 +504,7 @@ def _scan_ndjson( self._ldf = PyLazyFrame.new_from_ndjson( source, infer_schema_length, + schema, batch_size, n_rows, low_memory, diff --git a/py-polars/src/lazyframe.rs b/py-polars/src/lazyframe.rs index 68daeba69cfe..719a38989f61 100644 --- a/py-polars/src/lazyframe.rs +++ b/py-polars/src/lazyframe.rs @@ -119,10 +119,11 @@ impl PyLazyFrame { #[staticmethod] #[cfg(feature = "json")] #[allow(clippy::too_many_arguments)] - #[pyo3(signature = (path, infer_schema_length, batch_size, n_rows, low_memory, rechunk, row_count))] + #[pyo3(signature = (path, infer_schema_length, schema, batch_size, n_rows, low_memory, rechunk, row_count))] fn new_from_ndjson( path: String, infer_schema_length: Option, + schema: Option>, batch_size: Option, n_rows: Option, low_memory: bool, @@ -131,15 +132,17 @@ impl PyLazyFrame { ) -> PyResult { let row_count = row_count.map(|(name, offset)| RowCount { name, offset }); - let lf = LazyJsonLineReader::new(path) + let mut lf = LazyJsonLineReader::new(path) .with_infer_schema_length(infer_schema_length) .with_batch_size(batch_size) .with_n_rows(n_rows) .low_memory(low_memory) .with_rechunk(rechunk) - .with_row_count(row_count) - .finish() - .map_err(PyPolarsErr::from)?; + .with_row_count(row_count); + if let Some(schema) = schema { + lf = lf.with_schema(schema.0); + } + let lf = lf.finish().map_err(PyPolarsErr::from)?; Ok(lf.into()) } diff --git a/py-polars/tests/unit/io/test_lazy_json.py b/py-polars/tests/unit/io/test_lazy_json.py index 70c3b7188697..4b3907fc82ba 100644 --- a/py-polars/tests/unit/io/test_lazy_json.py +++ b/py-polars/tests/unit/io/test_lazy_json.py @@ -38,6 +38,24 @@ def test_scan_ndjson(foods_ndjson_path: Path) -> None: assert df["foo"].to_list() == [10, 16, 21, 23, 24, 30, 35] +def test_scan_ndjson_with_schema(foods_ndjson_path: Path) -> None: + schema = { + "category": pl.Categorical, + "calories": pl.Int64, + "fats_g": pl.Float64, + "sugars_g": pl.Int64, + } + df = pl.scan_ndjson(foods_ndjson_path, schema=schema).collect() + assert df["category"].dtype == pl.Categorical + assert df["calories"].dtype == pl.Int64 + assert df["fats_g"].dtype == pl.Float64 + assert df["sugars_g"].dtype == pl.Int64 + + schema["sugars_g"] = pl.Float64 + df = pl.scan_ndjson(foods_ndjson_path, schema=schema).collect() + assert df["sugars_g"].dtype == pl.Float64 + + @pytest.mark.write_disk() def test_scan_with_projection(tmp_path: Path) -> None: tmp_path.mkdir(exist_ok=True)