From 6e1fc2b4ab22d18949436b27d9d7efb4df3c8053 Mon Sep 17 00:00:00 2001 From: Will Jones Date: Sat, 29 Jan 2022 19:20:22 -0800 Subject: [PATCH 1/3] Draft PyArrowDataset reader impl --- Cargo.lock | 2 + Cargo.toml | 15 +- src/context.rs | 13 +- src/dataset.rs | 427 +++++++++++++++++++++++++++++++++++++++++++++++++ src/lib.rs | 1 + 5 files changed, 452 insertions(+), 6 deletions(-) create mode 100644 src/dataset.rs diff --git a/Cargo.lock b/Cargo.lock index 654dedd..ddc46f4 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -314,7 +314,9 @@ dependencies = [ name = "datafusion-python" version = "0.4.0" dependencies = [ + "async-trait", "datafusion", + "futures", "pyo3", "rand 0.7.3", "tokio", diff --git a/Cargo.toml b/Cargo.toml index aa16236..e855d41 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -28,11 +28,22 @@ edition = "2021" rust-version = "1.57" [dependencies] -tokio = { version = "1.0", features = ["macros", "rt", "rt-multi-thread", "sync"] } +tokio = { version = "1.0", features = [ + "macros", + "rt", + "rt-multi-thread", + "sync", +] } rand = "0.7" -pyo3 = { version = "0.14", features = ["extension-module", "abi3", "abi3-py36"] } +pyo3 = { version = "0.14", features = [ + "extension-module", + "abi3", + "abi3-py36", +] } datafusion = { version = "6.0.0", features = ["pyarrow"] } uuid = { version = "0.8", features = ["v4"] } +async-trait = "0.1.41" +futures = "0.3" [lib] name = "_internal" diff --git a/src/context.rs b/src/context.rs index 7f386ba..0672d35 100644 --- a/src/context.rs +++ b/src/context.rs @@ -31,6 +31,7 @@ use datafusion::prelude::CsvReadOptions; use crate::catalog::PyCatalog; use crate::dataframe::PyDataFrame; +use crate::dataset::PyArrowDatasetTable; use crate::errors::DataFusionError; use crate::udf::PyScalarUDF; use crate::utils::wait_for_future; @@ -60,10 +61,7 @@ impl PyExecutionContext { Ok(PyDataFrame::new(df)) } - fn create_dataframe( - &mut self, - partitions: Vec>, - ) -> PyResult { + fn create_dataframe(&mut self, partitions: Vec>) -> PyResult { let table = MemTable::try_new(partitions[0][0].schema(), partitions) .map_err(DataFusionError::from)?; @@ -143,6 +141,13 @@ impl PyExecutionContext { Ok(()) } + fn register_dataset(&mut self, name: &str, dataset: PyArrowDatasetTable) -> PyResult<()> { + self.ctx + .register_table(name, Arc::new(dataset)) + .map_err(DataFusionError::from)?; + Ok(()) + } + fn register_udf(&mut self, udf: PyScalarUDF) -> PyResult<()> { self.ctx.register_udf(udf.function); Ok(()) diff --git a/src/dataset.rs b/src/dataset.rs new file mode 100644 index 0000000..b6332ea --- /dev/null +++ b/src/dataset.rs @@ -0,0 +1,427 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. +use async_trait::async_trait; +use datafusion::arrow::datatypes::Schema; +use datafusion::arrow::datatypes::SchemaRef; +use datafusion::arrow::error::{ArrowError, Result as ArrowResult}; +use datafusion::arrow::pyarrow::PyArrowConvert; +use datafusion::arrow::record_batch::RecordBatch; +use datafusion::datasource::datasource::TableProviderFilterPushDown; +use datafusion::datasource::TableProvider; +use datafusion::error::DataFusionError; +use datafusion::error::Result; +use datafusion::logical_plan::Expr; +use datafusion::logical_plan::Operator; +use datafusion::physical_plan::metrics::{ExecutionPlanMetricsSet, MetricsSet}; +use datafusion::physical_plan::stream::RecordBatchReceiverStream; +use datafusion::physical_plan::{ + DisplayFormatType, ExecutionPlan, Partitioning, SendableRecordBatchStream, Statistics, +}; +use datafusion::scalar::ScalarValue::*; +use pyo3::conversion::ToPyObject; +use pyo3::exceptions::{PyAssertionError, PyNotImplementedError, PyStopIteration}; +use pyo3::prelude::*; +use pyo3::types::PyDict; +use std::any::Any; +use std::fmt; +use std::sync::Arc; +use tokio::{ + sync::mpsc::{channel, Receiver, Sender}, + task, +}; + +pub struct PyArrowDatasetTable { + dataset: Py, + schema: SchemaRef, +} + +impl<'py> FromPyObject<'py> for PyArrowDatasetTable { + fn extract(ob: &'py PyAny) -> PyResult { + // Check it's a PyArrow dataset + // "pyarrow.dataset.FileSystemDataset" + + let dataset: Py = ob.extract()?; + let schema = Python::with_gil(|py| -> PyResult { + Schema::from_pyarrow(dataset.getattr(py, "schema")?.as_ref(py)) + })?; + + Ok(PyArrowDatasetTable { + dataset, + schema: Arc::new(schema), + }) + } +} + +#[async_trait] +impl TableProvider for PyArrowDatasetTable { + fn as_any(&self) -> &dyn Any { + self + } + + fn schema(&self) -> SchemaRef { + self.schema.clone() + } + + async fn scan( + &self, + projection: &Option>, + batch_size: usize, + filters: &[Expr], + limit: Option, + ) -> Result> { + let scanner = Python::with_gil(|py| -> PyResult> { + let scanner_kwargs = PyDict::new(py); + scanner_kwargs.set_item("batch_size", batch_size)?; + + let combined_filter = filters + .iter() + .map(|f| f.clone()) + .reduce(|acc, item| acc.and(item)); + if let Some(expr) = combined_filter { + scanner_kwargs.set_item("filter", expr_to_pyarrow(&expr)?)?; + }; + + if let Some(indices) = projection { + let column_names: Vec = self + .schema + .project(indices)? + .fields() + .iter() + .map(|field| field.name().clone()) + .collect(); + scanner_kwargs.set_item("columns", column_names)?; + } + + Ok(self + .dataset + .call_method(py, "scanner", (), Some(scanner_kwargs))? + .extract(py)?) + }); + match scanner { + Ok(scanner) => Ok(Arc::new(PyArrowDatasetExec { + scanner: PyArrowDatasetScanner { + scanner: Arc::new(scanner), + limit, + schema: self.schema.clone(), + }, + projected_statistics: Statistics::default(), + metrics: ExecutionPlanMetricsSet::new(), + })), + Err(err) => Err(DataFusionError::Execution(err.to_string())), + } + } + + fn supports_filter_pushdown(&self, _: &Expr) -> Result { + Ok(TableProviderFilterPushDown::Exact) + } +} + +pub struct PyArrowDatasetScanner { + scanner: Arc>, + limit: Option, + schema: SchemaRef, +} + +impl fmt::Debug for PyArrowDatasetScanner { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("PyArrowDatasetScanner") + .field("scanner", &"pyarrow.dataset.Scanner") + .field("limit", &self.limit) + .field("schema", &self.schema) + .finish() + } +} + +impl Clone for PyArrowDatasetScanner { + fn clone(&self) -> Self { + PyArrowDatasetScanner { + // TODO: Is this a bad way to clone? + scanner: self.scanner.clone(), + limit: self.limit.clone(), + schema: self.schema.clone(), + } + } +} + +impl PyArrowDatasetScanner { + fn projected_schema(self) -> SchemaRef { + self.schema + } + + fn get_batches(&self, response_tx: Sender>) -> Result<()> { + let mut count = 0; + + loop { + // TODO: Avoid Python GIL with Arrow C Stream interface? + // https://arrow.apache.org/docs/dev/format/CStreamInterface.html + // https://github.com/apache/arrow/blob/cc4e2a54309813e6bbbb36ba50bcd22a7b71d3d9/python/pyarrow/ipc.pxi#L620 + let res = Python::with_gil(|py| -> PyResult> { + let batch_iter = self.scanner.call_method0(py, "to_batches")?; + let py_batch_res = batch_iter.call_method0(py, "__next__"); + match py_batch_res { + Ok(py_batch) => Ok(Some(RecordBatch::from_pyarrow(py_batch.extract(py)?)?)), + Err(error) if error.is_instance::(py) => Ok(None), + Err(error) => Err(error), + } + }); + + match (self.limit, res) { + (Some(limit), Ok(Some(batch))) => { + // Handle limit parameter by stopping iterator early + let next_total = count + batch.num_rows(); + if next_total == limit { + send_result(&response_tx, Ok(batch))?; + break; + } else if next_total < limit { + count += batch.num_rows(); + send_result(&response_tx, Ok(batch))?; + } else { + count = limit; + send_result(&response_tx, Ok(batch.slice(0, limit - count)))?; + break; + } + } + (None, Ok(Some(batch))) => { + count += batch.num_rows(); + send_result(&response_tx, Ok(batch))?; + } + (_, Ok(None)) => { + break; + } + (_, Err(err)) => { + send_result(&response_tx, Err(ArrowError::IoError(err.to_string())))?; + } + } + } + + Ok(()) + } +} + +fn send_result( + response_tx: &Sender>, + result: ArrowResult, +) -> Result<()> { + // Note this function is running on its own blockng tokio thread so blocking here is ok. + response_tx + .blocking_send(result) + .map_err(|e| DataFusionError::Execution(e.to_string()))?; + Ok(()) +} + +/// Execution plan for scanning a PyArrow dataset +#[derive(Debug, Clone)] +pub struct PyArrowDatasetExec { + scanner: PyArrowDatasetScanner, + projected_statistics: Statistics, + /// Execution metrics + metrics: ExecutionPlanMetricsSet, +} + +#[async_trait] +impl ExecutionPlan for PyArrowDatasetExec { + /// Return a reference to Any that can be used for downcasting + fn as_any(&self) -> &dyn Any { + self + } + + fn schema(&self) -> SchemaRef { + self.scanner.clone().projected_schema().clone() + } + + fn children(&self) -> Vec> { + // this is a leaf node and has no children + vec![] + } + + /// Get the output partitioning of this plan + fn output_partitioning(&self) -> Partitioning { + Partitioning::UnknownPartitioning(1) + } + + fn with_new_children( + &self, + children: Vec>, + ) -> Result> { + if children.is_empty() { + Ok(Arc::new(self.clone())) + } else { + Err(DataFusionError::Internal(format!( + "Children cannot be replaced in {:?}", + self + ))) + } + } + + async fn execute(&self, _partition_index: usize) -> Result { + let (response_tx, response_rx): ( + Sender>, + Receiver>, + ) = channel(2); + + let cloned = self.scanner.clone(); + + let join_handle = task::spawn_blocking(move || { + if let Err(e) = cloned.get_batches(response_tx) { + println!("Dataset scanner thread terminated due to error: {:?}", e); + } + }); + + Ok(RecordBatchReceiverStream::create( + &self.scanner.clone().projected_schema(), + response_rx, + join_handle, + )) + } + + fn fmt_as(&self, t: DisplayFormatType, f: &mut std::fmt::Formatter) -> std::fmt::Result { + match t { + DisplayFormatType::Default => { + write!( + f, + // TODO: better fmt + "PyArrowDatasetExec: limit={:?}, partitions=...", + self.scanner.limit + ) + } + } + } + + fn metrics(&self) -> Option { + Some(self.metrics.clone_inner()) + } + + fn statistics(&self) -> Statistics { + self.projected_statistics.clone() + } +} + +// TODO: replace with impl PyArrowConvert for Expr +// https://github.com/apache/arrow-rs/blob/master/arrow/src/pyarrow.rs +fn expr_to_pyarrow(expr: &Expr) -> PyResult { + Python::with_gil(|py| -> PyResult { + let ds = PyModule::import(py, "pyarrow.dataset")?; + let field = ds.getattr("field")?; + + let mut worklist: Vec<&Expr> = Vec::new(); // Expressions to parse + let mut result_list: Vec = Vec::new(); // Expressions that have been parsed + worklist.push(expr); + + while let Some(parent) = worklist.pop() { + match parent { + Expr::Column(col) => { + result_list.push(field.call1((col.name.clone(),))?.into()); + } + // TODO: finish implementing PyArrowConvert for ScalarValue? + // https://github.com/apache/arrow-datafusion/blob/master/datafusion/src/pyarrow.rs + Expr::Literal(scalar) => { + match scalar { + Boolean(val) => { + result_list.push(val.to_object(py)); + } + Float32(val) => { + result_list.push(val.to_object(py)); + } + Float64(val) => { + result_list.push(val.to_object(py)); + } + Int8(val) => { + result_list.push(val.to_object(py)); + } + Int16(val) => { + result_list.push(val.to_object(py)); + } + Int32(val) => { + result_list.push(val.to_object(py)); + } + Int64(val) => { + result_list.push(val.to_object(py)); + } + UInt8(val) => { + result_list.push(val.to_object(py)); + } + UInt16(val) => { + result_list.push(val.to_object(py)); + } + UInt32(val) => { + result_list.push(val.to_object(py)); + } + UInt64(val) => { + result_list.push(val.to_object(py)); + } + Utf8(val) => { + result_list.push(val.to_object(py)); + } + // TODO: indicate which somehow? + _ => { + return Err(PyNotImplementedError::new_err( + "Scalar type not yet supported", + )); + } + } + } + Expr::BinaryExpr { left, right, op } => { + let left_val = result_list.pop(); + let right_val = result_list.pop(); + match (left_val, right_val) { + (Some(left_val), Some(right_val)) => { + match op { + // pull children off of result_list + Operator::Eq => result_list.push(left_val.call_method1( + py, + "__eq__", + (right_val,), + )?), + Operator::NotEq => result_list.push(left_val.call_method1( + py, + "__ne__", + (right_val,), + )?), + _ => { + return Err(PyNotImplementedError::new_err( + "Operation not yet supported", + )); + } + } + } + (None, None) => { + // Need to process children first + worklist.push(parent); + worklist.push(&**left); + worklist.push(&**right); + } + _ => { + return Err(PyNotImplementedError::new_err( + "Operation not yet supported", + )); + } + } + } + _ => { + return Err(PyNotImplementedError::new_err( + "Expression not yet supported", + )); + } + } + } + + match result_list.len() { + 1 => Ok(result_list.pop().unwrap()), + _ => Err(PyAssertionError::new_err("something went wrong")), + } + }) +} diff --git a/src/lib.rs b/src/lib.rs index d40bae2..f9ba393 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -20,6 +20,7 @@ use pyo3::prelude::*; mod catalog; mod context; mod dataframe; +mod dataset; mod errors; mod expression; mod functions; From b8cb98b9d49c99a46d12e8d779c5dd901f1cbbbf Mon Sep 17 00:00:00 2001 From: Will Jones Date: Sat, 19 Feb 2022 15:23:13 -0800 Subject: [PATCH 2/3] Update datafusion and create some dataset tests --- Cargo.lock | 227 +++++++++++----- Cargo.toml | 7 +- datafusion/tests/test_pyarrow_dataset.py | 77 ++++++ src/dataframe.rs | 29 ++ src/dataset.rs | 320 +++++++++++++---------- src/udaf.rs | 11 - src/udf.rs | 5 +- 7 files changed, 456 insertions(+), 220 deletions(-) create mode 100644 datafusion/tests/test_pyarrow_dataset.py diff --git a/Cargo.lock b/Cargo.lock index ddc46f4..8a9af71 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -57,15 +57,16 @@ checksum = "8da52d66c7071e2e3fa2a1e5c6d088fec47b593032b254f5e980de8ea54454d6" [[package]] name = "arrow" -version = "6.5.0" +version = "8.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "216c6846a292bdd93c2b93c1baab58c32ff50e2ab5e8d50db333ab518535dd8b" +checksum = "ce240772a007c63658c1d335bb424fd1019b87895dee899b7bf70e85b2d24e5f" dependencies = [ "bitflags", "chrono", "comfy-table", "csv", "flatbuffers", + "half", "hex", "indexmap", "lazy_static", @@ -111,13 +112,11 @@ checksum = "bef38d45163c2f1dde094a7dfd33ccf595c92905c8f8f4fdc18d06fb1037718a" [[package]] name = "blake2" -version = "0.9.2" +version = "0.10.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0a4e37d16930f5459780f5621038b6382b9bb37c19016f39fb6b5808d831f174" +checksum = "b94ba84325db59637ffc528bbe8c7f86c02c57cff5c0e2b9b00f9a851f42f309" dependencies = [ - "crypto-mac", - "digest", - "opaque-debug", + "digest 0.10.1", ] [[package]] @@ -131,14 +130,14 @@ dependencies = [ "cc", "cfg-if", "constant_time_eq", - "digest", + "digest 0.9.0", ] [[package]] name = "block-buffer" -version = "0.9.0" +version = "0.10.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4152116fd6e9dadb291ae18fc1ec3575ed6d84c29642d97890f4b4a3417297e4" +checksum = "03588e54c62ae6d763e2a80090d50353b785795361b4ff5b3bf0a5097fc31c0b" dependencies = [ "generic-array", ] @@ -206,7 +205,6 @@ dependencies = [ "libc", "num-integer", "num-traits", - "time", "winapi", ] @@ -246,13 +244,12 @@ dependencies = [ ] [[package]] -name = "crypto-mac" -version = "0.8.0" +name = "crypto-common" +version = "0.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b584a330336237c1eecd3e94266efb216c56ed91225d634cb2991c5f3fd1aeab" +checksum = "683d6b536309245c849479fba3da410962a43ed8e51c26b729208ec0ac2798d0" dependencies = [ "generic-array", - "subtle", ] [[package]] @@ -280,8 +277,7 @@ dependencies = [ [[package]] name = "datafusion" version = "6.0.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "79e4a8a1f1ee057b2c27a01f050b9dffe56e8d43605d0201234b353a3cc1eb2f" +source = "git+https://github.com/apache/arrow-datafusion.git#15cfcbc28305e82891a5a52d252fb23c72fd8458" dependencies = [ "ahash", "arrow", @@ -290,12 +286,13 @@ dependencies = [ "blake3", "chrono", "futures", - "hashbrown", + "hashbrown 0.12.0", "lazy_static", "log", "md-5", "num_cpus", "ordered-float 2.10.0", + "parking_lot 0.12.0", "parquet", "paste 1.0.6", "pin-project-lite", @@ -305,6 +302,7 @@ dependencies = [ "sha2", "smallvec", "sqlparser", + "tempfile", "tokio", "tokio-stream", "unicode-segmentation", @@ -332,6 +330,27 @@ dependencies = [ "generic-array", ] +[[package]] +name = "digest" +version = "0.10.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b697d66081d42af4fba142d56918a3cb21dc8eb63372c6b85d14f44fb9c5979b" +dependencies = [ + "block-buffer", + "crypto-common", + "generic-array", + "subtle", +] + +[[package]] +name = "fastrand" +version = "1.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c3fcf0cee53519c866c09b5de1f6c56ff9d647101f81c1964fa632e148896cdf" +dependencies = [ + "instant", +] + [[package]] name = "flatbuffers" version = "2.0.0" @@ -476,11 +495,23 @@ dependencies = [ "wasi 0.10.2+wasi-snapshot-preview1", ] +[[package]] +name = "half" +version = "1.8.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "eabb4a44450da02c90444cf74558da904edde8fb4e9035a9a6a4e15445af0bd7" + [[package]] name = "hashbrown" version = "0.11.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ab5ef0d4909ef3724cc8cce6ccc8572c5c817592e9285f5464f8e86f8bd3726e" + +[[package]] +name = "hashbrown" +version = "0.12.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8c21d40587b92fa6a6c6e3c1bdbf87d75511db5672f9c93175574b3a00df1758" dependencies = [ "ahash", ] @@ -516,7 +547,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "bc633605454125dec4b66843673f01c7df2b89479b32e0ed634e43a91cff62a5" dependencies = [ "autocfg", - "hashbrown", + "hashbrown 0.11.2", ] [[package]] @@ -656,9 +687,9 @@ checksum = "1b03d17f364a3a042d5e5d46b053bbbf82c92c9430c592dd4c064dc6ee997125" [[package]] name = "lock_api" -version = "0.4.5" +version = "0.4.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "712a4d093c9976e24e7dbca41db895dabcbac38eb5f4045393d17a95bdfb1109" +checksum = "88943dd7ef4a2e5a4bfa2753aaab3013e34ce2533d1996fb18ef591e315e2b3b" dependencies = [ "scopeguard", ] @@ -694,13 +725,11 @@ dependencies = [ [[package]] name = "md-5" -version = "0.9.1" +version = "0.10.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7b5a279bb9607f9f53c22d496eade00d138d1bdcccd07d74650387cf94942a15" +checksum = "e6a38fc55c8bbc10058782919516f88826e70320db6d206aebc49611d24216ae" dependencies = [ - "block-buffer", - "digest", - "opaque-debug", + "digest 0.10.1", ] [[package]] @@ -831,12 +860,6 @@ version = "1.9.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "da32515d9f6e6e489d7bc9d84c71b060db7247dc035bbe44eac88cf87486d8d5" -[[package]] -name = "opaque-debug" -version = "0.3.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "624a8340c38c1b80fd549087862da4ba43e08858af025b236e509b6649fc13d5" - [[package]] name = "ordered-float" version = "1.1.1" @@ -863,7 +886,17 @@ checksum = "7d17b78036a60663b797adeaee46f5c9dfebb86948d1255007a1d6be0271ff99" dependencies = [ "instant", "lock_api", - "parking_lot_core", + "parking_lot_core 0.8.5", +] + +[[package]] +name = "parking_lot" +version = "0.12.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "87f5ec2493a61ac0506c0f4199f99070cbe83857b0337006a30f3e6719b8ef58" +dependencies = [ + "lock_api", + "parking_lot_core 0.9.0", ] [[package]] @@ -880,11 +913,24 @@ dependencies = [ "winapi", ] +[[package]] +name = "parking_lot_core" +version = "0.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b2f4f894f3865f6c0e02810fc597300f34dc2510f66400da262d8ae10e75767d" +dependencies = [ + "cfg-if", + "libc", + "redox_syscall", + "smallvec", + "windows-sys", +] + [[package]] name = "parquet" -version = "6.5.0" +version = "8.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "788d9953f4cfbe9db1beff7bebd54299d105e34680d78b82b1ddc85d432cac9d" +checksum = "2d5a6492e0b849fd458bc9364aee4c8a9882b3cc21b2576767162725f69d2ad8" dependencies = [ "arrow", "base64", @@ -903,9 +949,9 @@ dependencies = [ [[package]] name = "parquet-format" -version = "2.6.1" +version = "4.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a5bc6b23543b5dedc8f6cce50758a35e5582e148e0cfa26bd0cacd569cda5b71" +checksum = "1f0c06cdcd5460967c485f9c40a821746f5955ad81990533c7fae95dbd9bc0b5" dependencies = [ "thrift", ] @@ -970,14 +1016,14 @@ dependencies = [ [[package]] name = "pyo3" -version = "0.14.5" +version = "0.15.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "35100f9347670a566a67aa623369293703322bb9db77d99d7df7313b575ae0c8" +checksum = "7cf01dbf1c05af0a14c7779ed6f3aa9deac9c3419606ac9de537a2d649005720" dependencies = [ "cfg-if", "indoc", "libc", - "parking_lot", + "parking_lot 0.11.2", "paste 0.1.18", "pyo3-build-config", "pyo3-macros", @@ -986,18 +1032,18 @@ dependencies = [ [[package]] name = "pyo3-build-config" -version = "0.14.5" +version = "0.15.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d12961738cacbd7f91b7c43bc25cfeeaa2698ad07a04b3be0aa88b950865738f" +checksum = "dbf9e4d128bfbddc898ad3409900080d8d5095c379632fbbfbb9c8cfb1fb852b" dependencies = [ "once_cell", ] [[package]] name = "pyo3-macros" -version = "0.14.5" +version = "0.15.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fc0bc5215d704824dfddddc03f93cb572e1155c68b6761c37005e1c288808ea8" +checksum = "67701eb32b1f9a9722b4bc54b548ff9d7ebfded011c12daece7b9063be1fd755" dependencies = [ "pyo3-macros-backend", "quote", @@ -1006,9 +1052,9 @@ dependencies = [ [[package]] name = "pyo3-macros-backend" -version = "0.14.5" +version = "0.15.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "71623fc593224afaab918aa3afcaf86ed2f43d34f6afde7f3922608f253240df" +checksum = "f44f09e825ee49a105f2c7b23ebee50886a9aee0746f4dd5a704138a64b0218a" dependencies = [ "proc-macro2", "pyo3-build-config", @@ -1138,6 +1184,15 @@ version = "0.6.25" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f497285884f3fcff424ffc933e56d7cbca511def0c9831a7f9b5f6153e3cc89b" +[[package]] +name = "remove_dir_all" +version = "0.5.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3acd125665422973a33ac9d3dd2df85edad0f4ae9b00dafb1a05e43a9f5ef8e7" +dependencies = [ + "winapi", +] + [[package]] name = "ryu" version = "1.0.9" @@ -1181,15 +1236,13 @@ dependencies = [ [[package]] name = "sha2" -version = "0.9.8" +version = "0.10.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b69f9a4c9740d74c5baa3fd2e547f9525fa8088a8a958e0ca2409a514e33f5fa" +checksum = "99c3bd8169c58782adad9290a9af5939994036b76187f7b4f0e6de91dbbfc0ec" dependencies = [ - "block-buffer", "cfg-if", "cpufeatures", - "digest", - "opaque-debug", + "digest 0.10.1", ] [[package]] @@ -1212,9 +1265,9 @@ checksum = "45456094d1983e2ee2a18fdfebce3189fa451699d0502cb8e3b49dba5ba41451" [[package]] name = "sqlparser" -version = "0.12.0" +version = "0.13.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "760e624412a15d5838ae04fad01037beeff1047781431d74360cddd6b3c1c784" +checksum = "b9907f54bd0f7b6ce72c2be1e570a614819ee08e3deb66d90480df341d8a12a8" dependencies = [ "log", ] @@ -1260,6 +1313,20 @@ dependencies = [ "unicode-xid", ] +[[package]] +name = "tempfile" +version = "3.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5cdb1ef4eaeeaddc8fbd371e5017057064af0911902ef36b39801f67cc6d79e4" +dependencies = [ + "cfg-if", + "fastrand", + "libc", + "redox_syscall", + "remove_dir_all", + "winapi", +] + [[package]] name = "thiserror" version = "1.0.30" @@ -1302,16 +1369,6 @@ dependencies = [ "threadpool", ] -[[package]] -name = "time" -version = "0.1.43" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ca8a50ef2360fbd1eeb0ecd46795a87a19024eb4b53c5dc916ca1fd95fe62438" -dependencies = [ - "libc", - "winapi", -] - [[package]] name = "tokio" version = "1.15.0" @@ -1319,6 +1376,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "fbbf1c778ec206785635ce8ad57fe52b3009ae9e0c9f574a728f3049d3e55838" dependencies = [ "num_cpus", + "parking_lot 0.11.2", "pin-project-lite", "tokio-macros", ] @@ -1424,6 +1482,49 @@ version = "0.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f" +[[package]] +name = "windows-sys" +version = "0.29.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ceb069ac8b2117d36924190469735767f0990833935ab430155e71a44bafe148" +dependencies = [ + "windows_aarch64_msvc", + "windows_i686_gnu", + "windows_i686_msvc", + "windows_x86_64_gnu", + "windows_x86_64_msvc", +] + +[[package]] +name = "windows_aarch64_msvc" +version = "0.29.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c3d027175d00b01e0cbeb97d6ab6ebe03b12330a35786cbaca5252b1c4bf5d9b" + +[[package]] +name = "windows_i686_gnu" +version = "0.29.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8793f59f7b8e8b01eda1a652b2697d87b93097198ae85f823b969ca5b89bba58" + +[[package]] +name = "windows_i686_msvc" +version = "0.29.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8602f6c418b67024be2996c512f5f995de3ba417f4c75af68401ab8756796ae4" + +[[package]] +name = "windows_x86_64_gnu" +version = "0.29.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f3d615f419543e0bd7d2b3323af0d86ff19cbc4f816e6453f36a2c2ce889c354" + +[[package]] +name = "windows_x86_64_msvc" +version = "0.29.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "11d95421d9ed3672c280884da53201a5c46b7b2765ca6faf34b0d71cf34a3561" + [[package]] name = "zstd" version = "0.9.1+zstd.1.5.1" diff --git a/Cargo.toml b/Cargo.toml index e855d41..b80d124 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -35,12 +35,15 @@ tokio = { version = "1.0", features = [ "sync", ] } rand = "0.7" -pyo3 = { version = "0.14", features = [ +pyo3 = { version = "0.15", features = [ "extension-module", "abi3", "abi3-py36", ] } -datafusion = { version = "6.0.0", features = ["pyarrow"] } +# datafusion = { version = "6.0.0", features = ["pyarrow"] } +datafusion = { git = "https://github.com/apache/arrow-datafusion.git", features = [ + "pyarrow", +] } uuid = { version = "0.8", features = ["v4"] } async-trait = "0.1.41" futures = "0.3" diff --git a/datafusion/tests/test_pyarrow_dataset.py b/datafusion/tests/test_pyarrow_dataset.py new file mode 100644 index 0000000..5aafabc --- /dev/null +++ b/datafusion/tests/test_pyarrow_dataset.py @@ -0,0 +1,77 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from datetime import date, timedelta +from tempfile import mkdtemp + +import pyarrow as pa +import pyarrow.dataset as ds +import pytest + +from datafusion import ExecutionContext + + +@pytest.fixture +def ctx(): + return ExecutionContext() + + +@pytest.fixture +def table(): + table = pa.table({ + 'z': pa.array([x / 3 for x in range(8)]), + 'x': pa.array(['a'] * 3 + ['b'] * 5), + 'y': pa.array([date(2020, 1, 1) + timedelta(days=x) for x in range(8)]), + }) + return table + + +@pytest.fixture +def dataset(ctx, table): + tmp_dir = mkdtemp() + + part = ds.partitioning( + pa.schema([('x', pa.string()), ('y', pa.date32())]), + flavor="hive", + ) + + ds.write_dataset(table, tmp_dir, partitioning=part, format="parquet") + + dataset = ds.dataset(tmp_dir, partitioning=part) + ctx.register_dataset("ds", dataset) + return dataset + + +def test_catalog(ctx, table, dataset): + catalog_table = ctx.catalog().database().table("ds") + assert catalog_table.kind == "physical" + assert catalog_table.schema == table.schema + + +def test_scan_full(ctx, table, dataset): + result = ctx.sql("SELECT * FROM ds").collect() + assert pa.Table.from_batches(result) == table + + +def test_dataset_filter(ctx: ExecutionContext, table: pa.Table, dataset): + result = ctx.sql("SELECT * FROM ds WHERE y BETWEEN 2020-01-02 AND 2020-01-06 AND x = 'b'").collect() + assert result.record_count() == 3 + + +def test_dataset_project(ctx: ExecutionContext, table: pa.Table, dataset): + result = ctx.sql("SELECT z, y FROM ds").collect() + assert result.col_names() == ['z', 'y'] diff --git a/src/dataframe.rs b/src/dataframe.rs index 9050df9..171c53c 100644 --- a/src/dataframe.rs +++ b/src/dataframe.rs @@ -19,8 +19,10 @@ use std::sync::Arc; use pyo3::prelude::*; +use datafusion::arrow::array::StringArray; use datafusion::arrow::datatypes::Schema; use datafusion::arrow::pyarrow::PyArrowConvert; +use datafusion::arrow::record_batch::RecordBatch; use datafusion::arrow::util::pretty; use datafusion::dataframe::DataFrame; use datafusion::logical_plan::JoinType; @@ -100,6 +102,33 @@ impl PyDataFrame { Ok(pretty::print_batches(&batches)?) } + #[args(verbose = false, analyze = false)] + fn explain(&self, verbose: bool, analyze: bool, py: Python) -> PyResult<()> { + let df = self.df.explain(verbose, analyze)?; + let batches = wait_for_future(py, df.collect())?; + let batch = RecordBatch::concat(&batches[0].schema(), &batches)?; + + let plan_types = batch + .column(0) + .as_any() + .downcast_ref::() + .expect("Plan types is not a String anymore"); + let plans = batch + .column(1) + .as_any() + .downcast_ref::() + .expect("Plan is not a String anymore"); + + for (plan_type, plan) in plan_types.iter().zip(plans.iter()) { + if plan_type.is_some() && plan.is_some() { + println!("{}", plan_type.unwrap()); + println!("{}", plan.unwrap()); + } + } + + Ok(()) + } + fn join( &self, right: PyDataFrame, diff --git a/src/dataset.rs b/src/dataset.rs index b6332ea..54ad8f0 100644 --- a/src/dataset.rs +++ b/src/dataset.rs @@ -24,15 +24,14 @@ use datafusion::datasource::datasource::TableProviderFilterPushDown; use datafusion::datasource::TableProvider; use datafusion::error::DataFusionError; use datafusion::error::Result; -use datafusion::logical_plan::Expr; -use datafusion::logical_plan::Operator; +use datafusion::execution::runtime_env::RuntimeEnv; +use datafusion::logical_plan::{Expr, ExpressionVisitor, Operator, Recursion}; use datafusion::physical_plan::metrics::{ExecutionPlanMetricsSet, MetricsSet}; use datafusion::physical_plan::stream::RecordBatchReceiverStream; use datafusion::physical_plan::{ DisplayFormatType, ExecutionPlan, Partitioning, SendableRecordBatchStream, Statistics, }; -use datafusion::scalar::ScalarValue::*; -use pyo3::conversion::ToPyObject; +use pyo3::conversion::PyArrowConvert; use pyo3::exceptions::{PyAssertionError, PyNotImplementedError, PyStopIteration}; use pyo3::prelude::*; use pyo3::types::PyDict; @@ -79,45 +78,30 @@ impl TableProvider for PyArrowDatasetTable { async fn scan( &self, projection: &Option>, - batch_size: usize, filters: &[Expr], limit: Option, ) -> Result> { - let scanner = Python::with_gil(|py| -> PyResult> { - let scanner_kwargs = PyDict::new(py); - scanner_kwargs.set_item("batch_size", batch_size)?; - - let combined_filter = filters - .iter() - .map(|f| f.clone()) - .reduce(|acc, item| acc.and(item)); - if let Some(expr) = combined_filter { - scanner_kwargs.set_item("filter", expr_to_pyarrow(&expr)?)?; - }; - - if let Some(indices) = projection { - let column_names: Vec = self - .schema - .project(indices)? - .fields() - .iter() - .map(|field| field.name().clone()) - .collect(); - scanner_kwargs.set_item("columns", column_names)?; - } + let combined_filter = filters + .iter() + .map(|f| f.clone()) + .reduce(|acc, item| acc.and(item)); + let scanner = PyArrowDatasetScanner::make( + self.dataset.clone(), + self.schema.clone(), + projection, + combined_filter.clone(), + limit, + 10, // Dummy value; scanner recreated later with runtime batch_size. + ); - Ok(self - .dataset - .call_method(py, "scanner", (), Some(scanner_kwargs))? - .extract(py)?) - }); match scanner { Ok(scanner) => Ok(Arc::new(PyArrowDatasetExec { - scanner: PyArrowDatasetScanner { - scanner: Arc::new(scanner), - limit, - schema: self.schema.clone(), - }, + dataset: self.dataset.clone(), + scanner, + projection: projection.clone(), + filter: combined_filter, + limit, + schema: self.schema.clone(), projected_statistics: Statistics::default(), metrics: ExecutionPlanMetricsSet::new(), })), @@ -158,19 +142,63 @@ impl Clone for PyArrowDatasetScanner { } impl PyArrowDatasetScanner { - fn projected_schema(self) -> SchemaRef { - self.schema + fn make( + dataset: Py, + schema: SchemaRef, + projection: &Option>, + filter: Option, + limit: Option, + batch_size: usize, + ) -> Result { + let scanner = Python::with_gil(|py| -> PyResult> { + let scanner_kwargs = PyDict::new(py); + scanner_kwargs.set_item("batch_size", batch_size)?; + if let Some(expr) = filter { + scanner_kwargs.set_item("filter", expr_to_pyarrow(&expr)?)?; + }; + + if let Some(indices) = projection { + let column_names: Vec = schema + .project(indices)? + .fields() + .iter() + .map(|field| field.name().clone()) + .collect(); + scanner_kwargs.set_item("columns", column_names)?; + } + + Ok(dataset + .call_method(py, "scanner", (), Some(scanner_kwargs))? + .extract(py)?) + }); + match scanner { + Ok(scanner) => Ok(Self { + scanner: Arc::new(scanner), + limit, + schema, + }), + Err(err) => Err(DataFusionError::Execution(err.to_string())), + } + } + + fn projected_schema(&self) -> SchemaRef { + self.schema.clone() } fn get_batches(&self, response_tx: Sender>) -> Result<()> { let mut count = 0; + // TODO: Avoid Python GIL with Arrow C Stream interface? + // https://arrow.apache.org/docs/dev/format/CStreamInterface.html + // https://github.com/apache/arrow/blob/cc4e2a54309813e6bbbb36ba50bcd22a7b71d3d9/python/pyarrow/ipc.pxi#L620 + let batch_iter = Python::with_gil(|py| self.scanner.call_method0(py, "to_batches")) + .map_err(|err| DataFusionError::Execution(err.to_string()))?; + loop { // TODO: Avoid Python GIL with Arrow C Stream interface? // https://arrow.apache.org/docs/dev/format/CStreamInterface.html // https://github.com/apache/arrow/blob/cc4e2a54309813e6bbbb36ba50bcd22a7b71d3d9/python/pyarrow/ipc.pxi#L620 let res = Python::with_gil(|py| -> PyResult> { - let batch_iter = self.scanner.call_method0(py, "to_batches")?; let py_batch_res = batch_iter.call_method0(py, "__next__"); match py_batch_res { Ok(py_batch) => Ok(Some(RecordBatch::from_pyarrow(py_batch.extract(py)?)?)), @@ -226,7 +254,12 @@ fn send_result( /// Execution plan for scanning a PyArrow dataset #[derive(Debug, Clone)] pub struct PyArrowDatasetExec { + dataset: Py, scanner: PyArrowDatasetScanner, + projection: Option>, + filter: Option, + limit: Option, + schema: SchemaRef, projected_statistics: Statistics, /// Execution metrics metrics: ExecutionPlanMetricsSet, @@ -240,7 +273,7 @@ impl ExecutionPlan for PyArrowDatasetExec { } fn schema(&self) -> SchemaRef { - self.scanner.clone().projected_schema().clone() + self.scanner.projected_schema() } fn children(&self) -> Vec> { @@ -267,22 +300,35 @@ impl ExecutionPlan for PyArrowDatasetExec { } } - async fn execute(&self, _partition_index: usize) -> Result { + async fn execute( + &self, + _partition_index: usize, + runtime: Arc, + ) -> Result { + // need to use runtime.batch_size let (response_tx, response_rx): ( Sender>, Receiver>, ) = channel(2); - let cloned = self.scanner.clone(); + // Have to recreate with correct batch size + let scanner = PyArrowDatasetScanner::make( + self.dataset.clone(), + self.schema.clone(), + &self.projection, + self.filter.clone(), + self.limit, + runtime.batch_size, + )?; let join_handle = task::spawn_blocking(move || { - if let Err(e) = cloned.get_batches(response_tx) { + if let Err(e) = scanner.get_batches(response_tx) { println!("Dataset scanner thread terminated due to error: {:?}", e); } }); Ok(RecordBatchReceiverStream::create( - &self.scanner.clone().projected_schema(), + &self.scanner.projected_schema(), response_rx, join_handle, )) @@ -310,106 +356,82 @@ impl ExecutionPlan for PyArrowDatasetExec { } } -// TODO: replace with impl PyArrowConvert for Expr -// https://github.com/apache/arrow-rs/blob/master/arrow/src/pyarrow.rs -fn expr_to_pyarrow(expr: &Expr) -> PyResult { - Python::with_gil(|py| -> PyResult { - let ds = PyModule::import(py, "pyarrow.dataset")?; - let field = ds.getattr("field")?; +struct PyArrowExprVisitor { + result_stack: Vec, +} - let mut worklist: Vec<&Expr> = Vec::new(); // Expressions to parse - let mut result_list: Vec = Vec::new(); // Expressions that have been parsed - worklist.push(expr); +impl ExpressionVisitor for PyArrowExprVisitor { + fn pre_visit(mut self, _expr: &Expr) -> Result> { + Ok(Recursion::Continue(self)) + } - while let Some(parent) = worklist.pop() { - match parent { + fn post_visit(mut self, expr: &Expr) -> Result { + let res = Python::with_gil(|py| -> PyResult<()> { + let ds = PyModule::import(py, "pyarrow.dataset")?; + let field = ds.getattr("field")?; + + match expr { Expr::Column(col) => { - result_list.push(field.call1((col.name.clone(),))?.into()); + self.result_stack + .push(field.call1((col.name.clone(),))?.into()); } - // TODO: finish implementing PyArrowConvert for ScalarValue? - // https://github.com/apache/arrow-datafusion/blob/master/datafusion/src/pyarrow.rs Expr::Literal(scalar) => { - match scalar { - Boolean(val) => { - result_list.push(val.to_object(py)); - } - Float32(val) => { - result_list.push(val.to_object(py)); - } - Float64(val) => { - result_list.push(val.to_object(py)); - } - Int8(val) => { - result_list.push(val.to_object(py)); - } - Int16(val) => { - result_list.push(val.to_object(py)); - } - Int32(val) => { - result_list.push(val.to_object(py)); - } - Int64(val) => { - result_list.push(val.to_object(py)); - } - UInt8(val) => { - result_list.push(val.to_object(py)); - } - UInt16(val) => { - result_list.push(val.to_object(py)); - } - UInt32(val) => { - result_list.push(val.to_object(py)); - } - UInt64(val) => { - result_list.push(val.to_object(py)); - } - Utf8(val) => { - result_list.push(val.to_object(py)); - } - // TODO: indicate which somehow? - _ => { - return Err(PyNotImplementedError::new_err( - "Scalar type not yet supported", - )); - } - } + self.result_stack.push(scalar.to_pyarrow(py)?); } - Expr::BinaryExpr { left, right, op } => { - let left_val = result_list.pop(); - let right_val = result_list.pop(); - match (left_val, right_val) { - (Some(left_val), Some(right_val)) => { - match op { - // pull children off of result_list - Operator::Eq => result_list.push(left_val.call_method1( - py, - "__eq__", - (right_val,), - )?), - Operator::NotEq => result_list.push(left_val.call_method1( - py, - "__ne__", - (right_val,), - )?), - _ => { - return Err(PyNotImplementedError::new_err( - "Operation not yet supported", - )); - } - } - } - (None, None) => { - // Need to process children first - worklist.push(parent); - worklist.push(&**left); - worklist.push(&**right); - } - _ => { - return Err(PyNotImplementedError::new_err( - "Operation not yet supported", - )); - } + Expr::BinaryExpr { + left: _, + right: _, + op, + } => { + // Must be pop'd in reverse order of visitation + let right_val = self.result_stack.pop().unwrap(); + let left_val = self.result_stack.pop().unwrap(); + + let method = match op { + Operator::Eq => Ok("__eq__"), + Operator::NotEq => Ok("__ne__"), + Operator::Lt => Ok("__lt__"), + Operator::LtEq => Ok("__le__"), + Operator::Gt => Ok("__gt__"), + Operator::GtEq => Ok("__gt__"), + Operator::Plus => Ok("__add__"), + Operator::Minus => Ok("__sub__"), + Operator::Multiply => Ok("__mul__"), + Operator::Divide => Ok("__div__"), + Operator::Modulo => Ok("__mod__"), + Operator::Or => Ok("__or__"), + Operator::And => Ok("__and__"), + _ => Err(PyNotImplementedError::new_err( + "Operation not yet supported", + )), + }; + + self.result_stack + .push(left_val.call_method1(py, method?, (right_val,))?); + } + Expr::Not(expr) => { + let val = self.result_stack.pop().unwrap(); + + self.result_stack.push(val.call_method0(py, "__not__")?); + } + Expr::Between { + expr: _, + negated, + low: _, + high: _, + } => { + // Must be pop'd in reverse order of visitation + let high_val = self.result_stack.pop().unwrap(); + let low_val = self.result_stack.pop().unwrap(); + let expr_val = self.result_stack.pop().unwrap(); + + let gte_val = expr_val.call_method1(py, "__ge__", (low_val,))?; + let lte_val = expr_val.call_method1(py, "__le__", (high_val,))?; + let mut val = gte_val.call_method1(py, "__and__", (lte_val,))?; + if *negated { + val = val.call_method0(py, "__not__")?; } + self.result_stack.push(val); } _ => { return Err(PyNotImplementedError::new_err( @@ -417,10 +439,28 @@ fn expr_to_pyarrow(expr: &Expr) -> PyResult { )); } } + Ok(()) + }); + + match res { + Ok(_) => Ok(self), + Err(err) => Err(DataFusionError::External(Box::new(err))), } + } +} + +// TODO: replace with some Substrait conversion? +// https://github.com/apache/arrow-rs/blob/master/arrow/src/pyarrow.rs +fn expr_to_pyarrow(expr: &Expr) -> PyResult { + Python::with_gil(|py| -> PyResult { + let visitor = PyArrowExprVisitor { + result_stack: Vec::new(), + }; + + let mut final_visitor = expr.accept(visitor)?; - match result_list.len() { - 1 => Ok(result_list.pop().unwrap()), + match final_visitor.result_stack.len() { + 1 => Ok(final_visitor.result_stack.pop().unwrap()), _ => Err(PyAssertionError::new_err("something went wrong")), } }) diff --git a/src/udaf.rs b/src/udaf.rs index 1de6e63..c25fd1f 100644 --- a/src/udaf.rs +++ b/src/udaf.rs @@ -49,16 +49,6 @@ impl Accumulator for RustAccumulator { .map_err(|e| DataFusionError::Execution(format!("{}", e))) } - fn update(&mut self, _values: &[ScalarValue]) -> Result<()> { - // no need to implement as datafusion does not use it - todo!() - } - - fn merge(&mut self, _states: &[ScalarValue]) -> Result<()> { - // no need to implement as datafusion does not use it - todo!() - } - fn evaluate(&self) -> Result { Python::with_gil(|py| self.accum.as_ref(py).call_method0("evaluate")?.extract()) .map_err(|e| DataFusionError::Execution(format!("{}", e))) @@ -144,7 +134,6 @@ impl PyAggregateUDF { } /// creates a new PyExpr with the call of the udf - #[call] #[args(args = "*")] fn __call__(&self, args: Vec) -> PyResult { let args = args.iter().map(|e| e.expr.clone()).collect(); diff --git a/src/udf.rs b/src/udf.rs index 379c449..8251739 100644 --- a/src/udf.rs +++ b/src/udf.rs @@ -24,9 +24,7 @@ use datafusion::arrow::datatypes::DataType; use datafusion::arrow::pyarrow::PyArrowConvert; use datafusion::error::DataFusionError; use datafusion::logical_plan; -use datafusion::physical_plan::functions::{ - make_scalar_function, ScalarFunctionImplementation, -}; +use datafusion::physical_plan::functions::{make_scalar_function, ScalarFunctionImplementation}; use datafusion::physical_plan::udf::ScalarUDF; use crate::expression::PyExpr; @@ -89,7 +87,6 @@ impl PyScalarUDF { } /// creates a new PyExpr with the call of the udf - #[call] #[args(args = "*")] fn __call__(&self, args: Vec) -> PyResult { let args = args.iter().map(|e| e.expr.clone()).collect(); From 48cda7e1c02a32e130047e942cf39862346496b3 Mon Sep 17 00:00:00 2001 From: Will Jones Date: Sun, 20 Feb 2022 12:52:03 -0800 Subject: [PATCH 3/3] Make filter pushdown partial --- src/dataset.rs | 29 ++++++++++++++++++++++++----- 1 file changed, 24 insertions(+), 5 deletions(-) diff --git a/src/dataset.rs b/src/dataset.rs index 54ad8f0..b7935ce 100644 --- a/src/dataset.rs +++ b/src/dataset.rs @@ -31,7 +31,6 @@ use datafusion::physical_plan::stream::RecordBatchReceiverStream; use datafusion::physical_plan::{ DisplayFormatType, ExecutionPlan, Partitioning, SendableRecordBatchStream, Statistics, }; -use pyo3::conversion::PyArrowConvert; use pyo3::exceptions::{PyAssertionError, PyNotImplementedError, PyStopIteration}; use pyo3::prelude::*; use pyo3::types::PyDict; @@ -48,6 +47,24 @@ pub struct PyArrowDatasetTable { schema: SchemaRef, } +impl PyArrowDatasetTable { + /// Returns true if expression can by evaluated by pyarrow against this dataset + fn expression_valid(&self, expr: &Expr) -> bool { + if let Ok(pyarrow_expr) = expr_to_pyarrow(expr) { + let res = Python::with_gil(|py| -> PyResult<()> { + let scanner_kwargs = PyDict::new(py); + scanner_kwargs.set_item("filter", pyarrow_expr)?; + self.dataset + .call_method(py, "scanner", (), Some(scanner_kwargs))?; + Ok(()) + }); + res.is_ok() + } else { + false + } + } +} + impl<'py> FromPyObject<'py> for PyArrowDatasetTable { fn extract(ob: &'py PyAny) -> PyResult { // Check it's a PyArrow dataset @@ -81,8 +98,11 @@ impl TableProvider for PyArrowDatasetTable { filters: &[Expr], limit: Option, ) -> Result> { + // Filtering is only inexact because of expression conversion, but the + // PyArrow scanner does apply all filters given to it. let combined_filter = filters .iter() + .filter(|expr| self.expression_valid(expr)) .map(|f| f.clone()) .reduce(|acc, item| acc.and(item)); let scanner = PyArrowDatasetScanner::make( @@ -110,7 +130,7 @@ impl TableProvider for PyArrowDatasetTable { } fn supports_filter_pushdown(&self, _: &Expr) -> Result { - Ok(TableProviderFilterPushDown::Exact) + Ok(TableProviderFilterPushDown::Inexact) } } @@ -339,9 +359,8 @@ impl ExecutionPlan for PyArrowDatasetExec { DisplayFormatType::Default => { write!( f, - // TODO: better fmt - "PyArrowDatasetExec: limit={:?}, partitions=...", - self.scanner.limit + "PyArrowDatasetExec: filter={:?}, limit={:?}, projection={:?} partitions=...", + self.filter, self.scanner.limit, self.projection ) } }