From 14bf39d4dbd28368b267aec728fba345764f759c Mon Sep 17 00:00:00 2001 From: Guillaume Balaine Date: Tue, 15 Feb 2022 20:00:04 +0100 Subject: [PATCH] Arrow2 02092022 (#1795) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * feat: add join type for logical plan display (#1674) * (minor) Reduce memory manager and disk manager logs from `info!` to `debug!` (#1689) * Move `information_schema` tests out of execution/context.rs to `sql_integration` tests (#1684) * Move tests from context.rs to information_schema.rs * Fix up tests to compile * Move timestamp related tests out of context.rs and into sql integration test (#1696) * Move some tests out of context.rs and into sql * Move support test out of context.rs and into sql tests * Fixup tests and make them compile * Add `MemTrackingMetrics` to ease memory tracking for non-limited memory consumers (#1691) * Memory manager no longer track consumers, update aggregatedMetricsSet * Easy memory tracking with metrics * use tracking metrics in SPMS * tests * fix * doc * Update datafusion/src/physical_plan/sorts/sort.rs Co-authored-by: Andrew Lamb * make tracker AtomicUsize Co-authored-by: Andrew Lamb * Implement TableProvider for DataFrameImpl (#1699) * Add TableProvider impl for DataFrameImpl * Add physical plan in * Clean up plan construction and names construction * Remove duplicate comments * Remove unused parameter * Add test * Remove duplicate limit comment * Use cloned instead of individual clone * Reduce the amount of code to get a schema Co-authored-by: Andrew Lamb * Add comments to test * Fix plan comparison * Compare only the results of execution * Remove println * Refer to df_impl instead of table in test Co-authored-by: Andrew Lamb * Fix the register_table test to use the correct result set for comparison * Consolidate group/agg exprs * Format * Remove outdated comment Co-authored-by: Andrew Lamb * refine test in repartition.rs & coalesce_batches.rs (#1707) * Fuzz test for spillable sort (#1706) * Lazy TempDir creation in DiskManager (#1695) * Incorporate dyn scalar kernels (#1685) * Rebase * impl ToNumeric for ScalarValue * Update macro to be based on * Add floats * Cleanup * Newline * add annotation for select_to_plan (#1714) * Support `create_physical_expr` and `ExecutionContextState` or `DefaultPhysicalPlanner` for faster speed (#1700) * Change physical_expr creation API * Refactor API usage to avoid creating ExecutionContextState * Fixup ballista * clippy! * Fix can not load parquet table form spark in datafusion-cli. (#1665) * fix can not load parquet table form spark * add Invalid file in log. * fix fmt * add upper bound for pub fn (#1713) Signed-off-by: remzi <13716567376yh@gmail.com> * Create SchemaAdapter trait to map table schema to file schemas (#1709) * Create SchemaAdapter trait to map table schema to file schemas * Linting fix * Remove commented code * approx_quantile() aggregation function (#1539) * feat: implement TDigest for approx quantile Adds a [TDigest] implementation providing approximate quantile estimations of large inputs using a small amount of (bounded) memory. A TDigest is most accurate near either "end" of the quantile range (that is, 0.1, 0.9, 0.95, etc) due to the use of a scalaing function that increases resolution at the tails. The paper claims single digit part per million errors for q ≤ 0.001 or q ≥ 0.999 using 100 centroids, and in practice I have found accuracy to be more than acceptable for an apprixmate function across the entire quantile range. The implementation is a modified copy of https://github.com/MnO2/t-digest, itself a Rust port of [Facebook's C++ implementation]. Both Facebook's implementation, and Mn02's Rust port are Apache 2.0 licensed. [TDigest]: https://arxiv.org/abs/1902.04023 [Facebook's C++ implementation]: https://github.com/facebook/folly/blob/main/folly/stats/TDigest.h * feat: approx_quantile aggregation Adds the ApproxQuantile physical expression, plumbing & test cases. The function signature is: approx_quantile(column, quantile) Where column can be any numeric type (that can be cast to a float64) and quantile is a float64 literal between 0 and 1. * feat: approx_quantile dataframe function Adds the approx_quantile() dataframe function, and exports it in the prelude. * refactor: bastilla approx_quantile support Adds bastilla wire encoding for approx_quantile. Adding support for this required modifying the AggregateExprNode proto message to support propigating multiple LogicalExprNode aggregate arguments - all the existing aggregations take a single argument, so this wasn't needed before. This commit adds "repeated" to the expr field, which I believe is backwards compatible as described here: https://developers.google.com/protocol-buffers/docs/proto3#updating Specifically, adding "repeated" to an existing message field: "For ... message fields, optional is compatible with repeated" No existing tests needed fixing, and a new roundtrip test is included that covers the change to allow multiple expr. * refactor: use input type as return type Casts the calculated quantile value to the same type as the input data. * fixup! refactor: bastilla approx_quantile support * refactor: rebase onto main * refactor: validate quantile value Ensures the quantile values is between 0 and 1, emitting a plan error if not. * refactor: rename to approx_percentile_cont * refactor: clippy lints * suppport bitwise and as an example (#1653) * suppport bitwise and as an example * Use $OP in macro rather than `&` * fix: change signature to &dyn Array * fmt Co-authored-by: Andrew Lamb * fix: substr - correct behaivour with negative start pos (#1660) * minor: fix cargo run --release error (#1723) * Convert boolean case expressions to boolean logic (#1719) * Convert boolean case expressions to boolean logic * Review feedback * substitute `parking_lot::Mutex` for `std::sync::Mutex` (#1720) * Substitute parking_lot::Mutex for std::sync::Mutex * enable parking_lot feature in tokio * Add Expression Simplification API (#1717) * Add Expression Simplification API * fmt * Add tests and CI for optional pyarrow module (#1711) * Implement other side of conversion * Add test workflow * Add (failing) tests * Get unit tests passing * Use python -m pip * Debug LD_LIBRARY_PATH * Set LIBRARY_PATH * Update help with better info * Update parking_lot requirement from 0.11 to 0.12 (#1735) Updates the requirements on [parking_lot](https://github.com/Amanieu/parking_lot) to permit the latest version. - [Release notes](https://github.com/Amanieu/parking_lot/releases) - [Changelog](https://github.com/Amanieu/parking_lot/blob/master/CHANGELOG.md) - [Commits](https://github.com/Amanieu/parking_lot/compare/0.11.0...0.12.0) --- updated-dependencies: - dependency-name: parking_lot dependency-type: direct:production ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> * Prevent repartitioning of certain operator's direct children (#1731) (#1732) * Prevent repartitioning of certain operator's direct children (#1731) * Update ballista tests * Don't repartition children of RepartitionExec * Revert partition restriction on Repartition and Projection * Review feedback * Lint * API to get Expr's type and nullability without a `DFSchema` (#1726) * API to get Expr type and nullability without a `DFSchema` * Add test * publically export * Improve docs * Fix typos in crate documentation (#1739) * add `cargo check --release` to ci (#1737) * remote test * Update .github/workflows/rust.yml Co-authored-by: Andrew Lamb Co-authored-by: Andrew Lamb * Move optimize test out of context.rs (#1742) * Move optimize test out of context.rs * Update * use clap 3 style args parsing for datafusion cli (#1749) * use clap 3 style args parsing for datafusion cli * upgrade cli version * Add partitioned_csv setup code to sql_integration test (#1743) * use ordered-float 2.10 (#1756) Signed-off-by: Andy Grove * #1768 Support TimeUnit::Second in hasher (#1769) * Support TimeUnit::Second in hasher * fix linter * format (#1745) * Create built-in scalar functions programmatically (#1734) * create build-in scalar functions programatically Signed-off-by: remzi <13716567376yh@gmail.com> * solve conflict Signed-off-by: remzi <13716567376yh@gmail.com> * fix spelling mistake Signed-off-by: remzi <13716567376yh@gmail.com> * rename to call_fn Signed-off-by: remzi <13716567376yh@gmail.com> * [split/1] split datafusion-common module (#1751) * split datafusion-common module * pyarrow * Update datafusion-common/README.md Co-authored-by: Andy Grove * Update datafusion/Cargo.toml * include publishing Co-authored-by: Andy Grove * fix: Case insensitive unquoted identifiers (#1747) * move dfschema and column (#1758) * add datafusion-expr module (#1759) * move column, dfschema, etc. to common module (#1760) * include window frames and operator into datafusion-expr (#1761) * move signature, type signature, and volatility to split module (#1763) * [split/10] split up expr for rewriting, visiting, and simplification traits (#1774) * split up expr for rewriting, visiting, and simplification * add docs * move built-in scalar functions (#1764) * split expr type and null info to be expr-schemable (#1784) * rewrite predicates before pushing to union inputs (#1781) * move accumulator and columnar value (#1765) * move accumulator and columnar value (#1762) * fix bad data type in test_try_cast_decimal_to_decimal * added projections for avro columns Co-authored-by: xudong.w Co-authored-by: Andrew Lamb Co-authored-by: Yijie Shen Co-authored-by: Phillip Cloud <417981+cpcloud@users.noreply.github.com> Co-authored-by: Matthew Turner Co-authored-by: Yang <37145547+Ted-Jiang@users.noreply.github.com> Co-authored-by: Remzi Yang <59198230+HaoYang670@users.noreply.github.com> Co-authored-by: Dan Harris <1327726+thinkharderdev@users.noreply.github.com> Co-authored-by: Dom Co-authored-by: Kun Liu Co-authored-by: Dmitry Patsura Co-authored-by: Raphael Taylor-Davies <1781103+tustvold@users.noreply.github.com> Co-authored-by: Will Jones Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Co-authored-by: r.4ntix Co-authored-by: Jiayu Liu Co-authored-by: Andy Grove Co-authored-by: Rich Co-authored-by: Marko Mikulicic Co-authored-by: Eduard Karacharov <13005055+korowa@users.noreply.github.com> --- .github/workflows/rust.yml | 59 +- Cargo.toml | 6 +- README.md | 354 +- ballista/rust/client/Cargo.toml | 2 +- ballista/rust/core/Cargo.toml | 2 +- ballista/rust/executor/Cargo.toml | 2 +- ballista/rust/scheduler/Cargo.toml | 2 +- datafusion-cli/Cargo.toml | 3 +- datafusion-cli/src/command.rs | 11 +- datafusion-cli/src/exec.rs | 9 +- datafusion-cli/src/functions.rs | 4 +- datafusion-cli/src/lib.rs | 1 - datafusion-cli/src/main.rs | 162 +- datafusion-cli/src/print_format.rs | 72 +- datafusion-common/Cargo.toml | 44 + datafusion-common/README.md | 24 + datafusion-common/src/column.rs | 150 + datafusion-common/src/dfschema.rs | 720 ++++ datafusion-common/src/error.rs | 184 + datafusion-common/src/field_util.rs | 490 +++ datafusion-common/src/lib.rs | 30 + datafusion-common/src/pyarrow.rs | 247 ++ datafusion-common/src/record_batch.rs | 449 +++ datafusion-common/src/scalar.rs | 2992 +++++++++++++++++ datafusion-common/src/scalar_tmp.rs | 2992 +++++++++++++++++ datafusion-expr/Cargo.toml | 41 + datafusion-expr/README.md | 24 + datafusion-expr/src/accumulator.rs | 44 + datafusion-expr/src/aggregate_function.rs | 93 + datafusion-expr/src/built_in_function.rs | 330 ++ datafusion-expr/src/columnar_value.rs | 63 + datafusion-expr/src/expr.rs | 698 ++++ datafusion-expr/src/expr_fn.rs | 32 + datafusion-expr/src/function.rs | 46 + datafusion-expr/src/lib.rs | 49 + datafusion-expr/src/literal.rs | 138 + datafusion-expr/src/operator.rs | 140 + datafusion-expr/src/signature.rs | 116 + datafusion-expr/src/udaf.rs | 92 + datafusion-expr/src/udf.rs | 93 + datafusion-expr/src/window_frame.rs | 381 +++ datafusion-expr/src/window_function.rs | 204 ++ datafusion/Cargo.toml | 11 +- datafusion/benches/sort_limit_query_sql.rs | 3 + datafusion/fuzz-utils/Cargo.toml | 2 +- datafusion/fuzz-utils/src/lib.rs | 11 +- .../src/avro_to_arrow/arrow_array_reader.rs | 2 + datafusion/src/avro_to_arrow/reader.rs | 23 +- datafusion/src/avro_to_arrow/schema.rs | 1 - datafusion/src/dataframe.rs | 3 +- .../src/datasource/file_format/parquet.rs | 6 +- datafusion/src/datasource/listing/helpers.rs | 2 +- datafusion/src/datasource/memory.rs | 1 - datafusion/src/error.rs | 171 +- datafusion/src/execution/context.rs | 561 ++-- datafusion/src/execution/dataframe_impl.rs | 4 +- datafusion/src/field_util.rs | 474 +-- datafusion/src/lib.rs | 4 +- datafusion/src/logical_plan/builder.rs | 49 +- datafusion/src/logical_plan/dfschema.rs | 667 +--- datafusion/src/logical_plan/expr.rs | 1966 +---------- datafusion/src/logical_plan/expr_rewriter.rs | 592 ++++ datafusion/src/logical_plan/expr_schema.rs | 232 ++ datafusion/src/logical_plan/expr_simplier.rs | 97 + datafusion/src/logical_plan/expr_visitor.rs | 176 + datafusion/src/logical_plan/mod.rs | 26 +- datafusion/src/logical_plan/operators.rs | 123 +- datafusion/src/logical_plan/window_frames.rs | 363 +- .../src/optimizer/common_subexpr_eliminate.rs | 4 +- datafusion/src/optimizer/filter_push_down.rs | 54 +- .../src/optimizer/simplify_expressions.rs | 62 +- .../optimizer/single_distinct_to_groupby.rs | 1 + datafusion/src/optimizer/utils.rs | 6 +- .../src/physical_optimizer/repartition.rs | 212 +- datafusion/src/physical_plan/aggregates.rs | 87 +- .../src/physical_plan/expressions/try_cast.rs | 2 +- .../src/physical_plan/file_format/parquet.rs | 52 +- datafusion/src/physical_plan/functions.rs | 434 +-- datafusion/src/physical_plan/hash_utils.rs | 10 + datafusion/src/physical_plan/limit.rs | 5 + datafusion/src/physical_plan/mod.rs | 70 +- datafusion/src/physical_plan/udaf.rs | 83 +- datafusion/src/physical_plan/udf.rs | 85 +- datafusion/src/physical_plan/union.rs | 4 + .../src/physical_plan/window_functions.rs | 186 +- datafusion/src/pyarrow.rs | 96 - datafusion/src/record_batch.rs | 452 +-- datafusion/src/scalar.rs | 1904 +---------- datafusion/src/sql/planner.rs | 21 +- datafusion/src/sql/utils.rs | 10 + datafusion/tests/order_spill_fuzz.rs | 6 +- datafusion/tests/parquet_pruning.rs | 75 +- datafusion/tests/simplification.rs | 2 + datafusion/tests/sql/explain.rs | 60 + datafusion/tests/sql/mod.rs | 17 + datafusion/tests/sql/partitioned_csv.rs | 95 + datafusion/tests/sql/projection.rs | 192 ++ datafusion/tests/sql/select.rs | 58 + docs/source/index.rst | 1 + .../source/specification/quarterly_roadmap.md | 72 + docs/source/user-guide/sql/index.rst | 1 + docs/source/user-guide/sql/sql_status.md | 241 ++ 102 files changed, 13702 insertions(+), 8123 deletions(-) create mode 100644 datafusion-common/Cargo.toml create mode 100644 datafusion-common/README.md create mode 100644 datafusion-common/src/column.rs create mode 100644 datafusion-common/src/dfschema.rs create mode 100644 datafusion-common/src/error.rs create mode 100644 datafusion-common/src/field_util.rs create mode 100644 datafusion-common/src/lib.rs create mode 100644 datafusion-common/src/pyarrow.rs create mode 100644 datafusion-common/src/record_batch.rs create mode 100644 datafusion-common/src/scalar.rs create mode 100644 datafusion-common/src/scalar_tmp.rs create mode 100644 datafusion-expr/Cargo.toml create mode 100644 datafusion-expr/README.md create mode 100644 datafusion-expr/src/accumulator.rs create mode 100644 datafusion-expr/src/aggregate_function.rs create mode 100644 datafusion-expr/src/built_in_function.rs create mode 100644 datafusion-expr/src/columnar_value.rs create mode 100644 datafusion-expr/src/expr.rs create mode 100644 datafusion-expr/src/expr_fn.rs create mode 100644 datafusion-expr/src/function.rs create mode 100644 datafusion-expr/src/lib.rs create mode 100644 datafusion-expr/src/literal.rs create mode 100644 datafusion-expr/src/operator.rs create mode 100644 datafusion-expr/src/signature.rs create mode 100644 datafusion-expr/src/udaf.rs create mode 100644 datafusion-expr/src/udf.rs create mode 100644 datafusion-expr/src/window_frame.rs create mode 100644 datafusion-expr/src/window_function.rs delete mode 100644 datafusion/src/avro_to_arrow/schema.rs create mode 100644 datafusion/src/logical_plan/expr_rewriter.rs create mode 100644 datafusion/src/logical_plan/expr_schema.rs create mode 100644 datafusion/src/logical_plan/expr_simplier.rs create mode 100644 datafusion/src/logical_plan/expr_visitor.rs delete mode 100644 datafusion/src/pyarrow.rs create mode 100644 datafusion/tests/sql/explain.rs create mode 100644 datafusion/tests/sql/partitioned_csv.rs create mode 100644 docs/source/specification/quarterly_roadmap.md create mode 100644 docs/source/user-guide/sql/sql_status.md diff --git a/.github/workflows/rust.yml b/.github/workflows/rust.yml index 8a7f6737ded5..a9ff52e65760 100644 --- a/.github/workflows/rust.yml +++ b/.github/workflows/rust.yml @@ -58,12 +58,18 @@ jobs: rustup toolchain install ${{ matrix.rust }} rustup default ${{ matrix.rust }} rustup component add rustfmt - - name: Build Workspace + - name: Build workspace in debug mode run: | cargo build env: CARGO_HOME: "/github/home/.cargo" - CARGO_TARGET_DIR: "/github/home/target" + CARGO_TARGET_DIR: "/github/home/target/debug" + - name: Build workspace in release mode + run: | + cargo check --release + env: + CARGO_HOME: "/github/home/.cargo" + CARGO_TARGET_DIR: "/github/home/target/release" - name: Check DataFusion Build without default features run: | cargo check --no-default-features -p datafusion @@ -230,6 +236,55 @@ jobs: # do not produce debug symbols to keep memory usage down RUSTFLAGS: "-C debuginfo=0" + test-datafusion-pyarrow: + needs: [linux-build-lib] + runs-on: ubuntu-latest + strategy: + matrix: + arch: [amd64] + rust: [stable] + container: + image: ${{ matrix.arch }}/rust + env: + # Disable full debug symbol generation to speed up CI build and keep memory down + # "1" means line tables only, which is useful for panic tracebacks. + RUSTFLAGS: "-C debuginfo=1" + steps: + - uses: actions/checkout@v2 + with: + submodules: true + - name: Cache Cargo + uses: actions/cache@v2 + with: + path: /github/home/.cargo + # this key equals the ones on `linux-build-lib` for re-use + key: cargo-cache- + - name: Cache Rust dependencies + uses: actions/cache@v2 + with: + path: /github/home/target + # this key equals the ones on `linux-build-lib` for re-use + key: ${{ runner.os }}-${{ matrix.arch }}-target-cache-${{ matrix.rust }} + - uses: actions/setup-python@v2 + with: + python-version: "3.8" + - name: Install PyArrow + run: | + echo "LIBRARY_PATH=$LD_LIBRARY_PATH" >> $GITHUB_ENV + python -m pip install pyarrow + - name: Setup Rust toolchain + run: | + rustup toolchain install ${{ matrix.rust }} + rustup default ${{ matrix.rust }} + rustup component add rustfmt + - name: Run tests + run: | + cd datafusion + cargo test --features=pyarrow + env: + CARGO_HOME: "/github/home/.cargo" + CARGO_TARGET_DIR: "/github/home/target" + lint: name: Lint runs-on: ubuntu-latest diff --git a/Cargo.toml b/Cargo.toml index 5af182e873db..a988927a1f97 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -18,6 +18,8 @@ [workspace] members = [ "datafusion", + "datafusion-common", + "datafusion-expr", "datafusion-cli", "datafusion-examples", "benchmarks", @@ -33,5 +35,5 @@ lto = true codegen-units = 1 [patch.crates-io] -#arrow2 = { git = "https://github.com/jorgecarleitao/arrow2.git", branch = "main" } -#parquet2 = { git = "https://github.com/jorgecarleitao/parquet2.git", branch = "main" } +arrow2 = { git = "https://github.com/jorgecarleitao/arrow2.git", branch = "main" } +parquet2 = { git = "https://github.com/jorgecarleitao/parquet2.git", branch = "main" } diff --git a/README.md b/README.md index 25fc16c8956c..dc350f69bb9c 100644 --- a/README.md +++ b/README.md @@ -73,361 +73,23 @@ Here are some of the projects known to use DataFusion: ## Example Usage -Run a SQL query against data stored in a CSV: +Please see [example usage](https://arrow.apache.org/datafusion/user-guide/example-usage.html) to find how to use DataFusion. -```rust -use datafusion::prelude::*; -use datafusion::arrow::record_batch::RecordBatch; - -#[tokio::main] -async fn main() -> datafusion::error::Result<()> { - // register the table - let mut ctx = ExecutionContext::new(); - ctx.register_csv("example", "tests/example.csv", CsvReadOptions::new()).await?; - - // create a plan to run a SQL query - let df = ctx.sql("SELECT a, MIN(b) FROM example GROUP BY a LIMIT 100").await?; - - // execute and print results - df.show().await?; - Ok(()) -} -``` - -Use the DataFrame API to process data stored in a CSV: - -```rust -use datafusion::prelude::*; -use datafusion::arrow::record_batch::RecordBatch; - -#[tokio::main] -async fn main() -> datafusion::error::Result<()> { - // create the dataframe - let mut ctx = ExecutionContext::new(); - let df = ctx.read_csv("tests/example.csv", CsvReadOptions::new()).await?; - - let df = df.filter(col("a").lt_eq(col("b")))? - .aggregate(vec![col("a")], vec![min(col("b"))])?; - - // execute and print results - df.show_limit(100).await?; - Ok(()) -} -``` - -Both of these examples will produce - -```text -+---+--------+ -| a | MIN(b) | -+---+--------+ -| 1 | 2 | -+---+--------+ -``` - -## Using DataFusion as a library - -DataFusion is [published on crates.io](https://crates.io/crates/datafusion), and is [well documented on docs.rs](https://docs.rs/datafusion/). - -To get started, add the following to your `Cargo.toml` file: - -```toml -[dependencies] -datafusion = "6.0.0" -``` - -## Using DataFusion as a binary - -DataFusion also includes a simple command-line interactive SQL utility. See the [CLI reference](https://arrow.apache.org/datafusion/cli/index.html) for more information. - -# Roadmap - -A quarterly roadmap will be published to give the DataFusion community visibility into the priorities of the projects contributors. This roadmap is not binding. - -## 2022 Q1 - -### DataFusion Core - -- Publish official Arrow2 branch -- Implementation of memory manager (i.e. to enable spilling to disk as needed) - -### Benchmarking - -- Inclusion in Db-Benchmark with all quries covered -- All TPCH queries covered - -### Performance Improvements - -- Predicate evaluation -- Improve multi-column comparisons (that can't be vectorized at the moment) -- Null constant support - -### New Features - -- Read JSON as table -- Simplify DDL with Datafusion-Cli -- Add Decimal128 data type and the attendant features such as Arrow Kernel and UDF support -- Add new experimental e-graph based optimizer - -### Ballista - -- Begin work on design documents and plan / priorities for development - -### Extensions ([datafusion-contrib](https://github.com/datafusion-contrib])) - -- Stable S3 support -- Begin design discussions and prototyping of a stream provider - -## Beyond 2022 Q1 - -There is no clear timeline for the below, but community members have expressed interest in working on these topics. - -### DataFusion Core - -- Custom SQL support -- Split DataFusion into multiple crates -- Push based query execution and code generation - -### Ballista - -- Evolve architecture so that it can be deployed in a multi-tenant cloud native environment -- Ensure Ballista is scalable, elastic, and stable for production usage -- Develop distributed ML capabilities - -# Status - -## General - -- [x] SQL Parser -- [x] SQL Query Planner -- [x] Query Optimizer -- [x] Constant folding -- [x] Join Reordering -- [x] Limit Pushdown -- [x] Projection push down -- [x] Predicate push down -- [x] Type coercion -- [x] Parallel query execution - -## SQL Support - -- [x] Projection -- [x] Filter (WHERE) -- [x] Filter post-aggregate (HAVING) -- [x] Limit -- [x] Aggregate -- [x] Common math functions -- [x] cast -- [x] try_cast -- [x] [`VALUES` lists](https://www.postgresql.org/docs/current/queries-values.html) -- Postgres compatible String functions - - [x] ascii - - [x] bit_length - - [x] btrim - - [x] char_length - - [x] character_length - - [x] chr - - [x] concat - - [x] concat_ws - - [x] initcap - - [x] left - - [x] length - - [x] lpad - - [x] ltrim - - [x] octet_length - - [x] regexp_replace - - [x] repeat - - [x] replace - - [x] reverse - - [x] right - - [x] rpad - - [x] rtrim - - [x] split_part - - [x] starts_with - - [x] strpos - - [x] substr - - [x] to_hex - - [x] translate - - [x] trim -- Miscellaneous/Boolean functions - - [x] nullif -- Approximation functions - - [x] approx_distinct -- Common date/time functions - - [ ] Basic date functions - - [ ] Basic time functions - - [x] Basic timestamp functions - - [x] [to_timestamp](docs/user-guide/book/sql/datafusion-functions.html#to_timestamp) - - [x] [to_timestamp_millis](docs/user-guide/book/sql/datafusion-functions.html#to_timestamp_millis) - - [x] [to_timestamp_micros](docs/user-guide/book/sql/datafusion-functions.html#to_timestamp_micros) - - [x] [to_timestamp_seconds](docs/user-guide/book/sql/datafusion-functions.html#to_timestamp_seconds) -- nested functions - - [x] Array of columns -- [x] Schema Queries - - [x] SHOW TABLES - - [x] SHOW COLUMNS - - [x] information_schema.{tables, columns} - - [ ] information_schema other views -- [x] Sorting -- [ ] Nested types -- [ ] Lists -- [x] Subqueries -- [x] Common table expressions -- [x] Set Operations - - [x] UNION ALL - - [x] UNION - - [x] INTERSECT - - [x] INTERSECT ALL - - [x] EXCEPT - - [x] EXCEPT ALL -- [x] Joins - - [x] INNER JOIN - - [x] LEFT JOIN - - [x] RIGHT JOIN - - [x] FULL JOIN - - [x] CROSS JOIN -- [ ] Window - - [x] Empty window - - [x] Common window functions - - [x] Window with PARTITION BY clause - - [x] Window with ORDER BY clause - - [ ] Window with FILTER clause - - [ ] [Window with custom WINDOW FRAME](https://github.com/apache/arrow-datafusion/issues/361) - - [ ] UDF and UDAF for window functions - -## Data Sources - -- [x] CSV -- [x] Parquet primitive types -- [ ] Parquet nested types - -## Extensibility - -DataFusion is designed to be extensible at all points. To that end, you can provide your own custom: - -- [x] User Defined Functions (UDFs) -- [x] User Defined Aggregate Functions (UDAFs) -- [x] User Defined Table Source (`TableProvider`) for tables -- [x] User Defined `Optimizer` passes (plan rewrites) -- [x] User Defined `LogicalPlan` nodes -- [x] User Defined `ExecutionPlan` nodes - -## Rust Version Compatbility - -This crate is tested with the latest stable version of Rust. We do not currently test against other, older versions of the Rust compiler. - -# Supported SQL - -This library currently supports many SQL constructs, including - -- `CREATE EXTERNAL TABLE X STORED AS PARQUET LOCATION '...';` to register a table's locations -- `SELECT ... FROM ...` together with any expression -- `ALIAS` to name an expression -- `CAST` to change types, including e.g. `Timestamp(Nanosecond, None)` -- Many mathematical unary and binary expressions such as `+`, `/`, `sqrt`, `tan`, `>=`. -- `WHERE` to filter -- `GROUP BY` together with one of the following aggregations: `MIN`, `MAX`, `COUNT`, `SUM`, `AVG`, `CORR`, `VAR`, `COVAR`, `STDDEV` (sample and population) -- `ORDER BY` together with an expression and optional `ASC` or `DESC` and also optional `NULLS FIRST` or `NULLS LAST` - -## Supported Functions - -DataFusion strives to implement a subset of the [PostgreSQL SQL dialect](https://www.postgresql.org/docs/current/functions.html) where possible. We explicitly choose a single dialect to maximize interoperability with other tools and allow reuse of the PostgreSQL documents and tutorials as much as possible. - -Currently, only a subset of the PostgreSQL dialect is implemented, and we will document any deviations. - -## Schema Metadata / Information Schema Support - -DataFusion supports the showing metadata about the tables available. This information can be accessed using the views of the ISO SQL `information_schema` schema or the DataFusion specific `SHOW TABLES` and `SHOW COLUMNS` commands. - -More information can be found in the [Postgres docs](https://www.postgresql.org/docs/13/infoschema-schema.html)). - -To show tables available for use in DataFusion, use the `SHOW TABLES` command or the `information_schema.tables` view: - -```sql -> show tables; -+---------------+--------------------+------------+------------+ -| table_catalog | table_schema | table_name | table_type | -+---------------+--------------------+------------+------------+ -| datafusion | public | t | BASE TABLE | -| datafusion | information_schema | tables | VIEW | -+---------------+--------------------+------------+------------+ - -> select * from information_schema.tables; - -+---------------+--------------------+------------+--------------+ -| table_catalog | table_schema | table_name | table_type | -+---------------+--------------------+------------+--------------+ -| datafusion | public | t | BASE TABLE | -| datafusion | information_schema | TABLES | SYSTEM TABLE | -+---------------+--------------------+------------+--------------+ -``` - -To show the schema of a table in DataFusion, use the `SHOW COLUMNS` command or the or `information_schema.columns` view: - -```sql -> show columns from t; -+---------------+--------------+------------+-------------+-----------+-------------+ -| table_catalog | table_schema | table_name | column_name | data_type | is_nullable | -+---------------+--------------+------------+-------------+-----------+-------------+ -| datafusion | public | t | a | Int32 | NO | -| datafusion | public | t | b | Utf8 | NO | -| datafusion | public | t | c | Float32 | NO | -+---------------+--------------+------------+-------------+-----------+-------------+ - -> select table_name, column_name, ordinal_position, is_nullable, data_type from information_schema.columns; -+------------+-------------+------------------+-------------+-----------+ -| table_name | column_name | ordinal_position | is_nullable | data_type | -+------------+-------------+------------------+-------------+-----------+ -| t | a | 0 | NO | Int32 | -| t | b | 1 | NO | Utf8 | -| t | c | 2 | NO | Float32 | -+------------+-------------+------------------+-------------+-----------+ -``` - -## Supported Data Types - -DataFusion uses Arrow, and thus the Arrow type system, for query -execution. The SQL types from -[sqlparser-rs](https://github.com/ballista-compute/sqlparser-rs/blob/main/src/ast/data_type.rs#L57) -are mapped to Arrow types according to the following table - -| SQL Data Type | Arrow DataType | -| ------------- | --------------------------------- | -| `CHAR` | `Utf8` | -| `VARCHAR` | `Utf8` | -| `UUID` | _Not yet supported_ | -| `CLOB` | _Not yet supported_ | -| `BINARY` | _Not yet supported_ | -| `VARBINARY` | _Not yet supported_ | -| `DECIMAL` | `Float64` | -| `FLOAT` | `Float32` | -| `SMALLINT` | `Int16` | -| `INT` | `Int32` | -| `BIGINT` | `Int64` | -| `REAL` | `Float32` | -| `DOUBLE` | `Float64` | -| `BOOLEAN` | `Boolean` | -| `DATE` | `Date32` | -| `TIME` | `Time64(TimeUnit::Millisecond)` | -| `TIMESTAMP` | `Timestamp(TimeUnit::Nanosecond)` | -| `INTERVAL` | _Not yet supported_ | -| `REGCLASS` | _Not yet supported_ | -| `TEXT` | _Not yet supported_ | -| `BYTEA` | _Not yet supported_ | -| `CUSTOM` | _Not yet supported_ | -| `ARRAY` | _Not yet supported_ | - -# Roadmap +## Roadmap Please see [Roadmap](docs/source/specification/roadmap.md) for information of where the project is headed. -# Architecture Overview +## Architecture Overview There is no formal document describing DataFusion's architecture yet, but the following presentations offer a good overview of its different components and how they interact together. - (March 2021): The DataFusion architecture is described in _Query Engine Design and the Rust-Based DataFusion in Apache Arrow_: [recording](https://www.youtube.com/watch?v=K6eCAVEk4kU) (DataFusion content starts [~ 15 minutes in](https://www.youtube.com/watch?v=K6eCAVEk4kU&t=875s)) and [slides](https://www.slideshare.net/influxdata/influxdb-iox-tech-talks-query-engine-design-and-the-rustbased-datafusion-in-apache-arrow-244161934) - (February 2021): How DataFusion is used within the Ballista Project is described in \*Ballista: Distributed Compute with Rust and Apache Arrow: [recording](https://www.youtube.com/watch?v=ZZHQaOap9pQ) -# Developer's guide +## User's guide + +Please see [User Guide](https://arrow.apache.org/datafusion/) for more information about DataFusion. + +## Developer's guide Please see [Developers Guide](DEVELOPERS.md) for information about developing DataFusion. diff --git a/ballista/rust/client/Cargo.toml b/ballista/rust/client/Cargo.toml index 4ec1abe77654..dff5d1a5c584 100644 --- a/ballista/rust/client/Cargo.toml +++ b/ballista/rust/client/Cargo.toml @@ -35,7 +35,7 @@ log = "0.4" tokio = "1.0" tempfile = "3" sqlparser = "0.13" -parking_lot = "0.11" +parking_lot = "0.12" datafusion = { path = "../../../datafusion", version = "6.0.0" } diff --git a/ballista/rust/core/Cargo.toml b/ballista/rust/core/Cargo.toml index cdbbbf064371..dfc6a6c75494 100644 --- a/ballista/rust/core/Cargo.toml +++ b/ballista/rust/core/Cargo.toml @@ -50,7 +50,7 @@ arrow = { package = "arrow2", version="0.9", features = ["io_ipc", "io_flight"] datafusion = { path = "../../../datafusion", version = "6.0.0" } -parking_lot = "0.11" +parking_lot = "0.12" [dev-dependencies] tempfile = "3" diff --git a/ballista/rust/executor/Cargo.toml b/ballista/rust/executor/Cargo.toml index 310affdc01f8..ad48962b7d01 100644 --- a/ballista/rust/executor/Cargo.toml +++ b/ballista/rust/executor/Cargo.toml @@ -46,7 +46,7 @@ tokio-stream = { version = "0.1", features = ["net"] } tonic = "0.6" uuid = { version = "0.8", features = ["v4"] } hyper = "0.14.4" -parking_lot = "0.11" +parking_lot = "0.12" [dev-dependencies] diff --git a/ballista/rust/scheduler/Cargo.toml b/ballista/rust/scheduler/Cargo.toml index fdeb7e726d57..8acb13ba8963 100644 --- a/ballista/rust/scheduler/Cargo.toml +++ b/ballista/rust/scheduler/Cargo.toml @@ -53,7 +53,7 @@ tokio-stream = { version = "0.1", features = ["net"], optional = true } tonic = "0.6" tower = { version = "0.4" } warp = "0.3" -parking_lot = "0.11" +parking_lot = "0.12" [dev-dependencies] ballista-core = { path = "../core", version = "0.6.0" } diff --git a/datafusion-cli/Cargo.toml b/datafusion-cli/Cargo.toml index 09df15b57bc6..26ccdaf25fab 100644 --- a/datafusion-cli/Cargo.toml +++ b/datafusion-cli/Cargo.toml @@ -17,7 +17,8 @@ [package] name = "datafusion-cli" -version = "5.1.0" +description = "DataFusion is an in-memory query engine that uses Apache Arrow as the memory model. It supports executing SQL queries against CSV and Parquet files as well as querying directly against in-memory data." +version = "6.0.0" authors = ["Apache Arrow "] edition = "2021" keywords = [ "arrow", "datafusion", "ballista", "query", "sql" ] diff --git a/datafusion-cli/src/command.rs b/datafusion-cli/src/command.rs index fa37059039a2..f6bedc2148b9 100644 --- a/datafusion-cli/src/command.rs +++ b/datafusion-cli/src/command.rs @@ -20,7 +20,8 @@ use crate::context::Context; use crate::functions::{display_all_functions, Function}; use crate::print_format::PrintFormat; -use crate::print_options::{self, PrintOptions}; +use crate::print_options::PrintOptions; +use clap::ArgEnum; use datafusion::arrow::array::{ArrayRef, Utf8Array}; use datafusion::arrow::datatypes::{DataType, Field, Schema}; use datafusion::error::{DataFusionError, Result}; @@ -209,10 +210,14 @@ impl OutputFormat { Self::ChangeFormat(format) => { if let Ok(format) = format.parse::() { print_options.format = format; - println!("Output format is {}.", print_options.format); + println!("Output format is {:?}.", print_options.format); Ok(()) } else { - Err(DataFusionError::Execution(format!("{} is not a valid format type [possible values: csv, tsv, table, json, ndjson]", format))) + Err(DataFusionError::Execution(format!( + "{:?} is not a valid format type [possible values: {:?}]", + format, + PrintFormat::value_variants() + ))) } } } diff --git a/datafusion-cli/src/exec.rs b/datafusion-cli/src/exec.rs index acc340db8222..17b329b86d9b 100644 --- a/datafusion-cli/src/exec.rs +++ b/datafusion-cli/src/exec.rs @@ -21,19 +21,14 @@ use crate::{ command::{Command, OutputFormat}, context::Context, helper::CliHelper, - print_format::{all_print_formats, PrintFormat}, print_options::PrintOptions, }; -use datafusion::error::{DataFusionError, Result}; -use datafusion::record_batch::RecordBatch; -use rustyline::config::Config; +use datafusion::error::Result; use rustyline::error::ReadlineError; use rustyline::Editor; use std::fs::File; use std::io::prelude::*; use std::io::BufReader; -use std::str::FromStr; -use std::sync::Arc; use std::time::Instant; /// run and execute SQL statements and commands from a file, against a context with the given print options @@ -108,7 +103,7 @@ pub async fn exec_from_repl(ctx: &mut Context, print_options: &mut PrintOptions) ); } } else { - println!("Output format is {}.", print_options.format); + println!("Output format is {:?}.", print_options.format); } } _ => { diff --git a/datafusion-cli/src/functions.rs b/datafusion-cli/src/functions.rs index 7839d4f69bcb..224f990c440a 100644 --- a/datafusion-cli/src/functions.rs +++ b/datafusion-cli/src/functions.rs @@ -20,10 +20,8 @@ use arrow::array::{ArrayRef, Utf8Array}; use arrow::chunk::Chunk; use arrow::datatypes::{DataType, Field, Schema}; use datafusion::arrow::io::print; -use datafusion::error::{DataFusionError, Result}; +use datafusion::error::Result; use datafusion::field_util::SchemaExt; -use datafusion::physical_plan::ColumnarValue::Array; -use datafusion::record_batch::RecordBatch; use std::fmt; use std::str::FromStr; use std::sync::Arc; diff --git a/datafusion-cli/src/lib.rs b/datafusion-cli/src/lib.rs index b2bcdd3e48a6..b75be331259b 100644 --- a/datafusion-cli/src/lib.rs +++ b/datafusion-cli/src/lib.rs @@ -16,7 +16,6 @@ // under the License. #![doc = include_str!("../README.md")] -#![allow(unused_imports)] pub const DATAFUSION_CLI_VERSION: &str = env!("CARGO_PKG_VERSION"); pub mod command; diff --git a/datafusion-cli/src/main.rs b/datafusion-cli/src/main.rs index 4cb9e9ddef14..788bb27f899a 100644 --- a/datafusion-cli/src/main.rs +++ b/datafusion-cli/src/main.rs @@ -15,14 +15,11 @@ // specific language governing permissions and limitations // under the License. -use clap::{crate_version, App, Arg}; +use clap::Parser; use datafusion::error::Result; use datafusion::execution::context::ExecutionConfig; use datafusion_cli::{ - context::Context, - exec, - print_format::{all_print_formats, PrintFormat}, - print_options::PrintOptions, + context::Context, exec, print_format::PrintFormat, print_options::PrintOptions, DATAFUSION_CLI_VERSION, }; use std::env; @@ -30,117 +27,84 @@ use std::fs::File; use std::io::BufReader; use std::path::Path; +#[derive(Debug, Parser, PartialEq)] +#[clap(author, version, about, long_about= None)] +struct Args { + #[clap( + short = 'p', + long, + help = "Path to your data, default to current directory", + validator(is_valid_data_dir) + )] + data_path: Option, + + #[clap( + short = 'c', + long, + help = "The batch size of each query, or use DataFusion default", + validator(is_valid_batch_size) + )] + batch_size: Option, + + #[clap( + short, + long, + multiple_values = true, + help = "Execute commands from file(s), then exit", + validator(is_valid_file) + )] + file: Vec, + + #[clap(long, arg_enum, default_value_t = PrintFormat::Table)] + format: PrintFormat, + + #[clap(long, help = "Ballista scheduler host")] + host: Option, + + #[clap(long, help = "Ballista scheduler port")] + port: Option, + + #[clap( + short, + long, + help = "Reduce printing other than the results and work quietly" + )] + quiet: bool, +} + #[tokio::main] pub async fn main() -> Result<()> { - let matches = App::new("DataFusion") - .version(crate_version!()) - .about( - "DataFusion is an in-memory query engine that uses Apache Arrow \ - as the memory model. It supports executing SQL queries against CSV and \ - Parquet files as well as querying directly against in-memory data.", - ) - .arg( - Arg::new("data-path") - .help("Path to your data, default to current directory") - .short('p') - .long("data-path") - .validator(is_valid_data_dir) - .takes_value(true), - ) - .arg( - Arg::new("batch-size") - .help("The batch size of each query, or use DataFusion default") - .short('c') - .long("batch-size") - .validator(is_valid_batch_size) - .takes_value(true), - ) - .arg( - Arg::new("file") - .help("Execute commands from file(s), then exit") - .short('f') - .long("file") - .multiple_occurrences(true) - .validator(is_valid_file) - .takes_value(true), - ) - .arg( - Arg::new("format") - .help("Output format") - .long("format") - .default_value("table") - .possible_values( - &all_print_formats() - .iter() - .map(|format| format.to_string()) - .collect::>() - .iter() - .map(|i| i.as_str()) - .collect::>(), - ) - .takes_value(true), - ) - .arg( - Arg::new("host") - .help("Ballista scheduler host") - .long("host") - .takes_value(true), - ) - .arg( - Arg::new("port") - .help("Ballista scheduler port") - .long("port") - .takes_value(true), - ) - .arg( - Arg::new("quiet") - .help("Reduce printing other than the results and work quietly") - .short('q') - .long("quiet") - .takes_value(false), - ) - .get_matches(); - - let quiet = matches.is_present("quiet"); - - if !quiet { - println!("DataFusion CLI v{}\n", DATAFUSION_CLI_VERSION); - } + let args = Args::parse(); - let host = matches.value_of("host"); - let port = matches - .value_of("port") - .and_then(|port| port.parse::().ok()); + if !args.quiet { + println!("DataFusion CLI v{}", DATAFUSION_CLI_VERSION); + } - if let Some(path) = matches.value_of("data-path") { + if let Some(ref path) = args.data_path { let p = Path::new(path); env::set_current_dir(&p).unwrap(); }; let mut execution_config = ExecutionConfig::new().with_information_schema(true); - if let Some(batch_size) = matches - .value_of("batch-size") - .and_then(|size| size.parse::().ok()) - { + if let Some(batch_size) = args.batch_size { execution_config = execution_config.with_batch_size(batch_size); }; - let mut ctx: Context = match (host, port) { - (Some(h), Some(p)) => Context::new_remote(h, p)?, + let mut ctx: Context = match (args.host, args.port) { + (Some(ref h), Some(p)) => Context::new_remote(h, p)?, _ => Context::new_local(&execution_config), }; - let format = matches - .value_of("format") - .expect("No format is specified") - .parse::() - .expect("Invalid format"); - - let mut print_options = PrintOptions { format, quiet }; + let mut print_options = PrintOptions { + format: args.format, + quiet: args.quiet, + }; - if let Some(file_paths) = matches.values_of("file") { - let files = file_paths + let files = args.file; + if !files.is_empty() { + let files = files + .into_iter() .map(|file_path| File::open(file_path).unwrap()) .collect::>(); for file in files { diff --git a/datafusion-cli/src/print_format.rs b/datafusion-cli/src/print_format.rs index fa8bf2384cf3..5a176e0a0928 100644 --- a/datafusion-cli/src/print_format.rs +++ b/datafusion-cli/src/print_format.rs @@ -17,15 +17,14 @@ //! Print format variants use arrow::io::json::write::{JsonArray, JsonFormat, LineDelimited}; -use datafusion::arrow::io::{csv::write, print}; +use datafusion::arrow::io::csv::write; use datafusion::error::{DataFusionError, Result}; use datafusion::field_util::SchemaExt; use datafusion::record_batch::RecordBatch; -use std::fmt; use std::str::FromStr; /// Allow records to be printed in different formats -#[derive(Debug, PartialEq, Eq, Clone)] +#[derive(Debug, PartialEq, Eq, clap::ArgEnum, Clone)] pub enum PrintFormat { Csv, Tsv, @@ -34,40 +33,11 @@ pub enum PrintFormat { NdJson, } -/// returns all print formats -pub fn all_print_formats() -> Vec { - vec![ - PrintFormat::Csv, - PrintFormat::Tsv, - PrintFormat::Table, - PrintFormat::Json, - PrintFormat::NdJson, - ] -} - impl FromStr for PrintFormat { - type Err = (); - fn from_str(s: &str) -> std::result::Result { - match s.to_lowercase().as_str() { - "csv" => Ok(Self::Csv), - "tsv" => Ok(Self::Tsv), - "table" => Ok(Self::Table), - "json" => Ok(Self::Json), - "ndjson" => Ok(Self::NdJson), - _ => Err(()), - } - } -} + type Err = String; -impl fmt::Display for PrintFormat { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - match *self { - Self::Csv => write!(f, "csv"), - Self::Tsv => write!(f, "tsv"), - Self::Table => write!(f, "table"), - Self::Json => write!(f, "json"), - Self::NdJson => write!(f, "ndjson"), - } + fn from_str(s: &str) -> std::result::Result { + clap::ArgEnum::from_str(s, true) } } @@ -146,38 +116,6 @@ mod tests { use datafusion::arrow::datatypes::{DataType, Field, Schema}; use std::sync::Arc; - #[test] - fn test_from_str() { - let format = "csv".parse::().unwrap(); - assert_eq!(PrintFormat::Csv, format); - - let format = "tsv".parse::().unwrap(); - assert_eq!(PrintFormat::Tsv, format); - - let format = "json".parse::().unwrap(); - assert_eq!(PrintFormat::Json, format); - - let format = "ndjson".parse::().unwrap(); - assert_eq!(PrintFormat::NdJson, format); - - let format = "table".parse::().unwrap(); - assert_eq!(PrintFormat::Table, format); - } - - #[test] - fn test_to_str() { - assert_eq!("csv", PrintFormat::Csv.to_string()); - assert_eq!("table", PrintFormat::Table.to_string()); - assert_eq!("tsv", PrintFormat::Tsv.to_string()); - assert_eq!("json", PrintFormat::Json.to_string()); - assert_eq!("ndjson", PrintFormat::NdJson.to_string()); - } - - #[test] - fn test_from_str_failure() { - assert!("pretty".parse::().is_err()); - } - #[test] fn test_print_batches_with_sep() { let batches = vec![]; diff --git a/datafusion-common/Cargo.toml b/datafusion-common/Cargo.toml new file mode 100644 index 000000000000..08f228fc134f --- /dev/null +++ b/datafusion-common/Cargo.toml @@ -0,0 +1,44 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +[package] +name = "datafusion-common" +description = "DataFusion is an in-memory query engine that uses Apache Arrow as the memory model" +version = "6.0.0" +homepage = "https://github.com/apache/arrow-datafusion" +repository = "https://github.com/apache/arrow-datafusion" +readme = "README.md" +authors = ["Apache Arrow "] +license = "Apache-2.0" +keywords = [ "arrow", "query", "sql" ] +edition = "2021" +rust-version = "1.58" + +[lib] +name = "datafusion_common" +path = "src/lib.rs" + +[features] +pyarrow = ["pyo3"] + +[dependencies] +arrow = { package = "arrow2", version = "0.9", default-features = false } +parquet = { package = "parquet2", version = "0.10", default_features = false, features = ["stream"] } + +pyo3 = { version = "0.15", optional = true } +sqlparser = "0.13" +ordered-float = "2.10" diff --git a/datafusion-common/README.md b/datafusion-common/README.md new file mode 100644 index 000000000000..8c44d78ef47f --- /dev/null +++ b/datafusion-common/README.md @@ -0,0 +1,24 @@ + + +# DataFusion Common + +This is an internal module for the most fundamental types of [DataFusion][df]. + +[df]: https://crates.io/crates/datafusion diff --git a/datafusion-common/src/column.rs b/datafusion-common/src/column.rs new file mode 100644 index 000000000000..02faa24b0346 --- /dev/null +++ b/datafusion-common/src/column.rs @@ -0,0 +1,150 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Column + +use crate::{DFSchema, DataFusionError, Result}; +use std::collections::HashSet; +use std::convert::Infallible; +use std::fmt; +use std::str::FromStr; +use std::sync::Arc; + +/// A named reference to a qualified field in a schema. +#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)] +pub struct Column { + /// relation/table name. + pub relation: Option, + /// field/column name. + pub name: String, +} + +impl Column { + /// Create Column from unqualified name. + pub fn from_name(name: impl Into) -> Self { + Self { + relation: None, + name: name.into(), + } + } + + /// Deserialize a fully qualified name string into a column + pub fn from_qualified_name(flat_name: &str) -> Self { + use sqlparser::tokenizer::Token; + + let dialect = sqlparser::dialect::GenericDialect {}; + let mut tokenizer = sqlparser::tokenizer::Tokenizer::new(&dialect, flat_name); + if let Ok(tokens) = tokenizer.tokenize() { + if let [Token::Word(relation), Token::Period, Token::Word(name)] = + tokens.as_slice() + { + return Column { + relation: Some(relation.value.clone()), + name: name.value.clone(), + }; + } + } + // any expression that's not in the form of `foo.bar` will be treated as unqualified column + // name + Column { + relation: None, + name: String::from(flat_name), + } + } + + /// Serialize column into a flat name string + pub fn flat_name(&self) -> String { + match &self.relation { + Some(r) => format!("{}.{}", r, self.name), + None => self.name.clone(), + } + } + + // Internal implementation of normalize + pub fn normalize_with_schemas( + self, + schemas: &[&Arc], + using_columns: &[HashSet], + ) -> Result { + if self.relation.is_some() { + return Ok(self); + } + + for schema in schemas { + let fields = schema.fields_with_unqualified_name(&self.name); + match fields.len() { + 0 => continue, + 1 => { + return Ok(fields[0].qualified_column()); + } + _ => { + // More than 1 fields in this schema have their names set to self.name. + // + // This should only happen when a JOIN query with USING constraint references + // join columns using unqualified column name. For example: + // + // ```sql + // SELECT id FROM t1 JOIN t2 USING(id) + // ``` + // + // In this case, both `t1.id` and `t2.id` will match unqualified column `id`. + // We will use the relation from the first matched field to normalize self. + + // Compare matched fields with one USING JOIN clause at a time + for using_col in using_columns { + let all_matched = fields + .iter() + .all(|f| using_col.contains(&f.qualified_column())); + // All matched fields belong to the same using column set, in orther words + // the same join clause. We simply pick the qualifer from the first match. + if all_matched { + return Ok(fields[0].qualified_column()); + } + } + } + } + } + + Err(DataFusionError::Plan(format!( + "Column {} not found in provided schemas", + self + ))) + } +} + +impl From<&str> for Column { + fn from(c: &str) -> Self { + Self::from_qualified_name(c) + } +} + +impl FromStr for Column { + type Err = Infallible; + + fn from_str(s: &str) -> std::result::Result { + Ok(s.into()) + } +} + +impl fmt::Display for Column { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match &self.relation { + Some(r) => write!(f, "#{}.{}", r, self.name), + None => write!(f, "#{}", self.name), + } + } +} diff --git a/datafusion-common/src/dfschema.rs b/datafusion-common/src/dfschema.rs new file mode 100644 index 000000000000..55c5c4c7085c --- /dev/null +++ b/datafusion-common/src/dfschema.rs @@ -0,0 +1,720 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! DFSchema is an extended schema struct that DataFusion uses to provide support for +//! fields with optional relation names. + +use std::collections::HashSet; +use std::convert::TryFrom; +use std::sync::Arc; + +use crate::error::{DataFusionError, Result}; +use crate::Column; + +use crate::field_util::{FieldExt, SchemaExt}; +use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; +use std::fmt::{Display, Formatter}; + +/// A reference-counted reference to a `DFSchema`. +pub type DFSchemaRef = Arc; + +/// DFSchema wraps an Arrow schema and adds relation names +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct DFSchema { + /// Fields + fields: Vec, +} + +impl DFSchema { + /// Creates an empty `DFSchema` + pub fn empty() -> Self { + Self { fields: vec![] } + } + + /// Create a new `DFSchema` + pub fn new(fields: Vec) -> Result { + let mut qualified_names = HashSet::new(); + let mut unqualified_names = HashSet::new(); + + for field in &fields { + if let Some(qualifier) = field.qualifier() { + if !qualified_names.insert((qualifier, field.name())) { + return Err(DataFusionError::Plan(format!( + "Schema contains duplicate qualified field name '{}'", + field.qualified_name() + ))); + } + } else if !unqualified_names.insert(field.name()) { + return Err(DataFusionError::Plan(format!( + "Schema contains duplicate unqualified field name '{}'", + field.name() + ))); + } + } + + // check for mix of qualified and unqualified field with same unqualified name + // note that we need to sort the contents of the HashSet first so that errors are + // deterministic + let mut qualified_names = qualified_names + .iter() + .map(|(l, r)| (l.as_str(), r.to_owned())) + .collect::>(); + qualified_names.sort_by(|a, b| { + let a = format!("{}.{}", a.0, a.1); + let b = format!("{}.{}", b.0, b.1); + a.cmp(&b) + }); + for (qualifier, name) in &qualified_names { + if unqualified_names.contains(name) { + return Err(DataFusionError::Plan(format!( + "Schema contains qualified field name '{}.{}' \ + and unqualified field name '{}' which would be ambiguous", + qualifier, name, name + ))); + } + } + Ok(Self { fields }) + } + + /// Create a `DFSchema` from an Arrow schema + pub fn try_from_qualified_schema(qualifier: &str, schema: &Schema) -> Result { + Self::new( + schema + .fields() + .iter() + .map(|f| DFField::from_qualified(qualifier, f.clone())) + .collect(), + ) + } + + /// Combine two schemas + pub fn join(&self, schema: &DFSchema) -> Result { + let mut fields = self.fields.clone(); + fields.extend_from_slice(schema.fields().as_slice()); + Self::new(fields) + } + + /// Merge a schema into self + pub fn merge(&mut self, other_schema: &DFSchema) { + for field in other_schema.fields() { + // skip duplicate columns + let duplicated_field = match field.qualifier() { + Some(q) => self.field_with_name(Some(q.as_str()), field.name()).is_ok(), + // for unqualifed columns, check as unqualified name + None => self.field_with_unqualified_name(field.name()).is_ok(), + }; + if !duplicated_field { + self.fields.push(field.clone()); + } + } + } + + /// Get a list of fields + pub fn fields(&self) -> &Vec { + &self.fields + } + + /// Returns an immutable reference of a specific `Field` instance selected using an + /// offset within the internal `fields` vector + pub fn field(&self, i: usize) -> &DFField { + &self.fields[i] + } + + /// Find the index of the column with the given unqualified name + pub fn index_of(&self, name: &str) -> Result { + for i in 0..self.fields.len() { + if self.fields[i].name() == name { + return Ok(i); + } + } + Err(DataFusionError::Plan(format!( + "No field named '{}'. Valid fields are {}.", + name, + self.get_field_names() + ))) + } + + fn index_of_column_by_name( + &self, + qualifier: Option<&str>, + name: &str, + ) -> Result { + let mut matches = self + .fields + .iter() + .enumerate() + .filter(|(_, field)| match (qualifier, &field.qualifier) { + // field to lookup is qualified. + // current field is qualified and not shared between relations, compare both + // qualifier and name. + (Some(q), Some(field_q)) => q == field_q && field.name() == name, + // field to lookup is qualified but current field is unqualified. + (Some(_), None) => false, + // field to lookup is unqualified, no need to compare qualifier + (None, Some(_)) | (None, None) => field.name() == name, + }) + .map(|(idx, _)| idx); + match matches.next() { + None => Err(DataFusionError::Plan(format!( + "No field named '{}.{}'. Valid fields are {}.", + qualifier.unwrap_or(""), + name, + self.get_field_names() + ))), + Some(idx) => match matches.next() { + None => Ok(idx), + // found more than one matches + Some(_) => Err(DataFusionError::Internal(format!( + "Ambiguous reference to qualified field named '{}.{}'", + qualifier.unwrap_or(""), + name + ))), + }, + } + } + + /// Find the index of the column with the given qualifier and name + pub fn index_of_column(&self, col: &Column) -> Result { + self.index_of_column_by_name(col.relation.as_deref(), &col.name) + } + + /// Find the field with the given name + pub fn field_with_name( + &self, + qualifier: Option<&str>, + name: &str, + ) -> Result<&DFField> { + if let Some(qualifier) = qualifier { + self.field_with_qualified_name(qualifier, name) + } else { + self.field_with_unqualified_name(name) + } + } + + /// Find all fields match the given name + pub fn fields_with_unqualified_name(&self, name: &str) -> Vec<&DFField> { + self.fields + .iter() + .filter(|field| field.name() == name) + .collect() + } + + /// Find the field with the given name + pub fn field_with_unqualified_name(&self, name: &str) -> Result<&DFField> { + let matches = self.fields_with_unqualified_name(name); + match matches.len() { + 0 => Err(DataFusionError::Plan(format!( + "No field with unqualified name '{}'. Valid fields are {}.", + name, + self.get_field_names() + ))), + 1 => Ok(matches[0]), + _ => Err(DataFusionError::Plan(format!( + "Ambiguous reference to field named '{}'", + name + ))), + } + } + + /// Find the field with the given qualified name + pub fn field_with_qualified_name( + &self, + qualifier: &str, + name: &str, + ) -> Result<&DFField> { + let idx = self.index_of_column_by_name(Some(qualifier), name)?; + Ok(self.field(idx)) + } + + /// Find the field with the given qualified column + pub fn field_from_column(&self, column: &Column) -> Result<&DFField> { + match &column.relation { + Some(r) => self.field_with_qualified_name(r, &column.name), + None => self.field_with_unqualified_name(&column.name), + } + } + + /// Check to see if unqualified field names matches field names in Arrow schema + pub fn matches_arrow_schema(&self, arrow_schema: &Schema) -> bool { + self.fields + .iter() + .zip(arrow_schema.fields().iter()) + .all(|(dffield, arrowfield)| dffield.name() == arrowfield.name()) + } + + /// Strip all field qualifier in schema + pub fn strip_qualifiers(self) -> Self { + DFSchema { + fields: self + .fields + .into_iter() + .map(|f| f.strip_qualifier()) + .collect(), + } + } + + /// Replace all field qualifier with new value in schema + pub fn replace_qualifier(self, qualifier: &str) -> Self { + DFSchema { + fields: self + .fields + .into_iter() + .map(|f| { + DFField::new( + Some(qualifier), + f.name(), + f.data_type().to_owned(), + f.is_nullable(), + ) + }) + .collect(), + } + } + + /// Get comma-seperated list of field names for use in error messages + fn get_field_names(&self) -> String { + self.fields + .iter() + .map(|f| match f.qualifier() { + Some(qualifier) => format!("'{}.{}'", qualifier, f.name()), + None => format!("'{}'", f.name()), + }) + .collect::>() + .join(", ") + } +} + +impl From for Schema { + /// Convert DFSchema into a Schema + fn from(df_schema: DFSchema) -> Self { + Schema::new( + df_schema + .fields + .into_iter() + .map(|f| { + if f.qualifier().is_some() { + Field::new(f.name(), f.data_type().to_owned(), f.is_nullable()) + } else { + f.field + } + }) + .collect(), + ) + } +} + +impl From<&DFSchema> for Schema { + /// Convert DFSchema reference into a Schema + fn from(df_schema: &DFSchema) -> Self { + Schema::new(df_schema.fields.iter().map(|f| f.field.clone()).collect()) + } +} + +/// Create a `DFSchema` from an Arrow schema +impl TryFrom for DFSchema { + type Error = DataFusionError; + fn try_from(schema: Schema) -> std::result::Result { + Self::new( + schema + .fields() + .iter() + .map(|f| DFField::from(f.clone())) + .collect(), + ) + } +} + +impl From for SchemaRef { + fn from(df_schema: DFSchema) -> Self { + SchemaRef::new(df_schema.into()) + } +} + +/// Convenience trait to convert Schema like things to DFSchema and DFSchemaRef with fewer keystrokes +pub trait ToDFSchema +where + Self: Sized, +{ + /// Attempt to create a DSSchema + #[allow(clippy::wrong_self_convention)] + fn to_dfschema(self) -> Result; + + /// Attempt to create a DSSchemaRef + #[allow(clippy::wrong_self_convention)] + fn to_dfschema_ref(self) -> Result { + Ok(Arc::new(self.to_dfschema()?)) + } +} + +impl ToDFSchema for Schema { + #[allow(clippy::wrong_self_convention)] + fn to_dfschema(self) -> Result { + DFSchema::try_from(self) + } +} + +impl ToDFSchema for SchemaRef { + #[allow(clippy::wrong_self_convention)] + fn to_dfschema(self) -> Result { + // Attempt to use the Schema directly if there are no other + // references, otherwise clone + match Self::try_unwrap(self) { + Ok(schema) => DFSchema::try_from(schema), + Err(schemaref) => DFSchema::try_from(schemaref.as_ref().clone()), + } + } +} + +impl ToDFSchema for Vec { + fn to_dfschema(self) -> Result { + DFSchema::new(self) + } +} + +impl Display for DFSchema { + fn fmt(&self, f: &mut Formatter) -> std::fmt::Result { + write!( + f, + "{}", + self.fields + .iter() + .map(|field| field.qualified_name()) + .collect::>() + .join(", ") + ) + } +} + +/// Provides schema information needed by [Expr] methods such as +/// [Expr::nullable] and [Expr::data_type]. +/// +/// Note that this trait is implemented for &[DFSchema] which is +/// widely used in the DataFusion codebase. +pub trait ExprSchema { + /// Is this column reference nullable? + fn nullable(&self, col: &Column) -> Result; + + /// What is the datatype of this column? + fn data_type(&self, col: &Column) -> Result<&DataType>; +} + +// Implement `ExprSchema` for `Arc` +impl> ExprSchema for P { + fn nullable(&self, col: &Column) -> Result { + self.as_ref().nullable(col) + } + + fn data_type(&self, col: &Column) -> Result<&DataType> { + self.as_ref().data_type(col) + } +} + +impl ExprSchema for DFSchema { + fn nullable(&self, col: &Column) -> Result { + Ok(self.field_from_column(col)?.is_nullable()) + } + + fn data_type(&self, col: &Column) -> Result<&DataType> { + Ok(self.field_from_column(col)?.data_type()) + } +} + +/// DFField wraps an Arrow field and adds an optional qualifier +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct DFField { + /// Optional qualifier (usually a table or relation name) + qualifier: Option, + /// Arrow field definition + field: Field, +} + +impl DFField { + /// Creates a new `DFField` + pub fn new( + qualifier: Option<&str>, + name: &str, + data_type: DataType, + nullable: bool, + ) -> Self { + DFField { + qualifier: qualifier.map(|s| s.to_owned()), + field: Field::new(name, data_type, nullable), + } + } + + /// Create an unqualified field from an existing Arrow field + pub fn from(field: Field) -> Self { + Self { + qualifier: None, + field, + } + } + + /// Create a qualified field from an existing Arrow field + pub fn from_qualified(qualifier: &str, field: Field) -> Self { + Self { + qualifier: Some(qualifier.to_owned()), + field, + } + } + + /// Returns an immutable reference to the `DFField`'s unqualified name + pub fn name(&self) -> &str { + self.field.name() + } + + /// Returns an immutable reference to the `DFField`'s data-type + pub fn data_type(&self) -> &DataType { + self.field.data_type() + } + + /// Indicates whether this `DFField` supports null values + pub fn is_nullable(&self) -> bool { + self.field.is_nullable() + } + + /// Returns a string to the `DFField`'s qualified name + pub fn qualified_name(&self) -> String { + if let Some(qualifier) = &self.qualifier { + format!("{}.{}", qualifier, self.field.name()) + } else { + self.field.name().to_owned() + } + } + + /// Builds a qualified column based on self + pub fn qualified_column(&self) -> Column { + Column { + relation: self.qualifier.clone(), + name: self.field.name().to_string(), + } + } + + /// Builds an unqualified column based on self + pub fn unqualified_column(&self) -> Column { + Column { + relation: None, + name: self.field.name().to_string(), + } + } + + /// Get the optional qualifier + pub fn qualifier(&self) -> Option<&String> { + self.qualifier.as_ref() + } + + /// Get the arrow field + pub fn field(&self) -> &Field { + &self.field + } + + /// Return field with qualifier stripped + pub fn strip_qualifier(mut self) -> Self { + self.qualifier = None; + self + } +} + +#[cfg(test)] +mod tests { + use super::*; + use arrow::datatypes::DataType; + + #[test] + fn from_unqualified_field() { + let field = Field::new("c0", DataType::Boolean, true); + let field = DFField::from(field); + assert_eq!("c0", field.name()); + assert_eq!("c0", field.qualified_name()); + } + + #[test] + fn from_qualified_field() { + let field = Field::new("c0", DataType::Boolean, true); + let field = DFField::from_qualified("t1", field); + assert_eq!("c0", field.name()); + assert_eq!("t1.c0", field.qualified_name()); + } + + #[test] + fn from_unqualified_schema() -> Result<()> { + let schema = DFSchema::try_from(test_schema_1())?; + assert_eq!("c0, c1", schema.to_string()); + Ok(()) + } + + #[test] + fn from_qualified_schema() -> Result<()> { + let schema = DFSchema::try_from_qualified_schema("t1", &test_schema_1())?; + assert_eq!("t1.c0, t1.c1", schema.to_string()); + Ok(()) + } + + #[test] + fn from_qualified_schema_into_arrow_schema() -> Result<()> { + let schema = DFSchema::try_from_qualified_schema("t1", &test_schema_1())?; + let arrow_schema: Schema = schema.into(); + let expected = + "[Field { name: \"c0\", data_type: Boolean, is_nullable: true, metadata: {} }, \ + Field { name: \"c1\", data_type: Boolean, is_nullable: true, metadata: {} }]"; + assert_eq!(expected, format!("{:?}", arrow_schema.fields)); + Ok(()) + } + + #[test] + fn join_qualified() -> Result<()> { + let left = DFSchema::try_from_qualified_schema("t1", &test_schema_1())?; + let right = DFSchema::try_from_qualified_schema("t2", &test_schema_1())?; + let join = left.join(&right)?; + assert_eq!("t1.c0, t1.c1, t2.c0, t2.c1", join.to_string()); + // test valid access + assert!(join.field_with_qualified_name("t1", "c0").is_ok()); + assert!(join.field_with_qualified_name("t2", "c0").is_ok()); + // test invalid access + assert!(join.field_with_unqualified_name("c0").is_err()); + assert!(join.field_with_unqualified_name("t1.c0").is_err()); + assert!(join.field_with_unqualified_name("t2.c0").is_err()); + Ok(()) + } + + #[test] + fn join_qualified_duplicate() -> Result<()> { + let left = DFSchema::try_from_qualified_schema("t1", &test_schema_1())?; + let right = DFSchema::try_from_qualified_schema("t1", &test_schema_1())?; + let join = left.join(&right); + assert!(join.is_err()); + assert_eq!( + "Error during planning: Schema contains duplicate \ + qualified field name \'t1.c0\'", + &format!("{}", join.err().unwrap()) + ); + Ok(()) + } + + #[test] + fn join_unqualified_duplicate() -> Result<()> { + let left = DFSchema::try_from(test_schema_1())?; + let right = DFSchema::try_from(test_schema_1())?; + let join = left.join(&right); + assert!(join.is_err()); + assert_eq!( + "Error during planning: Schema contains duplicate \ + unqualified field name \'c0\'", + &format!("{}", join.err().unwrap()) + ); + Ok(()) + } + + #[test] + fn join_mixed() -> Result<()> { + let left = DFSchema::try_from_qualified_schema("t1", &test_schema_1())?; + let right = DFSchema::try_from(test_schema_2())?; + let join = left.join(&right)?; + assert_eq!("t1.c0, t1.c1, c100, c101", join.to_string()); + // test valid access + assert!(join.field_with_qualified_name("t1", "c0").is_ok()); + assert!(join.field_with_unqualified_name("c0").is_ok()); + assert!(join.field_with_unqualified_name("c100").is_ok()); + assert!(join.field_with_name(None, "c100").is_ok()); + // test invalid access + assert!(join.field_with_unqualified_name("t1.c0").is_err()); + assert!(join.field_with_unqualified_name("t1.c100").is_err()); + assert!(join.field_with_qualified_name("", "c100").is_err()); + Ok(()) + } + + #[test] + fn join_mixed_duplicate() -> Result<()> { + let left = DFSchema::try_from_qualified_schema("t1", &test_schema_1())?; + let right = DFSchema::try_from(test_schema_1())?; + let join = left.join(&right); + assert!(join.is_err()); + assert_eq!( + "Error during planning: Schema contains qualified \ + field name \'t1.c0\' and unqualified field name \'c0\' which would be ambiguous", + &format!("{}", join.err().unwrap()) + ); + Ok(()) + } + + #[test] + fn helpful_error_messages() -> Result<()> { + let schema = DFSchema::try_from_qualified_schema("t1", &test_schema_1())?; + let expected_help = "Valid fields are \'t1.c0\', \'t1.c1\'."; + assert!(schema + .field_with_qualified_name("x", "y") + .unwrap_err() + .to_string() + .contains(expected_help)); + assert!(schema + .field_with_unqualified_name("y") + .unwrap_err() + .to_string() + .contains(expected_help)); + assert!(schema + .index_of("y") + .unwrap_err() + .to_string() + .contains(expected_help)); + Ok(()) + } + + #[test] + fn into() { + // Demonstrate how to convert back and forth between Schema, SchemaRef, DFSchema, and DFSchemaRef + let arrow_schema = Schema::new(vec![Field::new("c0", DataType::Int64, true)]); + let arrow_schema_ref = Arc::new(arrow_schema.clone()); + + let df_schema = + DFSchema::new(vec![DFField::new(None, "c0", DataType::Int64, true)]).unwrap(); + let df_schema_ref = Arc::new(df_schema.clone()); + + { + let arrow_schema = arrow_schema.clone(); + let arrow_schema_ref = arrow_schema_ref.clone(); + + assert_eq!(df_schema, arrow_schema.to_dfschema().unwrap()); + assert_eq!(df_schema, arrow_schema_ref.to_dfschema().unwrap()); + } + + { + let arrow_schema = arrow_schema.clone(); + let arrow_schema_ref = arrow_schema_ref.clone(); + + assert_eq!(df_schema_ref, arrow_schema.to_dfschema_ref().unwrap()); + assert_eq!(df_schema_ref, arrow_schema_ref.to_dfschema_ref().unwrap()); + } + + // Now, consume the refs + assert_eq!(df_schema_ref, arrow_schema.to_dfschema_ref().unwrap()); + assert_eq!(df_schema_ref, arrow_schema_ref.to_dfschema_ref().unwrap()); + } + + fn test_schema_1() -> Schema { + Schema::new(vec![ + Field::new("c0", DataType::Boolean, true), + Field::new("c1", DataType::Boolean, true), + ]) + } + + fn test_schema_2() -> Schema { + Schema::new(vec![ + Field::new("c100", DataType::Boolean, true), + Field::new("c101", DataType::Boolean, true), + ]) + } +} diff --git a/datafusion-common/src/error.rs b/datafusion-common/src/error.rs new file mode 100644 index 000000000000..33c47688e6ba --- /dev/null +++ b/datafusion-common/src/error.rs @@ -0,0 +1,184 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! DataFusion error types + +use std::error; +use std::fmt::{Display, Formatter}; +use std::io; +use std::result; + +use arrow::error::ArrowError; +use parquet::error::ParquetError; +use sqlparser::parser::ParserError; + +/// Result type for operations that could result in an [DataFusionError] +pub type Result = result::Result; + +/// Error type for generic operations that could result in DataFusionError::External +pub type GenericError = Box; + +/// DataFusion error +#[derive(Debug)] +pub enum DataFusionError { + /// Error returned by arrow. + ArrowError(ArrowError), + /// Wraps an error from the Parquet crate + ParquetError(ParquetError), + /// Error associated to I/O operations and associated traits. + IoError(io::Error), + /// Error returned when SQL is syntactically incorrect. + SQL(ParserError), + /// Error returned on a branch that we know it is possible + /// but to which we still have no implementation for. + /// Often, these errors are tracked in our issue tracker. + NotImplemented(String), + /// Error returned as a consequence of an error in DataFusion. + /// This error should not happen in normal usage of DataFusion. + // DataFusions has internal invariants that we are unable to ask the compiler to check for us. + // This error is raised when one of those invariants is not verified during execution. + Internal(String), + /// This error happens whenever a plan is not valid. Examples include + /// impossible casts, schema inference not possible and non-unique column names. + Plan(String), + /// Error returned during execution of the query. + /// Examples include files not found, errors in parsing certain types. + Execution(String), + /// This error is thrown when a consumer cannot acquire memory from the Memory Manager + /// we can just cancel the execution of the partition. + ResourcesExhausted(String), + /// Errors originating from outside DataFusion's core codebase. + /// For example, a custom S3Error from the crate datafusion-objectstore-s3 + External(GenericError), +} + +impl From for DataFusionError { + fn from(e: io::Error) -> Self { + DataFusionError::IoError(e) + } +} + +impl From for DataFusionError { + fn from(e: ArrowError) -> Self { + DataFusionError::ArrowError(e) + } +} + +impl From for ArrowError { + fn from(e: DataFusionError) -> Self { + match e { + DataFusionError::ArrowError(e) => e, + DataFusionError::External(e) => ArrowError::External(String::new(), e), + other => ArrowError::External(String::new(), Box::new(other)), + } + } +} + +impl From for DataFusionError { + fn from(e: ParquetError) -> Self { + DataFusionError::ParquetError(e) + } +} + +impl From for DataFusionError { + fn from(e: ParserError) -> Self { + DataFusionError::SQL(e) + } +} + +impl From for DataFusionError { + fn from(err: GenericError) -> Self { + DataFusionError::External(err) + } +} + +impl Display for DataFusionError { + fn fmt(&self, f: &mut Formatter) -> std::fmt::Result { + match *self { + DataFusionError::ArrowError(ref desc) => write!(f, "Arrow error: {}", desc), + DataFusionError::ParquetError(ref desc) => { + write!(f, "Parquet error: {}", desc) + } + DataFusionError::IoError(ref desc) => write!(f, "IO error: {}", desc), + DataFusionError::SQL(ref desc) => { + write!(f, "SQL error: {:?}", desc) + } + DataFusionError::NotImplemented(ref desc) => { + write!(f, "This feature is not implemented: {}", desc) + } + DataFusionError::Internal(ref desc) => { + write!(f, "Internal error: {}. This was likely caused by a bug in DataFusion's \ + code and we would welcome that you file an bug report in our issue tracker", desc) + } + DataFusionError::Plan(ref desc) => { + write!(f, "Error during planning: {}", desc) + } + DataFusionError::Execution(ref desc) => { + write!(f, "Execution error: {}", desc) + } + DataFusionError::ResourcesExhausted(ref desc) => { + write!(f, "Resources exhausted: {}", desc) + } + DataFusionError::External(ref desc) => { + write!(f, "External error: {}", desc) + } + } + } +} + +impl error::Error for DataFusionError {} + +#[cfg(test)] +mod test { + use crate::error::DataFusionError; + use arrow::error::ArrowError; + + #[test] + fn arrow_error_to_datafusion() { + let res = return_arrow_error().unwrap_err(); + assert_eq!( + res.to_string(), + "External error: Error during planning: foo" + ); + } + + #[test] + fn datafusion_error_to_arrow() { + let res = return_datafusion_error().unwrap_err(); + assert_eq!(res.to_string(), "Arrow error: Schema error: bar"); + } + + /// Model what happens when implementing SendableRecrordBatchStream: + /// DataFusion code needs to return an ArrowError + #[allow(clippy::try_err)] + fn return_arrow_error() -> arrow::error::Result<()> { + // Expect the '?' to work + let _foo = Err(DataFusionError::Plan("foo".to_string()))?; + Ok(()) + } + + /// Model what happens when using arrow kernels in DataFusion + /// code: need to turn an ArrowError into a DataFusionError + #[allow(clippy::try_err)] + fn return_datafusion_error() -> crate::error::Result<()> { + // Expect the '?' to work + let _bar = Err(ArrowError::InvalidArgumentError( + "bad schema bar".to_string(), + ))?; + Ok(()) + } +} diff --git a/datafusion-common/src/field_util.rs b/datafusion-common/src/field_util.rs new file mode 100644 index 000000000000..2dfccb73092d --- /dev/null +++ b/datafusion-common/src/field_util.rs @@ -0,0 +1,490 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Utility functions for complex field access + +use arrow::array::{ArrayRef, StructArray}; +use arrow::datatypes::{DataType, Field, Metadata, Schema}; +use arrow::error::ArrowError; +use std::borrow::Borrow; +use std::collections::BTreeMap; + +use crate::error::{DataFusionError, Result}; +use crate::scalar::ScalarValue; + +/// Returns the field access indexed by `key` from a [`DataType::List`] or [`DataType::Struct`] +/// # Error +/// Errors if +/// * the `data_type` is not a Struct or, +/// * there is no field key is not of the required index type +pub fn get_indexed_field(data_type: &DataType, key: &ScalarValue) -> Result { + match (data_type, key) { + (DataType::List(lt), ScalarValue::Int64(Some(i))) => { + if *i < 0 { + Err(DataFusionError::Plan(format!( + "List based indexed access requires a positive int, was {0}", + i + ))) + } else { + Ok(Field::new(&i.to_string(), lt.data_type().clone(), false)) + } + } + (DataType::Struct(fields), ScalarValue::Utf8(Some(s))) => { + if s.is_empty() { + Err(DataFusionError::Plan( + "Struct based indexed access requires a non empty string".to_string(), + )) + } else { + let field = fields.iter().find(|f| f.name() == s); + match field { + None => Err(DataFusionError::Plan(format!( + "Field {} not found in struct", + s + ))), + Some(f) => Ok(f.clone()), + } + } + } + (DataType::Struct(_), _) => Err(DataFusionError::Plan( + "Only utf8 strings are valid as an indexed field in a struct".to_string(), + )), + (DataType::List(_), _) => Err(DataFusionError::Plan( + "Only ints are valid as an indexed field in a list".to_string(), + )), + _ => Err(DataFusionError::Plan( + "The expression to get an indexed field is only valid for `List` types" + .to_string(), + )), + } +} + +/// Imitate arrow-rs StructArray behavior by extending arrow2 StructArray +pub trait StructArrayExt { + /// Return field names in this struct array + fn column_names(&self) -> Vec<&str>; + /// Return child array whose field name equals to column_name + fn column_by_name(&self, column_name: &str) -> Option<&ArrayRef>; + /// Return the number of fields in this struct array + fn num_columns(&self) -> usize; + /// Return the column at the position + fn column(&self, pos: usize) -> ArrayRef; +} + +impl StructArrayExt for StructArray { + fn column_names(&self) -> Vec<&str> { + self.fields().iter().map(|f| f.name.as_str()).collect() + } + + fn column_by_name(&self, column_name: &str) -> Option<&ArrayRef> { + self.fields() + .iter() + .position(|c| c.name() == column_name) + .map(|pos| self.values()[pos].borrow()) + } + + fn num_columns(&self) -> usize { + self.fields().len() + } + + fn column(&self, pos: usize) -> ArrayRef { + self.values()[pos].clone() + } +} + +/// Converts a list of field / array pairs to a struct array +pub fn struct_array_from(pairs: Vec<(Field, ArrayRef)>) -> StructArray { + let fields: Vec = pairs.iter().map(|v| v.0.clone()).collect(); + let values = pairs.iter().map(|v| v.1.clone()).collect(); + StructArray::from_data(DataType::Struct(fields), values, None) +} + +/// Imitate arrow-rs Schema behavior by extending arrow2 Schema +pub trait SchemaExt { + /// Creates a new [`Schema`] from a sequence of [`Field`] values. + /// + /// # Example + /// + /// ``` + /// use arrow::datatypes::{Field, DataType, Schema}; + /// use datafusion::field_util::SchemaExt; + /// let field_a = Field::new("a", DataType::Int64, false); + /// let field_b = Field::new("b", DataType::Boolean, false); + /// + /// let schema = Schema::new(vec![field_a, field_b]); + /// ``` + fn new(fields: Vec) -> Self; + + /// Creates a new [`Schema`] from a sequence of [`Field`] values and [`arrow::datatypes::Metadata`] + /// + /// # Example + /// + /// ``` + /// use std::collections::BTreeMap; + /// use arrow::datatypes::{Field, DataType, Schema}; + /// use datafusion::field_util::SchemaExt; + /// + /// let field_a = Field::new("a", DataType::Int64, false); + /// let field_b = Field::new("b", DataType::Boolean, false); + /// + /// let schema_metadata: BTreeMap = + /// vec![("baz".to_string(), "barf".to_string())] + /// .into_iter() + /// .collect(); + /// let schema = Schema::new_with_metadata(vec![field_a, field_b], schema_metadata); + /// ``` + fn new_with_metadata(fields: Vec, metadata: Metadata) -> Self; + + /// Creates an empty [`Schema`]. + fn empty() -> Self; + + /// Look up a column by name and return a immutable reference to the column along with + /// its index. + fn column_with_name(&self, name: &str) -> Option<(usize, &Field)>; + + /// Returns the first [`Field`] named `name`. + fn field_with_name(&self, name: &str) -> Result<&Field>; + + /// Find the index of the column with the given name. + fn index_of(&self, name: &str) -> Result; + + /// Returns the [`Field`] at position `i`. + /// # Panics + /// Panics iff `i` is larger than the number of fields in this [`Schema`]. + fn field(&self, index: usize) -> &Field; + + /// Returns all [`Field`]s in this schema. + fn fields(&self) -> &[Field]; + + /// Returns an immutable reference to the Map of custom metadata key-value pairs. + fn metadata(&self) -> &BTreeMap; + + /// Merge schema into self if it is compatible. Struct fields will be merged recursively. + /// + /// Example: + /// + /// ``` + /// use arrow::datatypes::*; + /// use datafusion::field_util::SchemaExt; + /// + /// let merged = Schema::try_merge(vec![ + /// Schema::new(vec![ + /// Field::new("c1", DataType::Int64, false), + /// Field::new("c2", DataType::Utf8, false), + /// ]), + /// Schema::new(vec![ + /// Field::new("c1", DataType::Int64, true), + /// Field::new("c2", DataType::Utf8, false), + /// Field::new("c3", DataType::Utf8, false), + /// ]), + /// ]).unwrap(); + /// + /// assert_eq!( + /// merged, + /// Schema::new(vec![ + /// Field::new("c1", DataType::Int64, true), + /// Field::new("c2", DataType::Utf8, false), + /// Field::new("c3", DataType::Utf8, false), + /// ]), + /// ); + /// ``` + fn try_merge(schemas: impl IntoIterator) -> Result + where + Self: Sized; + + /// Return the field names + fn field_names(&self) -> Vec; + + /// Returns a new schema with only the specified columns in the new schema + /// This carries metadata from the parent schema over as well + fn project(&self, indices: &[usize]) -> Result; +} + +impl SchemaExt for Schema { + fn new(fields: Vec) -> Self { + Self::from(fields) + } + + fn new_with_metadata(fields: Vec, metadata: Metadata) -> Self { + Self::new(fields).with_metadata(metadata) + } + + fn empty() -> Self { + Self::from(vec![]) + } + + fn column_with_name(&self, name: &str) -> Option<(usize, &Field)> { + self.fields.iter().enumerate().find(|(_, f)| f.name == name) + } + + fn field_with_name(&self, name: &str) -> Result<&Field> { + Ok(&self.fields[self.index_of(name)?]) + } + + fn index_of(&self, name: &str) -> Result { + self.column_with_name(name).map(|(i, _f)| i).ok_or_else(|| { + DataFusionError::ArrowError(ArrowError::InvalidArgumentError(format!( + "Unable to get field named \"{}\". Valid fields: {:?}", + name, + self.field_names() + ))) + }) + } + + fn field(&self, index: usize) -> &Field { + &self.fields[index] + } + + #[inline] + fn fields(&self) -> &[Field] { + &self.fields + } + + #[inline] + fn metadata(&self) -> &BTreeMap { + &self.metadata + } + + fn try_merge(schemas: impl IntoIterator) -> Result { + schemas + .into_iter() + .try_fold(Self::empty(), |mut merged, schema| { + let Schema { metadata, fields } = schema; + for (key, value) in metadata.into_iter() { + // merge metadata + if let Some(old_val) = merged.metadata.get(&key) { + if old_val != &value { + return Err(DataFusionError::ArrowError( + ArrowError::InvalidArgumentError( + "Fail to merge schema due to conflicting metadata." + .to_string(), + ), + )); + } + } + merged.metadata.insert(key, value); + } + // merge fields + for field in fields.into_iter() { + let mut new_field = true; + for merged_field in &mut merged.fields { + if field.name() != merged_field.name() { + continue; + } + new_field = false; + merged_field.try_merge(&field)? + } + // found a new field, add to field list + if new_field { + merged.fields.push(field); + } + } + Ok(merged) + }) + } + + fn field_names(&self) -> Vec { + self.fields.iter().map(|f| f.name.to_string()).collect() + } + + fn project(&self, indices: &[usize]) -> Result { + let new_fields = indices + .iter() + .map(|i| { + self.fields.get(*i).cloned().ok_or_else(|| { + DataFusionError::ArrowError(ArrowError::InvalidArgumentError( + format!( + "project index {} out of bounds, max field {}", + i, + self.fields().len() + ), + )) + }) + }) + .collect::>>()?; + Ok(Self::new_with_metadata(new_fields, self.metadata.clone())) + } +} + +/// Imitate arrow-rs Field behavior by extending arrow2 Field +pub trait FieldExt { + /// The field name + fn name(&self) -> &str; + + /// Whether the field is nullable + fn is_nullable(&self) -> bool; + + /// Returns the field metadata + fn metadata(&self) -> &BTreeMap; + + /// Merge field into self if it is compatible. Struct will be merged recursively. + /// NOTE: `self` may be updated to unexpected state in case of merge failure. + /// + /// Example: + /// + /// ``` + /// use arrow2::datatypes::*; + /// + /// let mut field = Field::new("c1", DataType::Int64, false); + /// assert!(field.try_merge(&Field::new("c1", DataType::Int64, true)).is_ok()); + /// assert!(field.is_nullable()); + /// ``` + fn try_merge(&mut self, from: &Field) -> Result<()>; + + /// Sets the `Field`'s optional custom metadata. + /// The metadata is set as `None` for empty map. + fn set_metadata(&mut self, metadata: Option>); +} + +impl FieldExt for Field { + #[inline] + fn name(&self) -> &str { + &self.name + } + + #[inline] + fn is_nullable(&self) -> bool { + self.is_nullable + } + + #[inline] + fn metadata(&self) -> &BTreeMap { + &self.metadata + } + + fn try_merge(&mut self, from: &Field) -> Result<()> { + // merge metadata + for (key, from_value) in from.metadata() { + if let Some(self_value) = self.metadata.get(key) { + if self_value != from_value { + return Err(DataFusionError::ArrowError(ArrowError::InvalidArgumentError(format!( + "Fail to merge field due to conflicting metadata data value for key {}", + key + )))); + } + } else { + self.metadata.insert(key.clone(), from_value.clone()); + } + } + + match &mut self.data_type { + DataType::Struct(nested_fields) => match &from.data_type { + DataType::Struct(from_nested_fields) => { + for from_field in from_nested_fields { + let mut is_new_field = true; + for self_field in nested_fields.iter_mut() { + if self_field.name != from_field.name { + continue; + } + is_new_field = false; + self_field.try_merge(from_field)?; + } + if is_new_field { + nested_fields.push(from_field.clone()); + } + } + } + _ => { + return Err(DataFusionError::ArrowError( + ArrowError::InvalidArgumentError( + "Fail to merge schema Field due to conflicting datatype" + .to_string(), + ), + )); + } + }, + DataType::Union(nested_fields, _, _) => match &from.data_type { + DataType::Union(from_nested_fields, _, _) => { + for from_field in from_nested_fields { + let mut is_new_field = true; + for self_field in nested_fields.iter_mut() { + if from_field == self_field { + is_new_field = false; + break; + } + } + if is_new_field { + nested_fields.push(from_field.clone()); + } + } + } + _ => { + return Err(DataFusionError::ArrowError( + ArrowError::InvalidArgumentError( + "Fail to merge schema Field due to conflicting datatype" + .to_string(), + ), + )); + } + }, + DataType::Null + | DataType::Boolean + | DataType::Int8 + | DataType::Int16 + | DataType::Int32 + | DataType::Int64 + | DataType::UInt8 + | DataType::UInt16 + | DataType::UInt32 + | DataType::UInt64 + | DataType::Float16 + | DataType::Float32 + | DataType::Float64 + | DataType::Timestamp(_, _) + | DataType::Date32 + | DataType::Date64 + | DataType::Time32(_) + | DataType::Time64(_) + | DataType::Duration(_) + | DataType::Binary + | DataType::LargeBinary + | DataType::Interval(_) + | DataType::LargeList(_) + | DataType::List(_) + | DataType::Dictionary(_, _, _) + | DataType::FixedSizeList(_, _) + | DataType::FixedSizeBinary(_) + | DataType::Utf8 + | DataType::LargeUtf8 + | DataType::Extension(_, _, _) + | DataType::Map(_, _) + | DataType::Decimal(_, _) => { + if self.data_type != from.data_type { + return Err(DataFusionError::ArrowError( + ArrowError::InvalidArgumentError( + "Fail to merge schema Field due to conflicting datatype" + .to_string(), + ), + )); + } + } + } + if from.is_nullable { + self.is_nullable = from.is_nullable; + } + + Ok(()) + } + + #[inline] + fn set_metadata(&mut self, metadata: Option>) { + if let Some(v) = metadata { + if !v.is_empty() { + self.metadata = v; + } + } + } +} diff --git a/datafusion-common/src/lib.rs b/datafusion-common/src/lib.rs new file mode 100644 index 000000000000..cb06b4663432 --- /dev/null +++ b/datafusion-common/src/lib.rs @@ -0,0 +1,30 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +mod column; +mod dfschema; +mod error; +pub mod field_util; +#[cfg(feature = "pyarrow")] +mod pyarrow; +pub mod record_batch; +mod scalar; + +pub use column::Column; +pub use dfschema::{DFField, DFSchema, DFSchemaRef, ExprSchema, ToDFSchema}; +pub use error::{DataFusionError, Result}; +pub use scalar::{ScalarValue, MAX_PRECISION_FOR_DECIMAL128, MAX_SCALE_FOR_DECIMAL128}; diff --git a/datafusion-common/src/pyarrow.rs b/datafusion-common/src/pyarrow.rs new file mode 100644 index 000000000000..405e568be246 --- /dev/null +++ b/datafusion-common/src/pyarrow.rs @@ -0,0 +1,247 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! PyArrow + +use crate::{DataFusionError, ScalarValue}; +use arrow::array::ArrayData; +use arrow::pyarrow::PyArrowConvert; +use pyo3::exceptions::PyException; +use pyo3::prelude::PyErr; +use pyo3::types::PyList; +use pyo3::{FromPyObject, IntoPy, PyAny, PyObject, PyResult, Python}; + +impl From for PyErr { + fn from(err: DataFusionError) -> PyErr { + PyException::new_err(err.to_string()) + } +} + +impl PyArrowConvert for ScalarValue { + fn from_pyarrow(value: &PyAny) -> PyResult { + let py = value.py(); + let typ = value.getattr("type")?; + let val = value.call_method0("as_py")?; + + // construct pyarrow array from the python value and pyarrow type + let factory = py.import("pyarrow")?.getattr("array")?; + let args = PyList::new(py, &[val]); + let array = factory.call1((args, typ))?; + + // convert the pyarrow array to rust array using C data interface + let array = array.extract::()?; + let scalar = ScalarValue::try_from_array(&array.into(), 0)?; + + Ok(scalar) + } + + fn to_pyarrow(&self, py: Python) -> PyResult { + let array = self.to_array(); + // convert to pyarrow array using C data interface + let pyarray = array.data_ref().clone().into_py(py); + let pyscalar = pyarray.call_method1(py, "__getitem__", (0,))?; + + Ok(pyscalar) + } +} + +impl<'source> FromPyObject<'source> for ScalarValue { + fn extract(value: &'source PyAny) -> PyResult { + Self::from_pyarrow(value) + } +} + +impl<'a> IntoPy for ScalarValue { + fn into_py(self, py: Python) -> PyObject { + self.to_pyarrow(py).unwrap() + } +} + +#[cfg(test)] +mod tests { + use super::*; + use pyo3::prepare_freethreaded_python; + use pyo3::py_run; + use pyo3::types::PyDict; + + fn init_python() { + prepare_freethreaded_python(); + Python::with_gil(|py| { + if let Err(err) = py.run("import pyarrow", None, None) { + let locals = PyDict::new(py); + py.run( + "import sys; executable = sys.executable; python_path = sys.path", + None, + Some(locals), + ) + .expect("Couldn't get python info"); + let executable: String = + locals.get_item("executable").unwrap().extract().unwrap(); + let python_path: Vec<&str> = + locals.get_item("python_path").unwrap().extract().unwrap(); + + Err(err).expect( + format!( + "pyarrow not found\nExecutable: {}\nPython path: {:?}\n\ + HINT: try `pip install pyarrow`\n\ + NOTE: On Mac OS, you must compile against a Framework Python \ + (default in python.org installers and brew, but not pyenv)\n\ + NOTE: On Mac OS, PYO3 might point to incorrect Python library \ + path when using virtual environments. Try \ + `export PYTHONPATH=$(python -c \"import sys; print(sys.path[-1])\")`\n", + executable, python_path + ) + .as_ref(), + ) + } + }) + } + + #[test] + fn test_roundtrip() { + init_python(); + + let example_scalars = vec![ + ScalarValue::Boolean(Some(true)), + ScalarValue::Int32(Some(23)), + ScalarValue::Float64(Some(12.34)), + ScalarValue::Utf8(Some("Hello!".to_string())), + ScalarValue::Date32(Some(1234)), + ]; + + Python::with_gil(|py| { + for scalar in example_scalars.iter() { + let result = + ScalarValue::from_pyarrow(scalar.to_pyarrow(py).unwrap().as_ref(py)) + .unwrap(); + assert_eq!(scalar, &result); + } + }); + } + + #[test] + fn test_py_scalar() { + init_python(); + + Python::with_gil(|py| { + let scalar_float = ScalarValue::Float64(Some(12.34)); + let py_float = scalar_float.into_py(py).call_method0(py, "as_py").unwrap(); + py_run!(py, py_float, "assert py_float == 12.34"); + + let scalar_string = ScalarValue::Utf8(Some("Hello!".to_string())); + let py_string = scalar_string.into_py(py).call_method0(py, "as_py").unwrap(); + py_run!(py, py_string, "assert py_string == 'Hello!'"); + }); + } +} + +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::Array; +use arrow::error::ArrowError; +use pyo3::exceptions::{PyException, PyNotImplementedError}; +use pyo3::ffi::Py_uintptr_t; +use pyo3::prelude::*; +use pyo3::types::PyList; +use std::sync::Arc; + +use crate::error::DataFusionError; +use crate::scalar::ScalarValue; + +impl From for PyErr { + fn from(err: DataFusionError) -> PyErr { + PyException::new_err(err.to_string()) + } +} + +impl From for PyErr { + fn from(err: PyO3ArrowError) -> PyErr { + PyException::new_err(format!("{:?}", err)) + } +} + +#[derive(Debug)] +enum PyO3ArrowError { + ArrowError(ArrowError), +} + +fn to_rust_array(ob: PyObject, py: Python) -> PyResult> { + // prepare a pointer to receive the Array struct + let array = Box::new(arrow::ffi::Ffi_ArrowArray::empty()); + let schema = Box::new(arrow::ffi::Ffi_ArrowSchema::empty()); + + let array_ptr = &*array as *const arrow::ffi::Ffi_ArrowArray; + let schema_ptr = &*schema as *const arrow::ffi::Ffi_ArrowSchema; + + // make the conversion through PyArrow's private API + // this changes the pointer's memory and is thus unsafe. In particular, `_export_to_c` can go out of bounds + ob.call_method1( + py, + "_export_to_c", + (array_ptr as Py_uintptr_t, schema_ptr as Py_uintptr_t), + )?; + + let field = unsafe { + arrow::ffi::import_field_from_c(schema.as_ref()) + .map_err(PyO3ArrowError::ArrowError)? + }; + let array = unsafe { + arrow::ffi::import_array_from_c(array, &field) + .map_err(PyO3ArrowError::ArrowError)? + }; + + Ok(array.into()) +} +impl<'source> FromPyObject<'source> for ScalarValue { + fn extract(value: &'source PyAny) -> PyResult { + let py = value.py(); + let typ = value.getattr("type")?; + let val = value.call_method0("as_py")?; + + // construct pyarrow array from the python value and pyarrow type + let factory = py.import("pyarrow")?.getattr("array")?; + let args = PyList::new(py, &[val]); + let array = factory.call1((args, typ))?; + + // convert the pyarrow array to rust array using C data interface] + let array = to_rust_array(array.to_object(py), py)?; + let scalar = ScalarValue::try_from_array(&array, 0)?; + + Ok(scalar) + } +} + +impl<'a> IntoPy for ScalarValue { + fn into_py(self, _py: Python) -> PyObject { + Err(PyNotImplementedError::new_err("Not implemented")).unwrap() + } +} diff --git a/datafusion-common/src/record_batch.rs b/datafusion-common/src/record_batch.rs new file mode 100644 index 000000000000..4d456870ec7d --- /dev/null +++ b/datafusion-common/src/record_batch.rs @@ -0,0 +1,449 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Contains [`RecordBatch`]. +use std::sync::Arc; + +use crate::field_util::SchemaExt; +use arrow::array::*; +use arrow::chunk::Chunk; +use arrow::compute::filter::{build_filter, filter}; +use arrow::datatypes::*; +use arrow::error::{ArrowError, Result}; + +/// A two-dimensional dataset with a number of +/// columns ([`Array`]) and rows and defined [`Schema`](crate::datatypes::Schema). +/// # Implementation +/// Cloning is `O(C)` where `C` is the number of columns. +#[derive(Clone, Debug, PartialEq)] +pub struct RecordBatch { + schema: Arc, + columns: Vec>, +} + +impl RecordBatch { + /// Creates a [`RecordBatch`] from a schema and columns. + /// # Errors + /// This function errors iff + /// * `columns` is empty + /// * the schema and column data types do not match + /// * `columns` have a different length + /// # Example + /// + /// ``` + /// # use std::sync::Arc; + /// # use arrow2::array::PrimitiveArray; + /// # use arrow2::datatypes::{Schema, Field, DataType}; + /// # use arrow2::record_batch::RecordBatch; + /// # fn main() -> arrow2::error::Result<()> { + /// let id_array = PrimitiveArray::from_slice([1i32, 2, 3, 4, 5]); + /// let schema = Arc::new(Schema::new(vec![ + /// Field::new("id", DataType::Int32, false) + /// ])); + /// + /// let batch = RecordBatch::try_new( + /// schema, + /// vec![Arc::new(id_array)] + /// )?; + /// # Ok(()) + /// # } + /// ``` + pub fn try_new(schema: Arc, columns: Vec>) -> Result { + let options = RecordBatchOptions::default(); + Self::validate_new_batch(&schema, columns.as_slice(), &options)?; + Ok(RecordBatch { schema, columns }) + } + + /// Creates a [`RecordBatch`] from a schema and columns, with additional options, + /// such as whether to strictly validate field names. + /// + /// See [`Self::try_new()`] for the expected conditions. + pub fn try_new_with_options( + schema: Arc, + columns: Vec>, + options: &RecordBatchOptions, + ) -> Result { + Self::validate_new_batch(&schema, &columns, options)?; + Ok(RecordBatch { schema, columns }) + } + + /// Creates a new empty [`RecordBatch`]. + pub fn new_empty(schema: Arc) -> Self { + let columns = schema + .fields() + .iter() + .map(|field| new_empty_array(field.data_type().clone()).into()) + .collect(); + RecordBatch { schema, columns } + } + + /// Creates a new [`RecordBatch`] from a [`arrow::chunk::Chunk`] + pub fn new_with_chunk(schema: &Arc, chunk: Chunk) -> Self { + Self { + schema: schema.clone(), + columns: chunk.into_arrays(), + } + } + + /// Validate the schema and columns using [`RecordBatchOptions`]. Returns an error + /// if any validation check fails. + fn validate_new_batch( + schema: &Schema, + columns: &[Arc], + options: &RecordBatchOptions, + ) -> Result<()> { + // check that there are some columns + if columns.is_empty() { + return Err(ArrowError::InvalidArgumentError( + "at least one column must be defined to create a record batch" + .to_string(), + )); + } + // check that number of fields in schema match column length + if schema.fields().len() != columns.len() { + return Err(ArrowError::InvalidArgumentError(format!( + "number of columns({}) must match number of fields({}) in schema", + columns.len(), + schema.fields().len(), + ))); + } + // check that all columns have the same row count, and match the schema + let len = columns[0].len(); + + // This is a bit repetitive, but it is better to check the condition outside the loop + if options.match_field_names { + for (i, column) in columns.iter().enumerate() { + if column.len() != len { + return Err(ArrowError::InvalidArgumentError( + "all columns in a record batch must have the same length" + .to_string(), + )); + } + if column.data_type() != schema.field(i).data_type() { + return Err(ArrowError::InvalidArgumentError(format!( + "column types must match schema types, expected {:?} but found {:?} at column index {}", + schema.field(i).data_type(), + column.data_type(), + i))); + } + } + } else { + for (i, column) in columns.iter().enumerate() { + if column.len() != len { + return Err(ArrowError::InvalidArgumentError( + "all columns in a record batch must have the same length" + .to_string(), + )); + } + if !column.data_type().eq(schema.field(i).data_type()) { + return Err(ArrowError::InvalidArgumentError(format!( + "column types must match schema types, expected {:?} but found {:?} at column index {}", + schema.field(i).data_type(), + column.data_type(), + i))); + } + } + } + + Ok(()) + } + + /// Returns the [`Schema`](crate::datatypes::Schema) of the record batch. + pub fn schema(&self) -> &Arc { + &self.schema + } + + /// Returns the number of columns in the record batch. + /// + /// # Example + /// + /// ``` + /// # use std::sync::Arc; + /// # use arrow2::array::PrimitiveArray; + /// # use arrow2::datatypes::{Schema, Field, DataType}; + /// # use arrow2::record_batch::RecordBatch; + /// # fn main() -> arrow2::error::Result<()> { + /// let id_array = PrimitiveArray::from_slice([1i32, 2, 3, 4, 5]); + /// let schema = Arc::new(Schema::new(vec![ + /// Field::new("id", DataType::Int32, false) + /// ])); + /// + /// let batch = RecordBatch::try_new(schema, vec![Arc::new(id_array)])?; + /// + /// assert_eq!(batch.num_columns(), 1); + /// # Ok(()) + /// # } + /// ``` + pub fn num_columns(&self) -> usize { + self.columns.len() + } + + /// Returns the number of rows in each column. + /// + /// # Panics + /// + /// Panics if the `RecordBatch` contains no columns. + /// + /// # Example + /// + /// ``` + /// # use std::sync::Arc; + /// # use arrow2::array::PrimitiveArray; + /// # use arrow2::datatypes::{Schema, Field, DataType}; + /// # use arrow2::record_batch::RecordBatch; + /// # fn main() -> arrow2::error::Result<()> { + /// let id_array = PrimitiveArray::from_slice([1i32, 2, 3, 4, 5]); + /// let schema = Arc::new(Schema::new(vec![ + /// Field::new("id", DataType::Int32, false) + /// ])); + /// + /// let batch = RecordBatch::try_new(schema, vec![Arc::new(id_array)])?; + /// + /// assert_eq!(batch.num_rows(), 5); + /// # Ok(()) + /// # } + /// ``` + pub fn num_rows(&self) -> usize { + self.columns[0].len() + } + + /// Get a reference to a column's array by index. + /// + /// # Panics + /// + /// Panics if `index` is outside of `0..num_columns`. + pub fn column(&self, index: usize) -> &Arc { + &self.columns[index] + } + + /// Get a reference to all columns in the record batch. + pub fn columns(&self) -> &[Arc] { + &self.columns[..] + } + + /// Create a `RecordBatch` from an iterable list of pairs of the + /// form `(field_name, array)`, with the same requirements on + /// fields and arrays as [`RecordBatch::try_new`]. This method is + /// often used to create a single `RecordBatch` from arrays, + /// e.g. for testing. + /// + /// The resulting schema is marked as nullable for each column if + /// the array for that column is has any nulls. To explicitly + /// specify nullibility, use [`RecordBatch::try_from_iter_with_nullable`] + /// + /// Example: + /// ``` + /// use std::sync::Arc; + /// use arrow::array::*; + /// use arrow::datatypes::DataType; + /// use datafusion::record_batch::RecordBatch; + /// + /// let a: Arc = Arc::new(Int32Array::from_slice(&[1, 2])); + /// let b: Arc = Arc::new(Utf8Array::::from_slice(&["a", "b"])); + /// + /// let record_batch = RecordBatch::try_from_iter(vec![ + /// ("a", a), + /// ("b", b), + /// ]); + /// ``` + pub fn try_from_iter(value: I) -> Result + where + I: IntoIterator)>, + F: AsRef, + { + // TODO: implement `TryFrom` trait, once + // https://github.com/rust-lang/rust/issues/50133 is no longer an + // issue + let iter = value.into_iter().map(|(field_name, array)| { + let nullable = array.null_count() > 0; + (field_name, array, nullable) + }); + + Self::try_from_iter_with_nullable(iter) + } + + /// Create a `RecordBatch` from an iterable list of tuples of the + /// form `(field_name, array, nullable)`, with the same requirements on + /// fields and arrays as [`RecordBatch::try_new`]. This method is often + /// used to create a single `RecordBatch` from arrays, e.g. for + /// testing. + /// + /// Example: + /// ``` + /// use std::sync::Arc; + /// use arrow::array::*; + /// use arrow::datatypes::DataType; + /// use datafusion::record_batch::RecordBatch; + /// + /// let a: Arc = Arc::new(Int32Array::from_slice(&[1, 2])); + /// let b: Arc = Arc::new(Utf8Array::::from_slice(&["a", "b"])); + /// + /// // Note neither `a` nor `b` has any actual nulls, but we mark + /// // b an nullable + /// let record_batch = RecordBatch::try_from_iter_with_nullable(vec![ + /// ("a", a, false), + /// ("b", b, true), + /// ]); + /// ``` + pub fn try_from_iter_with_nullable(value: I) -> Result + where + I: IntoIterator, bool)>, + F: AsRef, + { + // TODO: implement `TryFrom` trait, once + // https://github.com/rust-lang/rust/issues/50133 is no longer an + // issue + let (fields, columns) = value + .into_iter() + .map(|(field_name, array, nullable)| { + let field_name = field_name.as_ref(); + let field = Field::new(field_name, array.data_type().clone(), nullable); + (field, array) + }) + .unzip(); + + let schema = Arc::new(Schema::new(fields)); + RecordBatch::try_new(schema, columns) + } + + /// Deconstructs itself into its internal components + pub fn into_inner(self) -> (Vec>, Arc) { + let Self { columns, schema } = self; + (columns, schema) + } + + /// Projects the schema onto the specified columns + pub fn project(&self, indices: &[usize]) -> Result { + let projected_schema = self.schema.project(indices)?; + let batch_fields = indices + .iter() + .map(|f| { + self.columns.get(*f).cloned().ok_or_else(|| { + ArrowError::InvalidArgumentError(format!( + "project index {} out of bounds, max field {}", + f, + self.columns.len() + )) + }) + }) + .collect::>>()?; + + RecordBatch::try_new(SchemaRef::new(projected_schema), batch_fields) + } + + /// Return a new RecordBatch where each column is sliced + /// according to `offset` and `length` + /// + /// # Panics + /// + /// Panics if `offset` with `length` is greater than column length. + pub fn slice(&self, offset: usize, length: usize) -> RecordBatch { + if self.schema.fields().is_empty() { + assert!((offset + length) == 0); + return RecordBatch::new_empty(self.schema.clone()); + } + assert!((offset + length) <= self.num_rows()); + + let columns = self + .columns() + .iter() + .map(|column| Arc::from(column.slice(offset, length))) + .collect(); + + Self { + schema: self.schema.clone(), + columns, + } + } +} + +/// Options that control the behaviour used when creating a [`RecordBatch`]. +#[derive(Debug)] +pub struct RecordBatchOptions { + /// Match field names of structs and lists. If set to `true`, the names must match. + pub match_field_names: bool, +} + +impl Default for RecordBatchOptions { + fn default() -> Self { + Self { + match_field_names: true, + } + } +} + +impl From for RecordBatch { + /// # Panics iff the null count of the array is not null. + fn from(array: StructArray) -> Self { + assert!(array.null_count() == 0); + let (fields, values, _) = array.into_data(); + RecordBatch { + schema: Arc::new(Schema::new(fields)), + columns: values, + } + } +} + +impl From for StructArray { + fn from(batch: RecordBatch) -> Self { + let (fields, values) = batch + .schema + .fields + .iter() + .zip(batch.columns.iter()) + .map(|t| (t.0.clone(), t.1.clone())) + .unzip(); + StructArray::from_data(DataType::Struct(fields), values, None) + } +} + +impl From for Chunk { + fn from(rb: RecordBatch) -> Self { + Chunk::new(rb.columns) + } +} + +impl From<&RecordBatch> for Chunk { + fn from(rb: &RecordBatch) -> Self { + Chunk::new(rb.columns.clone()) + } +} + +/// Returns a new [RecordBatch] with arrays containing only values matching the filter. +/// WARNING: the nulls of `filter` are ignored and the value on its slot is considered. +/// Therefore, it is considered undefined behavior to pass `filter` with null values. +pub fn filter_record_batch( + record_batch: &RecordBatch, + filter_values: &BooleanArray, +) -> Result { + let num_colums = record_batch.columns().len(); + + let filtered_arrays = match num_colums { + 1 => { + vec![filter(record_batch.columns()[0].as_ref(), filter_values)?.into()] + } + _ => { + let filter = build_filter(filter_values)?; + record_batch + .columns() + .iter() + .map(|a| filter(a.as_ref()).into()) + .collect() + } + }; + RecordBatch::try_new(record_batch.schema().clone(), filtered_arrays) +} diff --git a/datafusion-common/src/scalar.rs b/datafusion-common/src/scalar.rs new file mode 100644 index 000000000000..c87ea73f7ab7 --- /dev/null +++ b/datafusion-common/src/scalar.rs @@ -0,0 +1,2992 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! This module provides ScalarValue, an enum that can be used for storage of single elements + +use std::{convert::TryFrom, fmt, iter::repeat, sync::Arc}; + +use crate::error::{DataFusionError, Result}; +use crate::field_util::{FieldExt, StructArrayExt}; +use arrow::bitmap::Bitmap; +use arrow::buffer::Buffer; +use arrow::compute::concatenate; +use arrow::datatypes::DataType::Decimal; +use arrow::{ + array::*, + datatypes::{DataType, Field, IntegerType, IntervalUnit, TimeUnit}, + scalar::{PrimitiveScalar, Scalar}, + types::{days_ms, NativeType}, +}; +use ordered_float::OrderedFloat; +use std::cmp::Ordering; +use std::convert::{Infallible, TryInto}; +use std::str::FromStr; + +type StringArray = Utf8Array; +type LargeStringArray = Utf8Array; +type SmallBinaryArray = BinaryArray; +type LargeBinaryArray = BinaryArray; +type MutableStringArray = MutableUtf8Array; +type MutableLargeStringArray = MutableUtf8Array; + +// TODO may need to be moved to arrow-rs +/// The max precision and scale for decimal128 +pub const MAX_PRECISION_FOR_DECIMAL128: usize = 38; +pub const MAX_SCALE_FOR_DECIMAL128: usize = 38; + +/// Represents a dynamically typed, nullable single value. +/// This is the single-valued counter-part of arrow’s `Array`. +#[derive(Clone)] +pub enum ScalarValue { + /// true or false value + Boolean(Option), + /// 32bit float + Float32(Option), + /// 64bit float + Float64(Option), + /// 128bit decimal, using the i128 to represent the decimal + Decimal128(Option, usize, usize), + /// signed 8bit int + Int8(Option), + /// signed 16bit int + Int16(Option), + /// signed 32bit int + Int32(Option), + /// signed 64bit int + Int64(Option), + /// unsigned 8bit int + UInt8(Option), + /// unsigned 16bit int + UInt16(Option), + /// unsigned 32bit int + UInt32(Option), + /// unsigned 64bit int + UInt64(Option), + /// utf-8 encoded string. + Utf8(Option), + /// utf-8 encoded string representing a LargeString's arrow type. + LargeUtf8(Option), + /// binary + Binary(Option>), + /// large binary + LargeBinary(Option>), + /// list of nested ScalarValue (boxed to reduce size_of(ScalarValue)) + #[allow(clippy::box_collection)] + List(Option>>, Box), + /// Date stored as a signed 32bit int + Date32(Option), + /// Date stored as a signed 64bit int + Date64(Option), + /// Timestamp Second + TimestampSecond(Option, Option), + /// Timestamp Milliseconds + TimestampMillisecond(Option, Option), + /// Timestamp Microseconds + TimestampMicrosecond(Option, Option), + /// Timestamp Nanoseconds + TimestampNanosecond(Option, Option), + /// Interval with YearMonth unit + IntervalYearMonth(Option), + /// Interval with DayTime unit + IntervalDayTime(Option), + /// Interval with MonthDayNano unit + IntervalMonthDayNano(Option), + /// struct of nested ScalarValue (boxed to reduce size_of(ScalarValue)) + #[allow(clippy::box_collection)] + Struct(Option>>, Box>), +} + +// manual implementation of `PartialEq` that uses OrderedFloat to +// get defined behavior for floating point +impl PartialEq for ScalarValue { + fn eq(&self, other: &Self) -> bool { + use ScalarValue::*; + // This purposely doesn't have a catch-all "(_, _)" so that + // any newly added enum variant will require editing this list + // or else face a compile error + match (self, other) { + (Decimal128(v1, p1, s1), Decimal128(v2, p2, s2)) => { + v1.eq(v2) && p1.eq(p2) && s1.eq(s2) + } + (Decimal128(_, _, _), _) => false, + (Boolean(v1), Boolean(v2)) => v1.eq(v2), + (Boolean(_), _) => false, + (Float32(v1), Float32(v2)) => { + let v1 = v1.map(OrderedFloat); + let v2 = v2.map(OrderedFloat); + v1.eq(&v2) + } + (Float32(_), _) => false, + (Float64(v1), Float64(v2)) => { + let v1 = v1.map(OrderedFloat); + let v2 = v2.map(OrderedFloat); + v1.eq(&v2) + } + (Float64(_), _) => false, + (Int8(v1), Int8(v2)) => v1.eq(v2), + (Int8(_), _) => false, + (Int16(v1), Int16(v2)) => v1.eq(v2), + (Int16(_), _) => false, + (Int32(v1), Int32(v2)) => v1.eq(v2), + (Int32(_), _) => false, + (Int64(v1), Int64(v2)) => v1.eq(v2), + (Int64(_), _) => false, + (UInt8(v1), UInt8(v2)) => v1.eq(v2), + (UInt8(_), _) => false, + (UInt16(v1), UInt16(v2)) => v1.eq(v2), + (UInt16(_), _) => false, + (UInt32(v1), UInt32(v2)) => v1.eq(v2), + (UInt32(_), _) => false, + (UInt64(v1), UInt64(v2)) => v1.eq(v2), + (UInt64(_), _) => false, + (Utf8(v1), Utf8(v2)) => v1.eq(v2), + (Utf8(_), _) => false, + (LargeUtf8(v1), LargeUtf8(v2)) => v1.eq(v2), + (LargeUtf8(_), _) => false, + (Binary(v1), Binary(v2)) => v1.eq(v2), + (Binary(_), _) => false, + (LargeBinary(v1), LargeBinary(v2)) => v1.eq(v2), + (LargeBinary(_), _) => false, + (List(v1, t1), List(v2, t2)) => v1.eq(v2) && t1.eq(t2), + (List(_, _), _) => false, + (Date32(v1), Date32(v2)) => v1.eq(v2), + (Date32(_), _) => false, + (Date64(v1), Date64(v2)) => v1.eq(v2), + (Date64(_), _) => false, + (TimestampSecond(v1, _), TimestampSecond(v2, _)) => v1.eq(v2), + (TimestampSecond(_, _), _) => false, + (TimestampMillisecond(v1, _), TimestampMillisecond(v2, _)) => v1.eq(v2), + (TimestampMillisecond(_, _), _) => false, + (TimestampMicrosecond(v1, _), TimestampMicrosecond(v2, _)) => v1.eq(v2), + (TimestampMicrosecond(_, _), _) => false, + (TimestampNanosecond(v1, _), TimestampNanosecond(v2, _)) => v1.eq(v2), + (TimestampNanosecond(_, _), _) => false, + (IntervalYearMonth(v1), IntervalYearMonth(v2)) => v1.eq(v2), + (IntervalYearMonth(_), _) => false, + (IntervalDayTime(v1), IntervalDayTime(v2)) => v1.eq(v2), + (IntervalDayTime(_), _) => false, + (IntervalMonthDayNano(v1), IntervalMonthDayNano(v2)) => v1.eq(v2), + (IntervalMonthDayNano(_), _) => false, + (Struct(v1, t1), Struct(v2, t2)) => v1.eq(v2) && t1.eq(t2), + (Struct(_, _), _) => false, + } + } +} + +// manual implementation of `PartialOrd` that uses OrderedFloat to +// get defined behavior for floating point +impl PartialOrd for ScalarValue { + fn partial_cmp(&self, other: &Self) -> Option { + use ScalarValue::*; + // This purposely doesn't have a catch-all "(_, _)" so that + // any newly added enum variant will require editing this list + // or else face a compile error + match (self, other) { + (Decimal128(v1, p1, s1), Decimal128(v2, p2, s2)) => { + if p1.eq(p2) && s1.eq(s2) { + v1.partial_cmp(v2) + } else { + // Two decimal values can be compared if they have the same precision and scale. + None + } + } + (Decimal128(_, _, _), _) => None, + (Boolean(v1), Boolean(v2)) => v1.partial_cmp(v2), + (Boolean(_), _) => None, + (Float32(v1), Float32(v2)) => { + let v1 = v1.map(OrderedFloat); + let v2 = v2.map(OrderedFloat); + v1.partial_cmp(&v2) + } + (Float32(_), _) => None, + (Float64(v1), Float64(v2)) => { + let v1 = v1.map(OrderedFloat); + let v2 = v2.map(OrderedFloat); + v1.partial_cmp(&v2) + } + (Float64(_), _) => None, + (Int8(v1), Int8(v2)) => v1.partial_cmp(v2), + (Int8(_), _) => None, + (Int16(v1), Int16(v2)) => v1.partial_cmp(v2), + (Int16(_), _) => None, + (Int32(v1), Int32(v2)) => v1.partial_cmp(v2), + (Int32(_), _) => None, + (Int64(v1), Int64(v2)) => v1.partial_cmp(v2), + (Int64(_), _) => None, + (UInt8(v1), UInt8(v2)) => v1.partial_cmp(v2), + (UInt8(_), _) => None, + (UInt16(v1), UInt16(v2)) => v1.partial_cmp(v2), + (UInt16(_), _) => None, + (UInt32(v1), UInt32(v2)) => v1.partial_cmp(v2), + (UInt32(_), _) => None, + (UInt64(v1), UInt64(v2)) => v1.partial_cmp(v2), + (UInt64(_), _) => None, + (Utf8(v1), Utf8(v2)) => v1.partial_cmp(v2), + (Utf8(_), _) => None, + (LargeUtf8(v1), LargeUtf8(v2)) => v1.partial_cmp(v2), + (LargeUtf8(_), _) => None, + (Binary(v1), Binary(v2)) => v1.partial_cmp(v2), + (Binary(_), _) => None, + (LargeBinary(v1), LargeBinary(v2)) => v1.partial_cmp(v2), + (LargeBinary(_), _) => None, + (List(v1, t1), List(v2, t2)) => { + if t1.eq(t2) { + v1.partial_cmp(v2) + } else { + None + } + } + (List(_, _), _) => None, + (Date32(v1), Date32(v2)) => v1.partial_cmp(v2), + (Date32(_), _) => None, + (Date64(v1), Date64(v2)) => v1.partial_cmp(v2), + (Date64(_), _) => None, + (TimestampSecond(v1, _), TimestampSecond(v2, _)) => v1.partial_cmp(v2), + (TimestampSecond(_, _), _) => None, + (TimestampMillisecond(v1, _), TimestampMillisecond(v2, _)) => { + v1.partial_cmp(v2) + } + (TimestampMillisecond(_, _), _) => None, + (TimestampMicrosecond(v1, _), TimestampMicrosecond(v2, _)) => { + v1.partial_cmp(v2) + } + (TimestampMicrosecond(_, _), _) => None, + (TimestampNanosecond(v1, _), TimestampNanosecond(v2, _)) => { + v1.partial_cmp(v2) + } + (TimestampNanosecond(_, _), _) => None, + (IntervalYearMonth(v1), IntervalYearMonth(v2)) => v1.partial_cmp(v2), + (IntervalYearMonth(_), _) => None, + (_, IntervalDayTime(_)) => None, + (IntervalDayTime(_), _) => None, + (IntervalMonthDayNano(v1), IntervalMonthDayNano(v2)) => v1.partial_cmp(v2), + (IntervalMonthDayNano(_), _) => None, + (Struct(v1, t1), Struct(v2, t2)) => { + if t1.eq(t2) { + v1.partial_cmp(v2) + } else { + None + } + } + (Struct(_, _), _) => None, + } + } +} + +impl Eq for ScalarValue {} + +// manual implementation of `Hash` that uses OrderedFloat to +// get defined behavior for floating point +impl std::hash::Hash for ScalarValue { + fn hash(&self, state: &mut H) { + use ScalarValue::*; + match self { + Decimal128(v, p, s) => { + v.hash(state); + p.hash(state); + s.hash(state) + } + Boolean(v) => v.hash(state), + Float32(v) => { + let v = v.map(OrderedFloat); + v.hash(state) + } + Float64(v) => { + let v = v.map(OrderedFloat); + v.hash(state) + } + Int8(v) => v.hash(state), + Int16(v) => v.hash(state), + Int32(v) => v.hash(state), + Int64(v) => v.hash(state), + UInt8(v) => v.hash(state), + UInt16(v) => v.hash(state), + UInt32(v) => v.hash(state), + UInt64(v) => v.hash(state), + Utf8(v) => v.hash(state), + LargeUtf8(v) => v.hash(state), + Binary(v) => v.hash(state), + LargeBinary(v) => v.hash(state), + List(v, t) => { + v.hash(state); + t.hash(state); + } + Date32(v) => v.hash(state), + Date64(v) => v.hash(state), + TimestampSecond(v, _) => v.hash(state), + TimestampMillisecond(v, _) => v.hash(state), + TimestampMicrosecond(v, _) => v.hash(state), + TimestampNanosecond(v, _) => v.hash(state), + IntervalYearMonth(v) => v.hash(state), + IntervalDayTime(v) => v.hash(state), + IntervalMonthDayNano(v) => v.hash(state), + Struct(v, t) => { + v.hash(state); + t.hash(state); + } + } + } +} + +// return the index into the dictionary values for array@index as well +// as a reference to the dictionary values array. Returns None for the +// index if the array is NULL at index +#[inline] +fn get_dict_value( + array: &ArrayRef, + index: usize, +) -> Result<(&ArrayRef, Option)> { + let dict_array = array.as_any().downcast_ref::>().unwrap(); + + // look up the index in the values dictionary + let keys_col = dict_array.keys(); + if !keys_col.is_valid(index) { + return Ok((dict_array.values(), None)); + } + let values_index = keys_col.value(index).to_usize().ok_or_else(|| { + DataFusionError::Internal(format!( + "Can not convert index to usize in dictionary of type creating group by value {:?}", + keys_col.data_type() + )) + })?; + + Ok((dict_array.values(), Some(values_index))) +} + +macro_rules! typed_cast_tz { + ($array:expr, $index:expr, $SCALAR:ident, $TZ:expr) => {{ + let array = $array.as_any().downcast_ref::().unwrap(); + ScalarValue::$SCALAR( + match array.is_null($index) { + true => None, + false => Some(array.value($index).into()), + }, + $TZ.clone(), + ) + }}; +} + +macro_rules! typed_cast { + ($array:expr, $index:expr, $ARRAYTYPE:ident, $SCALAR:ident) => {{ + let array = $array.as_any().downcast_ref::<$ARRAYTYPE>().unwrap(); + ScalarValue::$SCALAR(match array.is_null($index) { + true => None, + false => Some(array.value($index).into()), + }) + }}; +} + +macro_rules! build_list { + ($VALUE_BUILDER_TY:ident, $SCALAR_TY:ident, $VALUES:expr, $SIZE:expr) => {{ + let dt = DataType::List(Box::new(Field::new("item", DataType::$SCALAR_TY, true))); + match $VALUES { + // the return on the macro is necessary, to short-circuit and return ArrayRef + None => { + return Arc::from(new_null_array(dt, $SIZE)); + } + Some(values) => { + let mut array = MutableListArray::::new_from( + <$VALUE_BUILDER_TY>::default(), + dt, + $SIZE, + ); + build_values_list!(array, $SCALAR_TY, values.as_ref(), $SIZE) + } + } + }}; +} + +macro_rules! build_timestamp_list { + ($TIME_UNIT:expr, $VALUES:expr, $SIZE:expr, $TZ:expr) => {{ + let child_dt = DataType::Timestamp($TIME_UNIT, $TZ.clone()); + match $VALUES { + // the return on the macro is necessary, to short-circuit and return ArrayRef + None => { + let null_array: ArrayRef = new_null_array( + DataType::List(Box::new(Field::new("item", child_dt, true))), + $SIZE, + ) + .into(); + null_array + } + Some(values) => { + let values = values.as_ref(); + let empty_arr = ::default().to(child_dt.clone()); + let mut array = MutableListArray::::new_from( + empty_arr, + DataType::List(Box::new(Field::new("item", child_dt, true))), + $SIZE, + ); + + match $TIME_UNIT { + TimeUnit::Second => { + build_values_list_tz!(array, TimestampSecond, values, $SIZE) + } + TimeUnit::Microsecond => { + build_values_list_tz!(array, TimestampMillisecond, values, $SIZE) + } + TimeUnit::Millisecond => { + build_values_list_tz!(array, TimestampMicrosecond, values, $SIZE) + } + TimeUnit::Nanosecond => { + build_values_list_tz!(array, TimestampNanosecond, values, $SIZE) + } + } + } + } + }}; +} + +macro_rules! build_values_list { + ($MUTABLE_ARR:ident, $SCALAR_TY:ident, $VALUES:expr, $SIZE:expr) => {{ + for _ in 0..$SIZE { + let mut vec = vec![]; + for scalar_value in $VALUES { + match scalar_value { + ScalarValue::$SCALAR_TY(v) => { + vec.push(v.clone()); + } + _ => panic!("Incompatible ScalarValue for list"), + }; + } + $MUTABLE_ARR.try_push(Some(vec)).unwrap(); + } + + let array: ListArray = $MUTABLE_ARR.into(); + Arc::new(array) + }}; +} + +macro_rules! dyn_to_array { + ($self:expr, $value:expr, $size:expr, $ty:ty) => {{ + Arc::new(PrimitiveArray::<$ty>::from_data( + $self.get_datatype(), + Buffer::<$ty>::from_iter(repeat(*$value).take($size)), + None, + )) + }}; +} + +macro_rules! build_values_list_tz { + ($MUTABLE_ARR:ident, $SCALAR_TY:ident, $VALUES:expr, $SIZE:expr) => {{ + for _ in 0..$SIZE { + let mut vec = vec![]; + for scalar_value in $VALUES { + match scalar_value { + ScalarValue::$SCALAR_TY(v, _) => { + vec.push(v.clone()); + } + _ => panic!("Incompatible ScalarValue for list"), + }; + } + $MUTABLE_ARR.try_push(Some(vec)).unwrap(); + } + + let array: ListArray = $MUTABLE_ARR.into(); + Arc::new(array) + }}; +} + +macro_rules! eq_array_primitive { + ($array:expr, $index:expr, $ARRAYTYPE:ident, $VALUE:expr) => {{ + let array = $array.as_any().downcast_ref::<$ARRAYTYPE>().unwrap(); + let is_valid = array.is_valid($index); + match $VALUE { + Some(val) => is_valid && &array.value($index) == val, + None => !is_valid, + } + }}; +} + +impl ScalarValue { + /// Create a decimal Scalar from value/precision and scale. + pub fn try_new_decimal128( + value: i128, + precision: usize, + scale: usize, + ) -> Result { + // make sure the precision and scale is valid + if precision <= MAX_PRECISION_FOR_DECIMAL128 && scale <= precision { + return Ok(ScalarValue::Decimal128(Some(value), precision, scale)); + } + return Err(DataFusionError::Internal(format!( + "Can not new a decimal type ScalarValue for precision {} and scale {}", + precision, scale + ))); + } + + /// Getter for the `DataType` of the value + pub fn get_datatype(&self) -> DataType { + match self { + ScalarValue::Boolean(_) => DataType::Boolean, + ScalarValue::UInt8(_) => DataType::UInt8, + ScalarValue::UInt16(_) => DataType::UInt16, + ScalarValue::UInt32(_) => DataType::UInt32, + ScalarValue::UInt64(_) => DataType::UInt64, + ScalarValue::Int8(_) => DataType::Int8, + ScalarValue::Int16(_) => DataType::Int16, + ScalarValue::Int32(_) => DataType::Int32, + ScalarValue::Int64(_) => DataType::Int64, + ScalarValue::Decimal128(_, precision, scale) => { + DataType::Decimal(*precision, *scale) + } + ScalarValue::TimestampSecond(_, tz_opt) => { + DataType::Timestamp(TimeUnit::Second, tz_opt.clone()) + } + ScalarValue::TimestampMillisecond(_, tz_opt) => { + DataType::Timestamp(TimeUnit::Millisecond, tz_opt.clone()) + } + ScalarValue::TimestampMicrosecond(_, tz_opt) => { + DataType::Timestamp(TimeUnit::Microsecond, tz_opt.clone()) + } + ScalarValue::TimestampNanosecond(_, tz_opt) => { + DataType::Timestamp(TimeUnit::Nanosecond, tz_opt.clone()) + } + ScalarValue::Float32(_) => DataType::Float32, + ScalarValue::Float64(_) => DataType::Float64, + ScalarValue::Utf8(_) => DataType::Utf8, + ScalarValue::LargeUtf8(_) => DataType::LargeUtf8, + ScalarValue::Binary(_) => DataType::Binary, + ScalarValue::LargeBinary(_) => DataType::LargeBinary, + ScalarValue::List(_, data_type) => DataType::List(Box::new(Field::new( + "item", + data_type.as_ref().clone(), + true, + ))), + ScalarValue::Date32(_) => DataType::Date32, + ScalarValue::Date64(_) => DataType::Date64, + ScalarValue::IntervalYearMonth(_) => { + DataType::Interval(IntervalUnit::YearMonth) + } + ScalarValue::IntervalDayTime(_) => DataType::Interval(IntervalUnit::DayTime), + ScalarValue::IntervalMonthDayNano(_) => { + DataType::Interval(IntervalUnit::MonthDayNano) + } + ScalarValue::Struct(_, fields) => DataType::Struct(fields.as_ref().clone()), + } + } + + /// Calculate arithmetic negation for a scalar value + pub fn arithmetic_negate(&self) -> Self { + match self { + ScalarValue::Boolean(None) + | ScalarValue::Int8(None) + | ScalarValue::Int16(None) + | ScalarValue::Int32(None) + | ScalarValue::Int64(None) + | ScalarValue::Float32(None) => self.clone(), + ScalarValue::Float64(Some(v)) => ScalarValue::Float64(Some(-v)), + ScalarValue::Float32(Some(v)) => ScalarValue::Float32(Some(-v)), + ScalarValue::Int8(Some(v)) => ScalarValue::Int8(Some(-v)), + ScalarValue::Int16(Some(v)) => ScalarValue::Int16(Some(-v)), + ScalarValue::Int32(Some(v)) => ScalarValue::Int32(Some(-v)), + ScalarValue::Int64(Some(v)) => ScalarValue::Int64(Some(-v)), + ScalarValue::Decimal128(Some(v), precision, scale) => { + ScalarValue::Decimal128(Some(-v), *precision, *scale) + } + _ => panic!("Cannot run arithmetic negate on scalar value: {:?}", self), + } + } + + /// whether this value is null or not. + pub fn is_null(&self) -> bool { + matches!( + *self, + ScalarValue::Boolean(None) + | ScalarValue::UInt8(None) + | ScalarValue::UInt16(None) + | ScalarValue::UInt32(None) + | ScalarValue::UInt64(None) + | ScalarValue::Int8(None) + | ScalarValue::Int16(None) + | ScalarValue::Int32(None) + | ScalarValue::Int64(None) + | ScalarValue::Float32(None) + | ScalarValue::Float64(None) + | ScalarValue::Date32(None) + | ScalarValue::Date64(None) + | ScalarValue::Utf8(None) + | ScalarValue::LargeUtf8(None) + | ScalarValue::List(None, _) + | ScalarValue::TimestampSecond(None, _) + | ScalarValue::TimestampMillisecond(None, _) + | ScalarValue::TimestampMicrosecond(None, _) + | ScalarValue::TimestampNanosecond(None, _) + | ScalarValue::Struct(None, _) + | ScalarValue::Decimal128(None, _, _) // For decimal type, the value is null means ScalarValue::Decimal128 is null. + ) + } + + /// Converts a scalar value into an 1-row array. + pub fn to_array(&self) -> ArrayRef { + self.to_array_of_size(1) + } + + /// Converts an iterator of references [`ScalarValue`] into an [`ArrayRef`] + /// corresponding to those values. For example, + /// + /// Returns an error if the iterator is empty or if the + /// [`ScalarValue`]s are not all the same type + /// + /// Example + /// ``` + /// use datafusion::scalar::ScalarValue; + /// use arrow::array::{BooleanArray, Array}; + /// + /// let scalars = vec![ + /// ScalarValue::Boolean(Some(true)), + /// ScalarValue::Boolean(None), + /// ScalarValue::Boolean(Some(false)), + /// ]; + /// + /// // Build an Array from the list of ScalarValues + /// let array = ScalarValue::iter_to_array(scalars.into_iter()) + /// .unwrap(); + /// + /// let expected: Box = Box::new( + /// BooleanArray::from(vec![ + /// Some(true), + /// None, + /// Some(false) + /// ] + /// )); + /// + /// assert_eq!(&array, &expected); + /// ``` + pub fn iter_to_array( + scalars: impl IntoIterator, + ) -> Result { + let mut scalars = scalars.into_iter().peekable(); + + // figure out the type based on the first element + let data_type = match scalars.peek() { + None => { + return Err(DataFusionError::Internal( + "Empty iterator passed to ScalarValue::iter_to_array".to_string(), + )); + } + Some(sv) => sv.get_datatype(), + }; + + /// Creates an array of $ARRAY_TY by unpacking values of + /// SCALAR_TY for primitive types + macro_rules! build_array_primitive { + ($TY:ty, $SCALAR_TY:ident, $DT:ident) => {{ + { + Arc::new(scalars + .map(|sv| { + if let ScalarValue::$SCALAR_TY(v) = sv { + Ok(v) + } else { + Err(DataFusionError::Internal(format!( + "Inconsistent types in ScalarValue::iter_to_array. \ + Expected {:?}, got {:?}", + data_type, sv + ))) + } + }).collect::>>()?.to($DT) + ) as Arc + } + }}; + } + + macro_rules! build_array_primitive_tz { + ($SCALAR_TY:ident) => {{ + { + let array = scalars + .map(|sv| { + if let ScalarValue::$SCALAR_TY(v, _) = sv { + Ok(v) + } else { + Err(DataFusionError::Internal(format!( + "Inconsistent types in ScalarValue::iter_to_array. \ + Expected {:?}, got {:?}", + data_type, sv + ))) + } + }) + .collect::>()?; + + Arc::new(array) + } + }}; + } + + /// Creates an array of $ARRAY_TY by unpacking values of + /// SCALAR_TY for "string-like" types. + macro_rules! build_array_string { + ($ARRAY_TY:ident, $SCALAR_TY:ident) => {{ + { + let array = scalars + .map(|sv| { + if let ScalarValue::$SCALAR_TY(v) = sv { + Ok(v) + } else { + Err(DataFusionError::Internal(format!( + "Inconsistent types in ScalarValue::iter_to_array. \ + Expected {:?}, got {:?}", + data_type, sv + ))) + } + }) + .collect::>()?; + Arc::new(array) + } + }}; + } + + macro_rules! build_array_list { + ($MUTABLE_TY:ty, $SCALAR_TY:ident) => {{ + let mut array = MutableListArray::::new(); + for scalar in scalars.into_iter() { + match scalar { + ScalarValue::List(Some(xs), _) => { + let xs = *xs; + let mut vec = vec![]; + for s in xs { + match s { + ScalarValue::$SCALAR_TY(o) => { vec.push(o) } + sv => return Err(DataFusionError::Internal(format!( + "Inconsistent types in ScalarValue::iter_to_array. \ + Expected Utf8, got {:?}", + sv + ))), + } + } + array.try_push(Some(vec))?; + } + ScalarValue::List(None, _) => { + array.push_null(); + } + sv => { + return Err(DataFusionError::Internal(format!( + "Inconsistent types in ScalarValue::iter_to_array. \ + Expected List, got {:?}", + sv + ))) + } + } + } + + let array: ListArray = array.into(); + Arc::new(array) + }} + } + + use DataType::*; + let array: Arc = match &data_type { + DataType::Decimal(precision, scale) => { + let decimal_array = + ScalarValue::iter_to_decimal_array(scalars, precision, scale)?; + Arc::new(decimal_array) + } + DataType::Boolean => Arc::new( + scalars + .map(|sv| { + if let ScalarValue::Boolean(v) = sv { + Ok(v) + } else { + Err(DataFusionError::Internal(format!( + "Inconsistent types in ScalarValue::iter_to_array. \ + Expected {:?}, got {:?}", + data_type, sv + ))) + } + }) + .collect::>()?, + ), + Float32 => { + build_array_primitive!(f32, Float32, Float32) + } + Float64 => { + build_array_primitive!(f64, Float64, Float64) + } + Int8 => build_array_primitive!(i8, Int8, Int8), + Int16 => build_array_primitive!(i16, Int16, Int16), + Int32 => build_array_primitive!(i32, Int32, Int32), + Int64 => build_array_primitive!(i64, Int64, Int64), + UInt8 => build_array_primitive!(u8, UInt8, UInt8), + UInt16 => build_array_primitive!(u16, UInt16, UInt16), + UInt32 => build_array_primitive!(u32, UInt32, UInt32), + UInt64 => build_array_primitive!(u64, UInt64, UInt64), + Utf8 => build_array_string!(StringArray, Utf8), + LargeUtf8 => build_array_string!(LargeStringArray, LargeUtf8), + Binary => build_array_string!(SmallBinaryArray, Binary), + LargeBinary => build_array_string!(LargeBinaryArray, LargeBinary), + Date32 => build_array_primitive!(i32, Date32, Date32), + Date64 => build_array_primitive!(i64, Date64, Date64), + Timestamp(TimeUnit::Second, _) => { + build_array_primitive_tz!(TimestampSecond) + } + Timestamp(TimeUnit::Millisecond, _) => { + build_array_primitive_tz!(TimestampMillisecond) + } + Timestamp(TimeUnit::Microsecond, _) => { + build_array_primitive_tz!(TimestampMicrosecond) + } + Timestamp(TimeUnit::Nanosecond, _) => { + build_array_primitive_tz!(TimestampNanosecond) + } + Interval(IntervalUnit::DayTime) => { + build_array_primitive!(days_ms, IntervalDayTime, data_type) + } + Interval(IntervalUnit::YearMonth) => { + build_array_primitive!(i32, IntervalYearMonth, data_type) + } + DataType::List(fields) if fields.data_type() == &DataType::Int8 => { + build_array_list!(Int8Vec, Int8) + } + DataType::List(fields) if fields.data_type() == &DataType::Int16 => { + build_array_list!(Int16Vec, Int16) + } + DataType::List(fields) if fields.data_type() == &DataType::Int32 => { + build_array_list!(Int32Vec, Int32) + } + DataType::List(fields) if fields.data_type() == &DataType::Int64 => { + build_array_list!(Int64Vec, Int64) + } + DataType::List(fields) if fields.data_type() == &DataType::UInt8 => { + build_array_list!(UInt8Vec, UInt8) + } + DataType::List(fields) if fields.data_type() == &DataType::UInt16 => { + build_array_list!(UInt16Vec, UInt16) + } + DataType::List(fields) if fields.data_type() == &DataType::UInt32 => { + build_array_list!(UInt32Vec, UInt32) + } + DataType::List(fields) if fields.data_type() == &DataType::UInt64 => { + build_array_list!(UInt64Vec, UInt64) + } + DataType::List(fields) if fields.data_type() == &DataType::Float32 => { + build_array_list!(Float32Vec, Float32) + } + DataType::List(fields) if fields.data_type() == &DataType::Float64 => { + build_array_list!(Float64Vec, Float64) + } + DataType::List(fields) if fields.data_type() == &DataType::Utf8 => { + build_array_list!(MutableStringArray, Utf8) + } + DataType::List(fields) if fields.data_type() == &DataType::LargeUtf8 => { + build_array_list!(MutableLargeStringArray, LargeUtf8) + } + DataType::List(_) => { + // Fallback case handling homogeneous lists with any ScalarValue element type + let list_array = ScalarValue::iter_to_array_list(scalars, &data_type)?; + Arc::new(list_array) + } + DataType::Struct(fields) => { + // Initialize a Vector to store the ScalarValues for each column + let mut columns: Vec> = + (0..fields.len()).map(|_| Vec::new()).collect(); + + // Iterate over scalars to populate the column scalars for each row + for scalar in scalars { + if let ScalarValue::Struct(values, fields) = scalar { + match values { + Some(values) => { + // Push value for each field + for c in 0..columns.len() { + let column = columns.get_mut(c).unwrap(); + column.push(values[c].clone()); + } + } + None => { + // Push NULL of the appropriate type for each field + for c in 0..columns.len() { + let dtype = fields[c].data_type(); + let column = columns.get_mut(c).unwrap(); + column.push(ScalarValue::try_from(dtype)?); + } + } + }; + } else { + return Err(DataFusionError::Internal(format!( + "Expected Struct but found: {}", + scalar + ))); + }; + } + + // Call iter_to_array recursively to convert the scalars for each column into Arrow arrays + let field_values = columns + .iter() + .map(|c| Self::iter_to_array(c.clone()).map(Arc::from)) + .collect::>>()?; + + Arc::new(StructArray::from_data(data_type, field_values, None)) + } + _ => { + return Err(DataFusionError::Internal(format!( + "Unsupported creation of {:?} array from ScalarValue {:?}", + data_type, + scalars.peek() + ))); + } + }; + + Ok(array) + } + + fn iter_to_decimal_array( + scalars: impl IntoIterator, + precision: &usize, + scale: &usize, + ) -> Result { + // collect the value as Option + let array = scalars + .into_iter() + .map(|element: ScalarValue| match element { + ScalarValue::Decimal128(v1, _, _) => v1, + _ => unreachable!(), + }) + .collect::>>(); + + // build the decimal array using the Decimal Builder + Ok(Int128Vec::from(array) + .to(Decimal(*precision, *scale)) + .into()) + } + + fn iter_to_array_list( + scalars: impl IntoIterator, + data_type: &DataType, + ) -> Result> { + let mut offsets: Vec = vec![0]; + + let mut elements: Vec = Vec::new(); + let mut valid: Vec = vec![]; + + let mut flat_len = 0i32; + for scalar in scalars { + if let ScalarValue::List(values, _) = scalar { + match values { + Some(values) => { + let element_array = ScalarValue::iter_to_array(*values)?; + + // Add new offset index + flat_len += element_array.len() as i32; + offsets.push(flat_len); + + elements.push(element_array); + + // Element is valid + valid.push(true); + } + None => { + // Repeat previous offset index + offsets.push(flat_len); + + // Element is null + valid.push(false); + } + } + } else { + return Err(DataFusionError::Internal(format!( + "Expected ScalarValue::List element. Received {:?}", + scalar + ))); + } + } + + // Concatenate element arrays to create single flat array + let element_arrays: Vec<&dyn Array> = + elements.iter().map(|a| a.as_ref()).collect(); + let flat_array = match concatenate::concatenate(&element_arrays) { + Ok(flat_array) => flat_array, + Err(err) => return Err(DataFusionError::ArrowError(err)), + }; + + let list_array = ListArray::::from_data( + data_type.clone(), + Buffer::from(offsets), + flat_array.into(), + Some(Bitmap::from(valid)), + ); + + Ok(list_array) + } + + /// Converts a scalar value into an array of `size` rows. + pub fn to_array_of_size(&self, size: usize) -> ArrayRef { + match self { + ScalarValue::Decimal128(e, precision, scale) => { + Int128Vec::from_iter(repeat(e).take(size)) + .to(Decimal(*precision, *scale)) + .into_arc() + } + ScalarValue::Boolean(e) => { + Arc::new(BooleanArray::from(vec![*e; size])) as ArrayRef + } + ScalarValue::Float64(e) => match e { + Some(value) => { + dyn_to_array!(self, value, size, f64) + } + None => new_null_array(self.get_datatype(), size).into(), + }, + ScalarValue::Float32(e) => match e { + Some(value) => dyn_to_array!(self, value, size, f32), + None => new_null_array(self.get_datatype(), size).into(), + }, + ScalarValue::Int8(e) => match e { + Some(value) => dyn_to_array!(self, value, size, i8), + None => new_null_array(self.get_datatype(), size).into(), + }, + ScalarValue::Int16(e) => match e { + Some(value) => dyn_to_array!(self, value, size, i16), + None => new_null_array(self.get_datatype(), size).into(), + }, + ScalarValue::Int32(e) + | ScalarValue::Date32(e) + | ScalarValue::IntervalYearMonth(e) => match e { + Some(value) => dyn_to_array!(self, value, size, i32), + None => new_null_array(self.get_datatype(), size).into(), + }, + ScalarValue::IntervalMonthDayNano(e) => match e { + Some(value) => dyn_to_array!(self, value, size, i128), + None => new_null_array(self.get_datatype(), size).into(), + }, + ScalarValue::Int64(e) | ScalarValue::Date64(e) => match e { + Some(value) => dyn_to_array!(self, value, size, i64), + None => new_null_array(self.get_datatype(), size).into(), + }, + ScalarValue::UInt8(e) => match e { + Some(value) => dyn_to_array!(self, value, size, u8), + None => new_null_array(self.get_datatype(), size).into(), + }, + ScalarValue::UInt16(e) => match e { + Some(value) => dyn_to_array!(self, value, size, u16), + None => new_null_array(self.get_datatype(), size).into(), + }, + ScalarValue::UInt32(e) => match e { + Some(value) => dyn_to_array!(self, value, size, u32), + None => new_null_array(self.get_datatype(), size).into(), + }, + ScalarValue::UInt64(e) => match e { + Some(value) => dyn_to_array!(self, value, size, u64), + None => new_null_array(self.get_datatype(), size).into(), + }, + ScalarValue::TimestampSecond(e, _) => match e { + Some(value) => dyn_to_array!(self, value, size, i64), + None => new_null_array(self.get_datatype(), size).into(), + }, + ScalarValue::TimestampMillisecond(e, _) => match e { + Some(value) => dyn_to_array!(self, value, size, i64), + None => new_null_array(self.get_datatype(), size).into(), + }, + + ScalarValue::TimestampMicrosecond(e, _) => match e { + Some(value) => dyn_to_array!(self, value, size, i64), + None => new_null_array(self.get_datatype(), size).into(), + }, + ScalarValue::TimestampNanosecond(e, _) => match e { + Some(value) => dyn_to_array!(self, value, size, i64), + None => new_null_array(self.get_datatype(), size).into(), + }, + ScalarValue::Utf8(e) => match e { + Some(value) => Arc::new(Utf8Array::::from_trusted_len_values_iter( + repeat(&value).take(size), + )), + None => new_null_array(self.get_datatype(), size).into(), + }, + ScalarValue::LargeUtf8(e) => match e { + Some(value) => Arc::new(Utf8Array::::from_trusted_len_values_iter( + repeat(&value).take(size), + )), + None => new_null_array(self.get_datatype(), size).into(), + }, + ScalarValue::Binary(e) => match e { + Some(value) => Arc::new( + repeat(Some(value.as_slice())) + .take(size) + .collect::>(), + ), + None => new_null_array(self.get_datatype(), size).into(), + }, + ScalarValue::LargeBinary(e) => match e { + Some(value) => Arc::new( + repeat(Some(value.as_slice())) + .take(size) + .collect::>(), + ), + None => new_null_array(self.get_datatype(), size).into(), + }, + ScalarValue::List(values, data_type) => match data_type.as_ref() { + DataType::Boolean => { + build_list!(MutableBooleanArray, Boolean, values, size) + } + DataType::Int8 => build_list!(Int8Vec, Int8, values, size), + DataType::Int16 => build_list!(Int16Vec, Int16, values, size), + DataType::Int32 => build_list!(Int32Vec, Int32, values, size), + DataType::Int64 => build_list!(Int64Vec, Int64, values, size), + DataType::UInt8 => build_list!(UInt8Vec, UInt8, values, size), + DataType::UInt16 => build_list!(UInt16Vec, UInt16, values, size), + DataType::UInt32 => build_list!(UInt32Vec, UInt32, values, size), + DataType::UInt64 => build_list!(UInt64Vec, UInt64, values, size), + DataType::Float32 => build_list!(Float32Vec, Float32, values, size), + DataType::Float64 => build_list!(Float64Vec, Float64, values, size), + DataType::Timestamp(unit, tz) => { + build_timestamp_list!(*unit, values, size, tz.clone()) + } + DataType::Utf8 => build_list!(MutableStringArray, Utf8, values, size), + DataType::LargeUtf8 => { + build_list!(MutableLargeStringArray, LargeUtf8, values, size) + } + dt => panic!("Unexpected DataType for list {:?}", dt), + }, + ScalarValue::IntervalDayTime(e) => match e { + Some(value) => { + Arc::new(PrimitiveArray::::from_trusted_len_values_iter( + std::iter::repeat(*value).take(size), + )) + } + None => new_null_array(self.get_datatype(), size).into(), + }, + ScalarValue::Struct(values, _) => match values { + Some(values) => { + let field_values = + values.iter().map(|v| v.to_array_of_size(size)).collect(); + Arc::new(StructArray::from_data( + self.get_datatype(), + field_values, + None, + )) + } + None => Arc::new(StructArray::new_null(self.get_datatype(), size)), + }, + } + } + + fn get_decimal_value_from_array( + array: &ArrayRef, + index: usize, + precision: &usize, + scale: &usize, + ) -> ScalarValue { + let array = array.as_any().downcast_ref::().unwrap(); + if array.is_null(index) { + ScalarValue::Decimal128(None, *precision, *scale) + } else { + ScalarValue::Decimal128(Some(array.value(index)), *precision, *scale) + } + } + + /// Converts a value in `array` at `index` into a ScalarValue + pub fn try_from_array(array: &ArrayRef, index: usize) -> Result { + // handle NULL value + if !array.is_valid(index) { + return array.data_type().try_into(); + } + + Ok(match array.data_type() { + DataType::Decimal(precision, scale) => { + ScalarValue::get_decimal_value_from_array(array, index, precision, scale) + } + DataType::Boolean => typed_cast!(array, index, BooleanArray, Boolean), + DataType::Float64 => typed_cast!(array, index, Float64Array, Float64), + DataType::Float32 => typed_cast!(array, index, Float32Array, Float32), + DataType::UInt64 => typed_cast!(array, index, UInt64Array, UInt64), + DataType::UInt32 => typed_cast!(array, index, UInt32Array, UInt32), + DataType::UInt16 => typed_cast!(array, index, UInt16Array, UInt16), + DataType::UInt8 => typed_cast!(array, index, UInt8Array, UInt8), + DataType::Int64 => typed_cast!(array, index, Int64Array, Int64), + DataType::Int32 => typed_cast!(array, index, Int32Array, Int32), + DataType::Int16 => typed_cast!(array, index, Int16Array, Int16), + DataType::Int8 => typed_cast!(array, index, Int8Array, Int8), + DataType::Binary => typed_cast!(array, index, SmallBinaryArray, Binary), + DataType::LargeBinary => { + typed_cast!(array, index, LargeBinaryArray, LargeBinary) + } + DataType::Utf8 => typed_cast!(array, index, StringArray, Utf8), + DataType::LargeUtf8 => typed_cast!(array, index, LargeStringArray, LargeUtf8), + DataType::List(nested_type) => { + let list_array = array + .as_any() + .downcast_ref::>() + .ok_or_else(|| { + DataFusionError::Internal( + "Failed to downcast ListArray".to_string(), + ) + })?; + let value = match list_array.is_null(index) { + true => None, + false => { + let nested_array = ArrayRef::from(list_array.value(index)); + let scalar_vec = (0..nested_array.len()) + .map(|i| ScalarValue::try_from_array(&nested_array, i)) + .collect::>>()?; + Some(scalar_vec) + } + }; + let value = value.map(Box::new); + let data_type = Box::new(nested_type.data_type().clone()); + ScalarValue::List(value, data_type) + } + DataType::Date32 => { + typed_cast!(array, index, Int32Array, Date32) + } + DataType::Date64 => { + typed_cast!(array, index, Int64Array, Date64) + } + DataType::Timestamp(TimeUnit::Second, tz_opt) => { + typed_cast_tz!(array, index, TimestampSecond, tz_opt) + } + DataType::Timestamp(TimeUnit::Millisecond, tz_opt) => { + typed_cast_tz!(array, index, TimestampMillisecond, tz_opt) + } + DataType::Timestamp(TimeUnit::Microsecond, tz_opt) => { + typed_cast_tz!(array, index, TimestampMicrosecond, tz_opt) + } + DataType::Timestamp(TimeUnit::Nanosecond, tz_opt) => { + typed_cast_tz!(array, index, TimestampNanosecond, tz_opt) + } + DataType::Dictionary(index_type, _, _) => { + let (values, values_index) = match index_type { + IntegerType::Int8 => get_dict_value::(array, index)?, + IntegerType::Int16 => get_dict_value::(array, index)?, + IntegerType::Int32 => get_dict_value::(array, index)?, + IntegerType::Int64 => get_dict_value::(array, index)?, + IntegerType::UInt8 => get_dict_value::(array, index)?, + IntegerType::UInt16 => get_dict_value::(array, index)?, + IntegerType::UInt32 => get_dict_value::(array, index)?, + IntegerType::UInt64 => get_dict_value::(array, index)?, + }; + + match values_index { + Some(values_index) => Self::try_from_array(values, values_index)?, + // was null + None => values.data_type().try_into()?, + } + } + DataType::Struct(fields) => { + let array = + array + .as_any() + .downcast_ref::() + .ok_or_else(|| { + DataFusionError::Internal( + "Failed to downcast ArrayRef to StructArray".to_string(), + ) + })?; + let mut field_values: Vec = Vec::new(); + for col_index in 0..array.num_columns() { + let col_array = &array.values()[col_index]; + let col_scalar = ScalarValue::try_from_array(col_array, index)?; + field_values.push(col_scalar); + } + Self::Struct(Some(Box::new(field_values)), Box::new(fields.clone())) + } + other => { + return Err(DataFusionError::NotImplemented(format!( + "Can't create a scalar from array of type \"{:?}\"", + other + ))); + } + }) + } + + fn eq_array_decimal( + array: &ArrayRef, + index: usize, + value: &Option, + precision: usize, + scale: usize, + ) -> bool { + let array = array.as_any().downcast_ref::().unwrap(); + match array.data_type() { + Decimal(pre, sca) => { + if *pre != precision || *sca != scale { + return false; + } + } + _ => return false, + } + match value { + None => array.is_null(index), + Some(v) => !array.is_null(index) && array.value(index) == *v, + } + } + + /// Compares a single row of array @ index for equality with self, + /// in an optimized fashion. + /// + /// This method implements an optimized version of: + /// + /// ```text + /// let arr_scalar = Self::try_from_array(array, index).unwrap(); + /// arr_scalar.eq(self) + /// ``` + /// + /// *Performance note*: the arrow compute kernels should be + /// preferred over this function if at all possible as they can be + /// vectorized and are generally much faster. + /// + /// This function has a few narrow usescases such as hash table key + /// comparisons where comparing a single row at a time is necessary. + #[inline] + pub fn eq_array(&self, array: &ArrayRef, index: usize) -> bool { + if let DataType::Dictionary(key_type, _, _) = array.data_type() { + return self.eq_array_dictionary(array, index, key_type); + } + + match self { + ScalarValue::Decimal128(v, precision, scale) => { + ScalarValue::eq_array_decimal(array, index, v, *precision, *scale) + } + ScalarValue::Boolean(val) => { + eq_array_primitive!(array, index, BooleanArray, val) + } + ScalarValue::Float32(val) => { + eq_array_primitive!(array, index, Float32Array, val) + } + ScalarValue::Float64(val) => { + eq_array_primitive!(array, index, Float64Array, val) + } + ScalarValue::Int8(val) => eq_array_primitive!(array, index, Int8Array, val), + ScalarValue::Int16(val) => eq_array_primitive!(array, index, Int16Array, val), + ScalarValue::Int32(val) => eq_array_primitive!(array, index, Int32Array, val), + ScalarValue::Int64(val) => eq_array_primitive!(array, index, Int64Array, val), + ScalarValue::UInt8(val) => eq_array_primitive!(array, index, UInt8Array, val), + ScalarValue::UInt16(val) => { + eq_array_primitive!(array, index, UInt16Array, val) + } + ScalarValue::UInt32(val) => { + eq_array_primitive!(array, index, UInt32Array, val) + } + ScalarValue::UInt64(val) => { + eq_array_primitive!(array, index, UInt64Array, val) + } + ScalarValue::Utf8(val) => eq_array_primitive!(array, index, StringArray, val), + ScalarValue::LargeUtf8(val) => { + eq_array_primitive!(array, index, LargeStringArray, val) + } + ScalarValue::Binary(val) => { + eq_array_primitive!(array, index, SmallBinaryArray, val) + } + ScalarValue::LargeBinary(val) => { + eq_array_primitive!(array, index, LargeBinaryArray, val) + } + ScalarValue::List(_, _) => unimplemented!(), + ScalarValue::Date32(val) => { + eq_array_primitive!(array, index, Int32Array, val) + } + ScalarValue::Date64(val) => { + eq_array_primitive!(array, index, Int64Array, val) + } + ScalarValue::TimestampSecond(val, _) => { + eq_array_primitive!(array, index, Int64Array, val) + } + ScalarValue::TimestampMillisecond(val, _) => { + eq_array_primitive!(array, index, Int64Array, val) + } + ScalarValue::TimestampMicrosecond(val, _) => { + eq_array_primitive!(array, index, Int64Array, val) + } + ScalarValue::TimestampNanosecond(val, _) => { + eq_array_primitive!(array, index, Int64Array, val) + } + ScalarValue::IntervalYearMonth(val) => { + eq_array_primitive!(array, index, Int32Array, val) + } + ScalarValue::IntervalDayTime(val) => { + eq_array_primitive!(array, index, DaysMsArray, val) + } + ScalarValue::IntervalMonthDayNano(val) => { + eq_array_primitive!(array, index, Int128Array, val) + } + ScalarValue::Struct(_, _) => unimplemented!(), + } + } + + /// Compares a dictionary array with indexes of type `key_type` + /// with the array @ index for equality with self + fn eq_array_dictionary( + &self, + array: &ArrayRef, + index: usize, + key_type: &IntegerType, + ) -> bool { + let (values, values_index) = match key_type { + IntegerType::Int8 => get_dict_value::(array, index).unwrap(), + IntegerType::Int16 => get_dict_value::(array, index).unwrap(), + IntegerType::Int32 => get_dict_value::(array, index).unwrap(), + IntegerType::Int64 => get_dict_value::(array, index).unwrap(), + IntegerType::UInt8 => get_dict_value::(array, index).unwrap(), + IntegerType::UInt16 => get_dict_value::(array, index).unwrap(), + IntegerType::UInt32 => get_dict_value::(array, index).unwrap(), + IntegerType::UInt64 => get_dict_value::(array, index).unwrap(), + }; + + match values_index { + Some(values_index) => self.eq_array(values, values_index), + None => self.is_null(), + } + } +} + +macro_rules! impl_scalar { + ($ty:ty, $scalar:tt) => { + impl From<$ty> for ScalarValue { + fn from(value: $ty) -> Self { + ScalarValue::$scalar(Some(value)) + } + } + + impl From> for ScalarValue { + fn from(value: Option<$ty>) -> Self { + ScalarValue::$scalar(value) + } + } + }; +} + +impl_scalar!(f64, Float64); +impl_scalar!(f32, Float32); +impl_scalar!(i8, Int8); +impl_scalar!(i16, Int16); +impl_scalar!(i32, Int32); +impl_scalar!(i64, Int64); +impl_scalar!(bool, Boolean); +impl_scalar!(u8, UInt8); +impl_scalar!(u16, UInt16); +impl_scalar!(u32, UInt32); +impl_scalar!(u64, UInt64); + +impl From<&str> for ScalarValue { + fn from(value: &str) -> Self { + Some(value).into() + } +} + +impl From> for ScalarValue { + fn from(value: Option<&str>) -> Self { + let value = value.map(|s| s.to_string()); + ScalarValue::Utf8(value) + } +} + +impl FromStr for ScalarValue { + type Err = Infallible; + + fn from_str(s: &str) -> std::result::Result { + Ok(s.into()) + } +} + +impl From> for ScalarValue { + fn from(value: Vec<(&str, ScalarValue)>) -> Self { + let (fields, scalars): (Vec<_>, Vec<_>) = value + .into_iter() + .map(|(name, scalar)| { + (Field::new(name, scalar.get_datatype(), false), scalar) + }) + .unzip(); + + Self::Struct(Some(Box::new(scalars)), Box::new(fields)) + } +} + +macro_rules! impl_try_from { + ($SCALAR:ident, $NATIVE:ident) => { + impl TryFrom for $NATIVE { + type Error = DataFusionError; + + fn try_from(value: ScalarValue) -> Result { + match value { + ScalarValue::$SCALAR(Some(inner_value)) => Ok(inner_value), + _ => Err(DataFusionError::Internal(format!( + "Cannot convert {:?} to {}", + value, + std::any::type_name::() + ))), + } + } + } + }; +} + +impl_try_from!(Int8, i8); +impl_try_from!(Int16, i16); + +// special implementation for i32 because of Date32 +impl TryFrom for i32 { + type Error = DataFusionError; + + fn try_from(value: ScalarValue) -> Result { + match value { + ScalarValue::Int32(Some(inner_value)) + | ScalarValue::Date32(Some(inner_value)) => Ok(inner_value), + _ => Err(DataFusionError::Internal(format!( + "Cannot convert {:?} to {}", + value, + std::any::type_name::() + ))), + } + } +} + +// special implementation for i64 because of TimeNanosecond +impl TryFrom for i64 { + type Error = DataFusionError; + + fn try_from(value: ScalarValue) -> Result { + match value { + ScalarValue::Int64(Some(inner_value)) + | ScalarValue::Date64(Some(inner_value)) + | ScalarValue::TimestampNanosecond(Some(inner_value), _) + | ScalarValue::TimestampMicrosecond(Some(inner_value), _) + | ScalarValue::TimestampMillisecond(Some(inner_value), _) + | ScalarValue::TimestampSecond(Some(inner_value), _) => Ok(inner_value), + _ => Err(DataFusionError::Internal(format!( + "Cannot convert {:?} to {}", + value, + std::any::type_name::() + ))), + } + } +} + +// special implementation for i128 because of Decimal128 +impl TryFrom for i128 { + type Error = DataFusionError; + + fn try_from(value: ScalarValue) -> Result { + match value { + ScalarValue::Decimal128(Some(inner_value), _, _) => Ok(inner_value), + _ => Err(DataFusionError::Internal(format!( + "Cannot convert {:?} to {}", + value, + std::any::type_name::() + ))), + } + } +} + +impl_try_from!(UInt8, u8); +impl_try_from!(UInt16, u16); +impl_try_from!(UInt32, u32); +impl_try_from!(UInt64, u64); +impl_try_from!(Float32, f32); +impl_try_from!(Float64, f64); +impl_try_from!(Boolean, bool); + +impl TryInto> for &ScalarValue { + type Error = DataFusionError; + + fn try_into(self) -> Result> { + use arrow::scalar::*; + match self { + ScalarValue::Boolean(b) => Ok(Box::new(BooleanScalar::new(*b))), + ScalarValue::Float32(f) => { + Ok(Box::new(PrimitiveScalar::::new(DataType::Float32, *f))) + } + ScalarValue::Float64(f) => { + Ok(Box::new(PrimitiveScalar::::new(DataType::Float64, *f))) + } + ScalarValue::Int8(i) => { + Ok(Box::new(PrimitiveScalar::::new(DataType::Int8, *i))) + } + ScalarValue::Int16(i) => { + Ok(Box::new(PrimitiveScalar::::new(DataType::Int16, *i))) + } + ScalarValue::Int32(i) => { + Ok(Box::new(PrimitiveScalar::::new(DataType::Int32, *i))) + } + ScalarValue::Int64(i) => { + Ok(Box::new(PrimitiveScalar::::new(DataType::Int64, *i))) + } + ScalarValue::UInt8(u) => { + Ok(Box::new(PrimitiveScalar::::new(DataType::UInt8, *u))) + } + ScalarValue::UInt16(u) => { + Ok(Box::new(PrimitiveScalar::::new(DataType::UInt16, *u))) + } + ScalarValue::UInt32(u) => { + Ok(Box::new(PrimitiveScalar::::new(DataType::UInt32, *u))) + } + ScalarValue::UInt64(u) => { + Ok(Box::new(PrimitiveScalar::::new(DataType::UInt64, *u))) + } + ScalarValue::Utf8(s) => Ok(Box::new(Utf8Scalar::::new(s.clone()))), + ScalarValue::LargeUtf8(s) => Ok(Box::new(Utf8Scalar::::new(s.clone()))), + ScalarValue::Binary(b) => Ok(Box::new(BinaryScalar::::new(b.clone()))), + ScalarValue::LargeBinary(b) => { + Ok(Box::new(BinaryScalar::::new(b.clone()))) + } + ScalarValue::Date32(i) => { + Ok(Box::new(PrimitiveScalar::::new(DataType::Date32, *i))) + } + ScalarValue::Date64(i) => { + Ok(Box::new(PrimitiveScalar::::new(DataType::Date64, *i))) + } + ScalarValue::TimestampSecond(i, tz) => { + Ok(Box::new(PrimitiveScalar::::new( + DataType::Timestamp(TimeUnit::Second, tz.clone()), + *i, + ))) + } + ScalarValue::TimestampMillisecond(i, tz) => { + Ok(Box::new(PrimitiveScalar::::new( + DataType::Timestamp(TimeUnit::Millisecond, tz.clone()), + *i, + ))) + } + ScalarValue::TimestampMicrosecond(i, tz) => { + Ok(Box::new(PrimitiveScalar::::new( + DataType::Timestamp(TimeUnit::Microsecond, tz.clone()), + *i, + ))) + } + ScalarValue::TimestampNanosecond(i, tz) => { + Ok(Box::new(PrimitiveScalar::::new( + DataType::Timestamp(TimeUnit::Nanosecond, tz.clone()), + *i, + ))) + } + ScalarValue::IntervalYearMonth(i) => { + Ok(Box::new(PrimitiveScalar::::new( + DataType::Interval(IntervalUnit::YearMonth), + *i, + ))) + } + + // List and IntervalDayTime comparison not possible in arrow2 + _ => Err(DataFusionError::Internal( + "Conversion not possible in arrow2".to_owned(), + )), + } + } +} + +impl TryFrom> for ScalarValue { + type Error = DataFusionError; + + fn try_from(s: PrimitiveScalar) -> Result { + match s.data_type() { + DataType::Timestamp(TimeUnit::Second, tz) => { + let s = s.as_any().downcast_ref::>().unwrap(); + Ok(ScalarValue::TimestampSecond(s.value(), tz.clone())) + } + DataType::Timestamp(TimeUnit::Microsecond, tz) => { + let s = s.as_any().downcast_ref::>().unwrap(); + Ok(ScalarValue::TimestampMicrosecond(s.value(), tz.clone())) + } + DataType::Timestamp(TimeUnit::Millisecond, tz) => { + let s = s.as_any().downcast_ref::>().unwrap(); + Ok(ScalarValue::TimestampMillisecond(s.value(), tz.clone())) + } + DataType::Timestamp(TimeUnit::Nanosecond, tz) => { + let s = s.as_any().downcast_ref::>().unwrap(); + Ok(ScalarValue::TimestampNanosecond(s.value(), tz.clone())) + } + _ => Err(DataFusionError::Internal( + format!( + "Conversion from arrow Scalar to Datafusion ScalarValue not implemented for: {:?}", s)) + ), + } + } +} + +impl TryFrom<&DataType> for ScalarValue { + type Error = DataFusionError; + + /// Create a Null instance of ScalarValue for this datatype + fn try_from(datatype: &DataType) -> Result { + Ok(match datatype { + DataType::Boolean => ScalarValue::Boolean(None), + DataType::Float64 => ScalarValue::Float64(None), + DataType::Float32 => ScalarValue::Float32(None), + DataType::Int8 => ScalarValue::Int8(None), + DataType::Int16 => ScalarValue::Int16(None), + DataType::Int32 => ScalarValue::Int32(None), + DataType::Int64 => ScalarValue::Int64(None), + DataType::UInt8 => ScalarValue::UInt8(None), + DataType::UInt16 => ScalarValue::UInt16(None), + DataType::UInt32 => ScalarValue::UInt32(None), + DataType::UInt64 => ScalarValue::UInt64(None), + DataType::Decimal(precision, scale) => { + ScalarValue::Decimal128(None, *precision, *scale) + } + DataType::Utf8 => ScalarValue::Utf8(None), + DataType::LargeUtf8 => ScalarValue::LargeUtf8(None), + DataType::Date32 => ScalarValue::Date32(None), + DataType::Date64 => ScalarValue::Date64(None), + DataType::Timestamp(TimeUnit::Second, tz_opt) => { + ScalarValue::TimestampSecond(None, tz_opt.clone()) + } + DataType::Timestamp(TimeUnit::Millisecond, tz_opt) => { + ScalarValue::TimestampMillisecond(None, tz_opt.clone()) + } + DataType::Timestamp(TimeUnit::Microsecond, tz_opt) => { + ScalarValue::TimestampMicrosecond(None, tz_opt.clone()) + } + DataType::Timestamp(TimeUnit::Nanosecond, tz_opt) => { + ScalarValue::TimestampNanosecond(None, tz_opt.clone()) + } + DataType::Dictionary(_index_type, value_type, _) => { + value_type.as_ref().try_into()? + } + DataType::List(ref nested_type) => { + ScalarValue::List(None, Box::new(nested_type.data_type().clone())) + } + DataType::Struct(fields) => { + ScalarValue::Struct(None, Box::new(fields.clone())) + } + _ => { + return Err(DataFusionError::NotImplemented(format!( + "Can't create a scalar from data_type \"{:?}\"", + datatype + ))); + } + }) + } +} + +macro_rules! format_option { + ($F:expr, $EXPR:expr) => {{ + match $EXPR { + Some(e) => write!($F, "{}", e), + None => write!($F, "NULL"), + } + }}; +} + +impl fmt::Display for ScalarValue { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match self { + ScalarValue::Decimal128(v, p, s) => { + write!(f, "{}", format_args!("{:?},{:?},{:?}", v, p, s))?; + } + ScalarValue::Boolean(e) => format_option!(f, e)?, + ScalarValue::Float32(e) => format_option!(f, e)?, + ScalarValue::Float64(e) => format_option!(f, e)?, + ScalarValue::Int8(e) => format_option!(f, e)?, + ScalarValue::Int16(e) => format_option!(f, e)?, + ScalarValue::Int32(e) => format_option!(f, e)?, + ScalarValue::Int64(e) => format_option!(f, e)?, + ScalarValue::UInt8(e) => format_option!(f, e)?, + ScalarValue::UInt16(e) => format_option!(f, e)?, + ScalarValue::UInt32(e) => format_option!(f, e)?, + ScalarValue::UInt64(e) => format_option!(f, e)?, + ScalarValue::TimestampSecond(e, _) => format_option!(f, e)?, + ScalarValue::TimestampMillisecond(e, _) => format_option!(f, e)?, + ScalarValue::TimestampMicrosecond(e, _) => format_option!(f, e)?, + ScalarValue::TimestampNanosecond(e, _) => format_option!(f, e)?, + ScalarValue::Utf8(e) => format_option!(f, e)?, + ScalarValue::LargeUtf8(e) => format_option!(f, e)?, + ScalarValue::Binary(e) => match e { + Some(l) => write!( + f, + "{}", + l.iter() + .map(|v| format!("{}", v)) + .collect::>() + .join(",") + )?, + None => write!(f, "NULL")?, + }, + ScalarValue::LargeBinary(e) => match e { + Some(l) => write!( + f, + "{}", + l.iter() + .map(|v| format!("{}", v)) + .collect::>() + .join(",") + )?, + None => write!(f, "NULL")?, + }, + ScalarValue::List(e, _) => match e { + Some(l) => write!( + f, + "{}", + l.iter() + .map(|v| format!("{}", v)) + .collect::>() + .join(",") + )?, + None => write!(f, "NULL")?, + }, + ScalarValue::Date32(e) => format_option!(f, e)?, + ScalarValue::Date64(e) => format_option!(f, e)?, + ScalarValue::IntervalDayTime(e) => format_option!(f, e)?, + ScalarValue::IntervalYearMonth(e) => format_option!(f, e)?, + ScalarValue::IntervalMonthDayNano(e) => format_option!(f, e)?, + ScalarValue::Struct(e, fields) => match e { + Some(l) => write!( + f, + "{{{}}}", + l.iter() + .zip(fields.iter()) + .map(|(value, field)| format!("{}:{}", field.name(), value)) + .collect::>() + .join(",") + )?, + None => write!(f, "NULL")?, + }, + }; + Ok(()) + } +} + +impl fmt::Debug for ScalarValue { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match self { + ScalarValue::Decimal128(_, _, _) => write!(f, "Decimal128({})", self), + ScalarValue::Boolean(_) => write!(f, "Boolean({})", self), + ScalarValue::Float32(_) => write!(f, "Float32({})", self), + ScalarValue::Float64(_) => write!(f, "Float64({})", self), + ScalarValue::Int8(_) => write!(f, "Int8({})", self), + ScalarValue::Int16(_) => write!(f, "Int16({})", self), + ScalarValue::Int32(_) => write!(f, "Int32({})", self), + ScalarValue::Int64(_) => write!(f, "Int64({})", self), + ScalarValue::UInt8(_) => write!(f, "UInt8({})", self), + ScalarValue::UInt16(_) => write!(f, "UInt16({})", self), + ScalarValue::UInt32(_) => write!(f, "UInt32({})", self), + ScalarValue::UInt64(_) => write!(f, "UInt64({})", self), + ScalarValue::TimestampSecond(_, tz_opt) => { + write!(f, "TimestampSecond({}, {:?})", self, tz_opt) + } + ScalarValue::TimestampMillisecond(_, tz_opt) => { + write!(f, "TimestampMillisecond({}, {:?})", self, tz_opt) + } + ScalarValue::TimestampMicrosecond(_, tz_opt) => { + write!(f, "TimestampMicrosecond({}, {:?})", self, tz_opt) + } + ScalarValue::TimestampNanosecond(_, tz_opt) => { + write!(f, "TimestampNanosecond({}, {:?})", self, tz_opt) + } + ScalarValue::Utf8(None) => write!(f, "Utf8({})", self), + ScalarValue::Utf8(Some(_)) => write!(f, "Utf8(\"{}\")", self), + ScalarValue::LargeUtf8(None) => write!(f, "LargeUtf8({})", self), + ScalarValue::LargeUtf8(Some(_)) => write!(f, "LargeUtf8(\"{}\")", self), + ScalarValue::Binary(None) => write!(f, "Binary({})", self), + ScalarValue::Binary(Some(_)) => write!(f, "Binary(\"{}\")", self), + ScalarValue::LargeBinary(None) => write!(f, "LargeBinary({})", self), + ScalarValue::LargeBinary(Some(_)) => write!(f, "LargeBinary(\"{}\")", self), + ScalarValue::List(_, dt) => write!(f, "List[{:?}]([{}])", dt, self), + ScalarValue::Date32(_) => write!(f, "Date32(\"{}\")", self), + ScalarValue::Date64(_) => write!(f, "Date64(\"{}\")", self), + ScalarValue::IntervalDayTime(_) => { + write!(f, "IntervalDayTime(\"{}\")", self) + } + ScalarValue::IntervalYearMonth(_) => { + write!(f, "IntervalYearMonth(\"{}\")", self) + } + ScalarValue::IntervalMonthDayNano(_) => { + write!(f, "IntervalMonthDayNano(\"{}\")", self) + } + ScalarValue::Struct(e, fields) => { + // Use Debug representation of field values + match e { + Some(l) => write!( + f, + "Struct({{{}}})", + l.iter() + .zip(fields.iter()) + .map(|(value, field)| format!("{}:{:?}", field.name(), value)) + .collect::>() + .join(",") + ), + None => write!(f, "Struct(NULL)"), + } + } + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::field_util::struct_array_from; + + #[test] + fn scalar_decimal_test() { + let decimal_value = ScalarValue::Decimal128(Some(123), 10, 1); + assert_eq!(DataType::Decimal(10, 1), decimal_value.get_datatype()); + let try_into_value: i128 = decimal_value.clone().try_into().unwrap(); + assert_eq!(123_i128, try_into_value); + assert!(!decimal_value.is_null()); + let neg_decimal_value = decimal_value.arithmetic_negate(); + match neg_decimal_value { + ScalarValue::Decimal128(v, _, _) => { + assert_eq!(-123, v.unwrap()); + } + _ => { + unreachable!(); + } + } + + // decimal scalar to array + let array = decimal_value.to_array(); + let array = array.as_any().downcast_ref::().unwrap(); + assert_eq!(1, array.len()); + assert_eq!(DataType::Decimal(10, 1), array.data_type().clone()); + assert_eq!(123i128, array.value(0)); + + // decimal scalar to array with size + let array = decimal_value.to_array_of_size(10); + let array_decimal = array.as_any().downcast_ref::().unwrap(); + assert_eq!(10, array.len()); + assert_eq!(DataType::Decimal(10, 1), array.data_type().clone()); + assert_eq!(123i128, array_decimal.value(0)); + assert_eq!(123i128, array_decimal.value(9)); + // test eq array + assert!(decimal_value.eq_array(&array, 1)); + assert!(decimal_value.eq_array(&array, 5)); + // test try from array + assert_eq!( + decimal_value, + ScalarValue::try_from_array(&array, 5).unwrap() + ); + + assert_eq!( + decimal_value, + ScalarValue::try_new_decimal128(123, 10, 1).unwrap() + ); + + // test compare + let left = ScalarValue::Decimal128(Some(123), 10, 2); + let right = ScalarValue::Decimal128(Some(124), 10, 2); + assert!(!left.eq(&right)); + let result = left < right; + assert!(result); + let result = left <= right; + assert!(result); + let right = ScalarValue::Decimal128(Some(124), 10, 3); + // make sure that two decimals with diff datatype can't be compared. + let result = left.partial_cmp(&right); + assert_eq!(None, result); + + let decimal_vec = vec![ + ScalarValue::Decimal128(Some(1), 10, 2), + ScalarValue::Decimal128(Some(2), 10, 2), + ScalarValue::Decimal128(Some(3), 10, 2), + ]; + // convert the vec to decimal array and check the result + let array = ScalarValue::iter_to_array(decimal_vec.into_iter()).unwrap(); + assert_eq!(3, array.len()); + assert_eq!(DataType::Decimal(10, 2), array.data_type().clone()); + + let decimal_vec = vec![ + ScalarValue::Decimal128(Some(1), 10, 2), + ScalarValue::Decimal128(Some(2), 10, 2), + ScalarValue::Decimal128(Some(3), 10, 2), + ScalarValue::Decimal128(None, 10, 2), + ]; + let array: ArrayRef = + ScalarValue::iter_to_array(decimal_vec.into_iter()).unwrap(); + assert_eq!(4, array.len()); + assert_eq!(DataType::Decimal(10, 2), array.data_type().clone()); + + assert!(ScalarValue::try_new_decimal128(1, 10, 2) + .unwrap() + .eq_array(&array, 0)); + assert!(ScalarValue::try_new_decimal128(2, 10, 2) + .unwrap() + .eq_array(&array, 1)); + assert!(ScalarValue::try_new_decimal128(3, 10, 2) + .unwrap() + .eq_array(&array, 2)); + assert_eq!( + ScalarValue::Decimal128(None, 10, 2), + ScalarValue::try_from_array(&array, 3).unwrap() + ); + assert_eq!( + ScalarValue::Decimal128(None, 10, 2), + ScalarValue::try_from_array(&array, 4).unwrap() + ); + } + + #[test] + fn scalar_value_to_array_u64() { + let value = ScalarValue::UInt64(Some(13u64)); + let array = value.to_array(); + let array = array.as_any().downcast_ref::().unwrap(); + assert_eq!(array.len(), 1); + assert!(!array.is_null(0)); + assert_eq!(array.value(0), 13); + + let value = ScalarValue::UInt64(None); + let array = value.to_array(); + let array = array.as_any().downcast_ref::().unwrap(); + assert_eq!(array.len(), 1); + assert!(array.is_null(0)); + } + + #[test] + fn scalar_value_to_array_u32() { + let value = ScalarValue::UInt32(Some(13u32)); + let array = value.to_array(); + let array = array.as_any().downcast_ref::().unwrap(); + assert_eq!(array.len(), 1); + assert!(!array.is_null(0)); + assert_eq!(array.value(0), 13); + + let value = ScalarValue::UInt32(None); + let array = value.to_array(); + let array = array.as_any().downcast_ref::().unwrap(); + assert_eq!(array.len(), 1); + assert!(array.is_null(0)); + } + + #[test] + fn scalar_list_null_to_array() { + let list_array_ref = + ScalarValue::List(None, Box::new(DataType::UInt64)).to_array(); + let list_array = list_array_ref + .as_any() + .downcast_ref::>() + .unwrap(); + + assert!(list_array.is_null(0)); + assert_eq!(list_array.len(), 1); + assert_eq!(list_array.values().len(), 0); + } + + #[test] + fn scalar_list_to_array() { + let list_array_ref = ScalarValue::List( + Some(Box::new(vec![ + ScalarValue::UInt64(Some(100)), + ScalarValue::UInt64(None), + ScalarValue::UInt64(Some(101)), + ])), + Box::new(DataType::UInt64), + ) + .to_array(); + + let list_array = list_array_ref + .as_any() + .downcast_ref::>() + .unwrap(); + assert_eq!(list_array.len(), 1); + assert_eq!(list_array.values().len(), 3); + + let prim_array_ref = list_array.value(0); + let prim_array = prim_array_ref + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(prim_array.len(), 3); + assert_eq!(prim_array.value(0), 100); + assert!(prim_array.is_null(1)); + assert_eq!(prim_array.value(2), 101); + } + + /// Creates array directly and via ScalarValue and ensures they are the same + macro_rules! check_scalar_iter { + ($SCALAR_T:ident, $ARRAYTYPE:ident, $INPUT:expr) => {{ + let scalars: Vec<_> = + $INPUT.iter().map(|v| ScalarValue::$SCALAR_T(*v)).collect(); + + let array = ScalarValue::iter_to_array(scalars.into_iter()).unwrap(); + + let expected = $ARRAYTYPE::from($INPUT).as_arc(); + + assert_eq!(&array, &expected); + }}; + } + + /// Creates array directly and via ScalarValue and ensures they are the same + /// but for variants that carry a timezone field. + macro_rules! check_scalar_iter_tz { + ($SCALAR_T:ident, $INPUT:expr) => {{ + let scalars: Vec<_> = $INPUT + .iter() + .map(|v| ScalarValue::$SCALAR_T(*v, None)) + .collect(); + + let array = ScalarValue::iter_to_array(scalars.into_iter()).unwrap(); + + let expected: Arc = Arc::new(Int64Array::from($INPUT)); + + assert_eq!(&array, &expected); + }}; + } + + /// Creates array directly and via ScalarValue and ensures they + /// are the same, for string arrays + macro_rules! check_scalar_iter_string { + ($SCALAR_T:ident, $ARRAYTYPE:ident, $INPUT:expr) => {{ + let scalars: Vec<_> = $INPUT + .iter() + .map(|v| ScalarValue::$SCALAR_T(v.map(|v| v.to_string()))) + .collect(); + + let array = ScalarValue::iter_to_array(scalars.into_iter()).unwrap(); + + let expected: Arc = Arc::new($ARRAYTYPE::from($INPUT)); + + assert_eq!(&array, &expected); + }}; + } + + /// Creates array directly and via ScalarValue and ensures they + /// are the same, for binary arrays + macro_rules! check_scalar_iter_binary { + ($SCALAR_T:ident, $ARRAYTYPE:ident, $INPUT:expr) => {{ + let scalars: Vec<_> = $INPUT + .iter() + .map(|v| ScalarValue::$SCALAR_T(v.map(|v| v.to_vec()))) + .collect(); + + let array = ScalarValue::iter_to_array(scalars.into_iter()).unwrap(); + + let expected: $ARRAYTYPE = + $INPUT.iter().map(|v| v.map(|v| v.to_vec())).collect(); + + let expected: Arc = Arc::new(expected); + + assert_eq!(&array, &expected); + }}; + } + + #[test] + fn scalar_iter_to_array_boolean() { + check_scalar_iter!( + Boolean, + MutableBooleanArray, + vec![Some(true), None, Some(false)] + ); + check_scalar_iter!(Float32, Float32Vec, vec![Some(1.9), None, Some(-2.1)]); + check_scalar_iter!(Float64, Float64Vec, vec![Some(1.9), None, Some(-2.1)]); + + check_scalar_iter!(Int8, Int8Vec, vec![Some(1), None, Some(3)]); + check_scalar_iter!(Int16, Int16Vec, vec![Some(1), None, Some(3)]); + check_scalar_iter!(Int32, Int32Vec, vec![Some(1), None, Some(3)]); + check_scalar_iter!(Int64, Int64Vec, vec![Some(1), None, Some(3)]); + + check_scalar_iter!(UInt8, UInt8Vec, vec![Some(1), None, Some(3)]); + check_scalar_iter!(UInt16, UInt16Vec, vec![Some(1), None, Some(3)]); + check_scalar_iter!(UInt32, UInt32Vec, vec![Some(1), None, Some(3)]); + check_scalar_iter!(UInt64, UInt64Vec, vec![Some(1), None, Some(3)]); + + check_scalar_iter_tz!(TimestampSecond, vec![Some(1), None, Some(3)]); + check_scalar_iter_tz!(TimestampMillisecond, vec![Some(1), None, Some(3)]); + check_scalar_iter_tz!(TimestampMicrosecond, vec![Some(1), None, Some(3)]); + check_scalar_iter_tz!(TimestampNanosecond, vec![Some(1), None, Some(3)]); + + check_scalar_iter_string!( + Utf8, + StringArray, + vec![Some("foo"), None, Some("bar")] + ); + check_scalar_iter_string!( + LargeUtf8, + LargeStringArray, + vec![Some("foo"), None, Some("bar")] + ); + check_scalar_iter_binary!( + Binary, + SmallBinaryArray, + vec![Some(b"foo"), None, Some(b"bar")] + ); + check_scalar_iter_binary!( + LargeBinary, + LargeBinaryArray, + vec![Some(b"foo"), None, Some(b"bar")] + ); + } + + #[test] + fn scalar_iter_to_array_empty() { + let scalars = vec![] as Vec; + + let result = ScalarValue::iter_to_array(scalars.into_iter()).unwrap_err(); + assert!( + result + .to_string() + .contains("Empty iterator passed to ScalarValue::iter_to_array"), + "{}", + result + ); + } + + #[test] + fn scalar_iter_to_array_mismatched_types() { + use ScalarValue::*; + // If the scalar values are not all the correct type, error here + let scalars: Vec = vec![Boolean(Some(true)), Int32(Some(5))]; + + let result = ScalarValue::iter_to_array(scalars.into_iter()).unwrap_err(); + assert!(result.to_string().contains("Inconsistent types in ScalarValue::iter_to_array. Expected Boolean, got Int32(5)"), + "{}", result); + } + + #[test] + fn scalar_try_from_array_null() { + let array = vec![Some(33), None].into_iter().collect::(); + let array: ArrayRef = Arc::new(array); + + assert_eq!( + ScalarValue::Int64(Some(33)), + ScalarValue::try_from_array(&array, 0).unwrap() + ); + assert_eq!( + ScalarValue::Int64(None), + ScalarValue::try_from_array(&array, 1).unwrap() + ); + } + + #[test] + fn scalar_try_from_dict_datatype() { + let data_type = + DataType::Dictionary(IntegerType::Int8, Box::new(DataType::Utf8), false); + let data_type = &data_type; + assert_eq!(ScalarValue::Utf8(None), data_type.try_into().unwrap()) + } + + #[test] + fn size_of_scalar() { + // Since ScalarValues are used in a non trivial number of places, + // making it larger means significant more memory consumption + // per distinct value. + #[cfg(target_arch = "aarch64")] + assert_eq!(std::mem::size_of::(), 64); + + #[cfg(target_arch = "amd64")] + assert_eq!(std::mem::size_of::(), 48); + } + + #[test] + fn scalar_eq_array() { + // Validate that eq_array has the same semantics as ScalarValue::eq + macro_rules! make_typed_vec { + ($INPUT:expr, $TYPE:ident) => {{ + $INPUT + .iter() + .map(|v| v.map(|v| v as $TYPE)) + .collect::>() + }}; + } + + let bool_vals = vec![Some(true), None, Some(false)]; + let f32_vals = vec![Some(-1.0), None, Some(1.0)]; + let f64_vals = make_typed_vec!(f32_vals, f64); + + let i8_vals = vec![Some(-1), None, Some(1)]; + let i16_vals = make_typed_vec!(i8_vals, i16); + let i32_vals = make_typed_vec!(i8_vals, i32); + let i64_vals = make_typed_vec!(i8_vals, i64); + let days_ms_vals = &[Some(days_ms::new(1, 2)), None, Some(days_ms::new(10, 0))]; + + let u8_vals = vec![Some(0), None, Some(1)]; + let u16_vals = make_typed_vec!(u8_vals, u16); + let u32_vals = make_typed_vec!(u8_vals, u32); + let u64_vals = make_typed_vec!(u8_vals, u64); + + let str_vals = &[Some("foo"), None, Some("bar")]; + + /// Test each value in `scalar` with the corresponding element + /// at `array`. Assumes each element is unique (aka not equal + /// with all other indexes) + struct TestCase { + array: ArrayRef, + scalars: Vec, + } + + /// Create a test case for casing the input to the specified array type + macro_rules! make_test_case { + ($INPUT:expr, $ARRAY_TY:ident, $SCALAR_TY:ident) => {{ + TestCase { + array: Arc::new($INPUT.iter().collect::<$ARRAY_TY>()), + scalars: $INPUT.iter().map(|v| ScalarValue::$SCALAR_TY(*v)).collect(), + } + }}; + + ($INPUT:expr, $ARRAY_TY:ident, $SCALAR_TY:ident, $TZ:expr) => {{ + let tz = $TZ; + TestCase { + array: Arc::new($INPUT.iter().collect::<$ARRAY_TY>()), + scalars: $INPUT + .iter() + .map(|v| ScalarValue::$SCALAR_TY(*v, tz.clone())) + .collect(), + } + }}; + } + + macro_rules! make_date_test_case { + ($INPUT:expr, $ARRAY_TY:ident, $SCALAR_TY:ident) => {{ + TestCase { + array: Arc::new($ARRAY_TY::from($INPUT).to(DataType::$SCALAR_TY)), + scalars: $INPUT.iter().map(|v| ScalarValue::$SCALAR_TY(*v)).collect(), + } + }}; + } + + macro_rules! make_ts_test_case { + ($INPUT:expr, $ARROW_TU:ident, $SCALAR_TY:ident, $TZ:expr) => {{ + TestCase { + array: Arc::new( + Int64Array::from($INPUT) + .to(DataType::Timestamp(TimeUnit::$ARROW_TU, $TZ)), + ), + scalars: $INPUT + .iter() + .map(|v| ScalarValue::$SCALAR_TY(*v, $TZ)) + .collect(), + } + }}; + } + + macro_rules! make_temporal_test_case { + ($INPUT:expr, $ARRAY_TY:ident, $ARROW_TU:ident, $SCALAR_TY:ident) => {{ + TestCase { + array: Arc::new( + $ARRAY_TY::from($INPUT) + .to(DataType::Interval(IntervalUnit::$ARROW_TU)), + ), + scalars: $INPUT.iter().map(|v| ScalarValue::$SCALAR_TY(*v)).collect(), + } + }}; + } + + macro_rules! make_str_test_case { + ($INPUT:expr, $ARRAY_TY:ident, $SCALAR_TY:ident) => {{ + TestCase { + array: Arc::new($INPUT.iter().cloned().collect::<$ARRAY_TY>()), + scalars: $INPUT + .iter() + .map(|v| ScalarValue::$SCALAR_TY(v.map(|v| v.to_string()))) + .collect(), + } + }}; + } + + macro_rules! make_binary_test_case { + ($INPUT:expr, $ARRAY_TY:ident, $SCALAR_TY:ident) => {{ + TestCase { + array: Arc::new($INPUT.iter().cloned().collect::<$ARRAY_TY>()), + scalars: $INPUT + .iter() + .map(|v| { + ScalarValue::$SCALAR_TY(v.map(|v| v.as_bytes().to_vec())) + }) + .collect(), + } + }}; + } + + /// create a test case for DictionaryArray<$INDEX_TY> + macro_rules! make_str_dict_test_case { + ($INPUT:expr, $INDEX_TY:ty, $SCALAR_TY:ident) => {{ + TestCase { + array: { + let mut array = MutableDictionaryArray::< + $INDEX_TY, + MutableUtf8Array, + >::new(); + array.try_extend(*($INPUT)).unwrap(); + let array: DictionaryArray<$INDEX_TY> = array.into(); + Arc::new(array) + }, + scalars: $INPUT + .iter() + .map(|v| ScalarValue::$SCALAR_TY(v.map(|v| v.to_string()))) + .collect(), + } + }}; + } + let utc_tz = Some("UTC".to_owned()); + let cases = vec![ + make_test_case!(bool_vals, BooleanArray, Boolean), + make_test_case!(f32_vals, Float32Array, Float32), + make_test_case!(f64_vals, Float64Array, Float64), + make_test_case!(i8_vals, Int8Array, Int8), + make_test_case!(i16_vals, Int16Array, Int16), + make_test_case!(i32_vals, Int32Array, Int32), + make_test_case!(i64_vals, Int64Array, Int64), + make_test_case!(u8_vals, UInt8Array, UInt8), + make_test_case!(u16_vals, UInt16Array, UInt16), + make_test_case!(u32_vals, UInt32Array, UInt32), + make_test_case!(u64_vals, UInt64Array, UInt64), + make_str_test_case!(str_vals, StringArray, Utf8), + make_str_test_case!(str_vals, LargeStringArray, LargeUtf8), + make_binary_test_case!(str_vals, SmallBinaryArray, Binary), + make_binary_test_case!(str_vals, LargeBinaryArray, LargeBinary), + make_date_test_case!(&i32_vals, Int32Array, Date32), + make_date_test_case!(&i64_vals, Int64Array, Date64), + make_ts_test_case!(&i64_vals, Second, TimestampSecond, utc_tz.clone()), + make_ts_test_case!( + &i64_vals, + Millisecond, + TimestampMillisecond, + utc_tz.clone() + ), + make_ts_test_case!( + &i64_vals, + Microsecond, + TimestampMicrosecond, + utc_tz.clone() + ), + make_ts_test_case!( + &i64_vals, + Nanosecond, + TimestampNanosecond, + utc_tz.clone() + ), + make_ts_test_case!(&i64_vals, Second, TimestampSecond, None), + make_ts_test_case!(&i64_vals, Millisecond, TimestampMillisecond, None), + make_ts_test_case!(&i64_vals, Microsecond, TimestampMicrosecond, None), + make_ts_test_case!(&i64_vals, Nanosecond, TimestampNanosecond, None), + make_temporal_test_case!(&i32_vals, Int32Array, YearMonth, IntervalYearMonth), + make_temporal_test_case!(days_ms_vals, DaysMsArray, DayTime, IntervalDayTime), + make_str_dict_test_case!(str_vals, i8, Utf8), + make_str_dict_test_case!(str_vals, i16, Utf8), + make_str_dict_test_case!(str_vals, i32, Utf8), + make_str_dict_test_case!(str_vals, i64, Utf8), + make_str_dict_test_case!(str_vals, u8, Utf8), + make_str_dict_test_case!(str_vals, u16, Utf8), + make_str_dict_test_case!(str_vals, u32, Utf8), + make_str_dict_test_case!(str_vals, u64, Utf8), + ]; + + for case in cases { + let TestCase { array, scalars } = case; + assert_eq!(array.len(), scalars.len()); + + for (index, scalar) in scalars.into_iter().enumerate() { + assert!( + scalar.eq_array(&array, index), + "Expected {:?} to be equal to {:?} at index {}", + scalar, + array, + index + ); + + // test that all other elements are *not* equal + for other_index in 0..array.len() { + if index != other_index { + assert!( + !scalar.eq_array(&array, other_index), + "Expected {:?} to be NOT equal to {:?} at index {}", + scalar, + array, + other_index + ); + } + } + } + } + } + + #[test] + fn scalar_partial_ordering() { + use ScalarValue::*; + + assert_eq!( + Int64(Some(33)).partial_cmp(&Int64(Some(0))), + Some(Ordering::Greater) + ); + assert_eq!( + Int64(Some(0)).partial_cmp(&Int64(Some(33))), + Some(Ordering::Less) + ); + assert_eq!( + Int64(Some(33)).partial_cmp(&Int64(Some(33))), + Some(Ordering::Equal) + ); + // For different data type, `partial_cmp` returns None. + assert_eq!(Int64(Some(33)).partial_cmp(&Int32(Some(33))), None); + assert_eq!(Int32(Some(33)).partial_cmp(&Int64(Some(33))), None); + + assert_eq!( + List( + Some(Box::new(vec![Int32(Some(1)), Int32(Some(5))])), + Box::new(DataType::Int32), + ) + .partial_cmp(&List( + Some(Box::new(vec![Int32(Some(1)), Int32(Some(5))])), + Box::new(DataType::Int32), + )), + Some(Ordering::Equal) + ); + + assert_eq!( + List( + Some(Box::new(vec![Int32(Some(10)), Int32(Some(5))])), + Box::new(DataType::Int32), + ) + .partial_cmp(&List( + Some(Box::new(vec![Int32(Some(1)), Int32(Some(5))])), + Box::new(DataType::Int32), + )), + Some(Ordering::Greater) + ); + + assert_eq!( + List( + Some(Box::new(vec![Int32(Some(1)), Int32(Some(5))])), + Box::new(DataType::Int32), + ) + .partial_cmp(&List( + Some(Box::new(vec![Int32(Some(10)), Int32(Some(5))])), + Box::new(DataType::Int32), + )), + Some(Ordering::Less) + ); + + // For different data type, `partial_cmp` returns None. + assert_eq!( + List( + Some(Box::new(vec![Int64(Some(1)), Int64(Some(5))])), + Box::new(DataType::Int64), + ) + .partial_cmp(&List( + Some(Box::new(vec![Int32(Some(1)), Int32(Some(5))])), + Box::new(DataType::Int32), + )), + None + ); + + assert_eq!( + ScalarValue::from(vec![ + ("A", ScalarValue::from(1.0)), + ("B", ScalarValue::from("Z")), + ]) + .partial_cmp(&ScalarValue::from(vec![ + ("A", ScalarValue::from(2.0)), + ("B", ScalarValue::from("A")), + ])), + Some(Ordering::Less) + ); + + // For different struct fields, `partial_cmp` returns None. + assert_eq!( + ScalarValue::from(vec![ + ("A", ScalarValue::from(1.0)), + ("B", ScalarValue::from("Z")), + ]) + .partial_cmp(&ScalarValue::from(vec![ + ("a", ScalarValue::from(2.0)), + ("b", ScalarValue::from("A")), + ])), + None + ); + } + + #[test] + fn test_scalar_struct() { + let field_a = Field::new("A", DataType::Int32, false); + let field_b = Field::new("B", DataType::Boolean, false); + let field_c = Field::new("C", DataType::Utf8, false); + + let field_e = Field::new("e", DataType::Int16, false); + let field_f = Field::new("f", DataType::Int64, false); + let field_d = Field::new( + "D", + DataType::Struct(vec![field_e.clone(), field_f.clone()]), + false, + ); + + let scalar = ScalarValue::Struct( + Some(Box::new(vec![ + ScalarValue::Int32(Some(23)), + ScalarValue::Boolean(Some(false)), + ScalarValue::Utf8(Some("Hello".to_string())), + ScalarValue::from(vec![ + ("e", ScalarValue::from(2i16)), + ("f", ScalarValue::from(3i64)), + ]), + ])), + Box::new(vec![ + field_a.clone(), + field_b.clone(), + field_c.clone(), + field_d.clone(), + ]), + ); + let _dt = scalar.get_datatype(); + let _sub_dt = field_d.data_type.clone(); + + // Check Display + assert_eq!( + format!("{}", scalar), + String::from("{A:23,B:false,C:Hello,D:{e:2,f:3}}") + ); + + // Check Debug + assert_eq!( + format!("{:?}", scalar), + String::from( + r#"Struct({A:Int32(23),B:Boolean(false),C:Utf8("Hello"),D:Struct({e:Int16(2),f:Int64(3)})})"# + ) + ); + + // Convert to length-2 array + let array = scalar.to_array_of_size(2); + let expected_vals = vec![ + (field_a.clone(), Int32Vec::from_slice(&[23, 23]).as_arc()), + ( + field_b.clone(), + Arc::new(BooleanArray::from_slice(&vec![false, false])) as ArrayRef, + ), + ( + field_c.clone(), + Arc::new(StringArray::from_slice(&vec!["Hello", "Hello"])) as ArrayRef, + ), + ( + field_d.clone(), + Arc::new(StructArray::from_data( + DataType::Struct(vec![field_e.clone(), field_f.clone()]), + vec![ + Int16Vec::from_slice(&[2, 2]).as_arc(), + Int64Vec::from_slice(&[3, 3]).as_arc(), + ], + None, + )) as ArrayRef, + ), + ]; + + let expected = Arc::new(struct_array_from(expected_vals)) as ArrayRef; + assert_eq!(&array, &expected); + + // Construct from second element of ArrayRef + let constructed = ScalarValue::try_from_array(&expected, 1).unwrap(); + assert_eq!(constructed, scalar); + + // None version + let none_scalar = ScalarValue::try_from(array.data_type()).unwrap(); + assert!(none_scalar.is_null()); + assert_eq!(format!("{:?}", none_scalar), String::from("Struct(NULL)")); + + // Construct with convenience From> + let constructed = ScalarValue::from(vec![ + ("A", ScalarValue::from(23i32)), + ("B", ScalarValue::from(false)), + ("C", ScalarValue::from("Hello")), + ( + "D", + ScalarValue::from(vec![ + ("e", ScalarValue::from(2i16)), + ("f", ScalarValue::from(3i64)), + ]), + ), + ]); + assert_eq!(constructed, scalar); + + // Build Array from Vec of structs + let scalars = vec![ + ScalarValue::from(vec![ + ("A", ScalarValue::from(23i32)), + ("B", ScalarValue::from(false)), + ("C", ScalarValue::from("Hello")), + ( + "D", + ScalarValue::from(vec![ + ("e", ScalarValue::from(2i16)), + ("f", ScalarValue::from(3i64)), + ]), + ), + ]), + ScalarValue::from(vec![ + ("A", ScalarValue::from(7i32)), + ("B", ScalarValue::from(true)), + ("C", ScalarValue::from("World")), + ( + "D", + ScalarValue::from(vec![ + ("e", ScalarValue::from(4i16)), + ("f", ScalarValue::from(5i64)), + ]), + ), + ]), + ScalarValue::from(vec![ + ("A", ScalarValue::from(-1000i32)), + ("B", ScalarValue::from(true)), + ("C", ScalarValue::from("!!!!!")), + ( + "D", + ScalarValue::from(vec![ + ("e", ScalarValue::from(6i16)), + ("f", ScalarValue::from(7i64)), + ]), + ), + ]), + ]; + let array: ArrayRef = ScalarValue::iter_to_array(scalars).unwrap(); + + let expected = Arc::new(struct_array_from(vec![ + (field_a, Int32Vec::from_slice(&[23, 7, -1000]).as_arc()), + ( + field_b, + Arc::new(BooleanArray::from_slice(&vec![false, true, true])) as ArrayRef, + ), + ( + field_c, + Arc::new(StringArray::from_slice(&vec!["Hello", "World", "!!!!!"])) + as ArrayRef, + ), + ( + field_d, + Arc::new(StructArray::from_data( + DataType::Struct(vec![field_e, field_f]), + vec![ + Int16Vec::from_slice(&[2, 4, 6]).as_arc(), + Int64Vec::from_slice(&[3, 5, 7]).as_arc(), + ], + None, + )) as ArrayRef, + ), + ])) as ArrayRef; + + assert_eq!(&array, &expected); + } + + #[test] + fn test_lists_in_struct() { + let field_a = Field::new("A", DataType::Utf8, false); + let field_primitive_list = Field::new( + "primitive_list", + DataType::List(Box::new(Field::new("item", DataType::Int32, true))), + false, + ); + + // Define primitive list scalars + let l0 = ScalarValue::List( + Some(Box::new(vec![ + ScalarValue::from(1i32), + ScalarValue::from(2i32), + ScalarValue::from(3i32), + ])), + Box::new(DataType::Int32), + ); + + let l1 = ScalarValue::List( + Some(Box::new(vec![ + ScalarValue::from(4i32), + ScalarValue::from(5i32), + ])), + Box::new(DataType::Int32), + ); + + let l2 = ScalarValue::List( + Some(Box::new(vec![ScalarValue::from(6i32)])), + Box::new(DataType::Int32), + ); + + // Define struct scalars + let s0 = ScalarValue::from(vec![ + ("A", ScalarValue::Utf8(Some(String::from("First")))), + ("primitive_list", l0), + ]); + + let s1 = ScalarValue::from(vec![ + ("A", ScalarValue::Utf8(Some(String::from("Second")))), + ("primitive_list", l1), + ]); + + let s2 = ScalarValue::from(vec![ + ("A", ScalarValue::Utf8(Some(String::from("Third")))), + ("primitive_list", l2), + ]); + + // iter_to_array for struct scalars + let array = + ScalarValue::iter_to_array(vec![s0.clone(), s1.clone(), s2.clone()]).unwrap(); + let array = array.as_any().downcast_ref::().unwrap(); + + let mut list_array = + MutableListArray::::new_with_capacity(Int32Vec::new(), 5); + list_array + .try_extend(vec![ + Some(vec![Some(1), Some(2), Some(3)]), + Some(vec![Some(4), Some(5)]), + Some(vec![Some(6)]), + ]) + .unwrap(); + let expected = struct_array_from(vec![ + ( + field_a.clone(), + Arc::new(StringArray::from_slice(&vec!["First", "Second", "Third"])) + as ArrayRef, + ), + (field_primitive_list.clone(), list_array.as_arc()), + ]); + + assert_eq!(array, &expected); + + // Define list-of-structs scalars + let nl0 = ScalarValue::List( + Some(Box::new(vec![s0.clone(), s1.clone()])), + Box::new(s0.get_datatype()), + ); + + let nl1 = + ScalarValue::List(Some(Box::new(vec![s2])), Box::new(s0.get_datatype())); + + let nl2 = + ScalarValue::List(Some(Box::new(vec![s1])), Box::new(s0.get_datatype())); + + // iter_to_array for list-of-struct + let array = ScalarValue::iter_to_array(vec![nl0, nl1, nl2]).unwrap(); + let array = array.as_any().downcast_ref::>().unwrap(); + + // Construct expected array with array builders + let field_a_builder = + Utf8Array::::from_slice(&vec!["First", "Second", "Third", "Second"]); + let primitive_value_builder = Int32Vec::with_capacity(5); + let mut field_primitive_list_builder = + MutableListArray::::new_with_capacity( + primitive_value_builder, + 0, + ); + field_primitive_list_builder + .try_push(Some(vec![1, 2, 3].into_iter().map(Option::Some))) + .unwrap(); + field_primitive_list_builder + .try_push(Some(vec![4, 5].into_iter().map(Option::Some))) + .unwrap(); + field_primitive_list_builder + .try_push(Some(vec![6].into_iter().map(Option::Some))) + .unwrap(); + field_primitive_list_builder + .try_push(Some(vec![4, 5].into_iter().map(Option::Some))) + .unwrap(); + let _element_builder = StructArray::from_data( + DataType::Struct(vec![field_a, field_primitive_list]), + vec![ + Arc::new(field_a_builder), + field_primitive_list_builder.as_arc(), + ], + None, + ); + //let expected = ListArray::(element_builder, 5); + eprintln!("array = {:?}", array); + //assert_eq!(array, &expected); + } + + #[test] + fn test_nested_lists() { + // Define inner list scalars + let l1 = ScalarValue::List( + Some(Box::new(vec![ + ScalarValue::List( + Some(Box::new(vec![ + ScalarValue::from(1i32), + ScalarValue::from(2i32), + ScalarValue::from(3i32), + ])), + Box::new(DataType::Int32), + ), + ScalarValue::List( + Some(Box::new(vec![ + ScalarValue::from(4i32), + ScalarValue::from(5i32), + ])), + Box::new(DataType::Int32), + ), + ])), + Box::new(DataType::List(Box::new(Field::new( + "item", + DataType::Int32, + true, + )))), + ); + + let l2 = ScalarValue::List( + Some(Box::new(vec![ + ScalarValue::List( + Some(Box::new(vec![ScalarValue::from(6i32)])), + Box::new(DataType::Int32), + ), + ScalarValue::List( + Some(Box::new(vec![ + ScalarValue::from(7i32), + ScalarValue::from(8i32), + ])), + Box::new(DataType::Int32), + ), + ])), + Box::new(DataType::List(Box::new(Field::new( + "item", + DataType::Int32, + true, + )))), + ); + + let l3 = ScalarValue::List( + Some(Box::new(vec![ScalarValue::List( + Some(Box::new(vec![ScalarValue::from(9i32)])), + Box::new(DataType::Int32), + )])), + Box::new(DataType::List(Box::new(Field::new( + "item", + DataType::Int32, + true, + )))), + ); + + let array = ScalarValue::iter_to_array(vec![l1, l2, l3]).unwrap(); + + // Construct expected array with array builders + let inner_builder = Int32Vec::with_capacity(8); + let middle_builder = + MutableListArray::::new_with_capacity(inner_builder, 0); + let mut outer_builder = + MutableListArray::>::new_with_capacity( + middle_builder, + 0, + ); + outer_builder + .try_push(Some(vec![ + Some(vec![Some(1), Some(2), Some(3)]), + Some(vec![Some(4), Some(5)]), + ])) + .unwrap(); + outer_builder + .try_push(Some(vec![ + Some(vec![Some(6)]), + Some(vec![Some(7), Some(8)]), + ])) + .unwrap(); + outer_builder + .try_push(Some(vec![Some(vec![Some(9)])])) + .unwrap(); + + let expected = outer_builder.as_arc(); + + assert_eq!(&array, &expected); + } + + #[test] + fn scalar_timestamp_ns_utc_timezone() { + let scalar = ScalarValue::TimestampNanosecond( + Some(1599566400000000000), + Some("UTC".to_owned()), + ); + + assert_eq!( + scalar.get_datatype(), + DataType::Timestamp(TimeUnit::Nanosecond, Some("UTC".to_owned())) + ); + + let array = scalar.to_array(); + assert_eq!(array.len(), 1); + assert_eq!( + array.data_type(), + &DataType::Timestamp(TimeUnit::Nanosecond, Some("UTC".to_owned())) + ); + + let newscalar = ScalarValue::try_from_array(&array, 0).unwrap(); + assert_eq!( + newscalar.get_datatype(), + DataType::Timestamp(TimeUnit::Nanosecond, Some("UTC".to_owned())) + ); + } +} diff --git a/datafusion-common/src/scalar_tmp.rs b/datafusion-common/src/scalar_tmp.rs new file mode 100644 index 000000000000..847a9ddd65fd --- /dev/null +++ b/datafusion-common/src/scalar_tmp.rs @@ -0,0 +1,2992 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! This module provides ScalarValue, an enum that can be used for storage of single elements + +use std::{convert::TryFrom, fmt, iter::repeat, sync::Arc}; + +use crate::error::{DataFusionError, Result}; +use crate::field_util::{FieldExt, StructArrayExt}; +use arrow::bitmap::Bitmap; +use arrow::buffer::Buffer; +use arrow::compute::concatenate; +use arrow::datatypes::DataType::Decimal; +use arrow::{ + array::*, + datatypes::{DataType, Field, IntegerType, IntervalUnit, TimeUnit}, + scalar::{PrimitiveScalar, Scalar}, + types::{days_ms, NativeType}, +}; +use ordered_float::OrderedFloat; +use std::cmp::Ordering; +use std::convert::{Infallible, TryInto}; +use std::str::FromStr; + +type StringArray = Utf8Array; +type LargeStringArray = Utf8Array; +type SmallBinaryArray = BinaryArray; +type LargeBinaryArray = BinaryArray; +type MutableStringArray = MutableUtf8Array; +type MutableLargeStringArray = MutableUtf8Array; + +// TODO may need to be moved to arrow-rs +/// The max precision and scale for decimal128 +pub(crate) const MAX_PRECISION_FOR_DECIMAL128: usize = 38; +pub(crate) const MAX_SCALE_FOR_DECIMAL128: usize = 38; + +/// Represents a dynamically typed, nullable single value. +/// This is the single-valued counter-part of arrow’s `Array`. +#[derive(Clone)] +pub enum ScalarValue { + /// true or false value + Boolean(Option), + /// 32bit float + Float32(Option), + /// 64bit float + Float64(Option), + /// 128bit decimal, using the i128 to represent the decimal + Decimal128(Option, usize, usize), + /// signed 8bit int + Int8(Option), + /// signed 16bit int + Int16(Option), + /// signed 32bit int + Int32(Option), + /// signed 64bit int + Int64(Option), + /// unsigned 8bit int + UInt8(Option), + /// unsigned 16bit int + UInt16(Option), + /// unsigned 32bit int + UInt32(Option), + /// unsigned 64bit int + UInt64(Option), + /// utf-8 encoded string. + Utf8(Option), + /// utf-8 encoded string representing a LargeString's arrow type. + LargeUtf8(Option), + /// binary + Binary(Option>), + /// large binary + LargeBinary(Option>), + /// list of nested ScalarValue (boxed to reduce size_of(ScalarValue)) + #[allow(clippy::box_collection)] + List(Option>>, Box), + /// Date stored as a signed 32bit int + Date32(Option), + /// Date stored as a signed 64bit int + Date64(Option), + /// Timestamp Second + TimestampSecond(Option, Option), + /// Timestamp Milliseconds + TimestampMillisecond(Option, Option), + /// Timestamp Microseconds + TimestampMicrosecond(Option, Option), + /// Timestamp Nanoseconds + TimestampNanosecond(Option, Option), + /// Interval with YearMonth unit + IntervalYearMonth(Option), + /// Interval with DayTime unit + IntervalDayTime(Option), + /// Interval with MonthDayNano unit + IntervalMonthDayNano(Option), + /// struct of nested ScalarValue (boxed to reduce size_of(ScalarValue)) + #[allow(clippy::box_collection)] + Struct(Option>>, Box>), +} + +// manual implementation of `PartialEq` that uses OrderedFloat to +// get defined behavior for floating point +impl PartialEq for ScalarValue { + fn eq(&self, other: &Self) -> bool { + use ScalarValue::*; + // This purposely doesn't have a catch-all "(_, _)" so that + // any newly added enum variant will require editing this list + // or else face a compile error + match (self, other) { + (Decimal128(v1, p1, s1), Decimal128(v2, p2, s2)) => { + v1.eq(v2) && p1.eq(p2) && s1.eq(s2) + } + (Decimal128(_, _, _), _) => false, + (Boolean(v1), Boolean(v2)) => v1.eq(v2), + (Boolean(_), _) => false, + (Float32(v1), Float32(v2)) => { + let v1 = v1.map(OrderedFloat); + let v2 = v2.map(OrderedFloat); + v1.eq(&v2) + } + (Float32(_), _) => false, + (Float64(v1), Float64(v2)) => { + let v1 = v1.map(OrderedFloat); + let v2 = v2.map(OrderedFloat); + v1.eq(&v2) + } + (Float64(_), _) => false, + (Int8(v1), Int8(v2)) => v1.eq(v2), + (Int8(_), _) => false, + (Int16(v1), Int16(v2)) => v1.eq(v2), + (Int16(_), _) => false, + (Int32(v1), Int32(v2)) => v1.eq(v2), + (Int32(_), _) => false, + (Int64(v1), Int64(v2)) => v1.eq(v2), + (Int64(_), _) => false, + (UInt8(v1), UInt8(v2)) => v1.eq(v2), + (UInt8(_), _) => false, + (UInt16(v1), UInt16(v2)) => v1.eq(v2), + (UInt16(_), _) => false, + (UInt32(v1), UInt32(v2)) => v1.eq(v2), + (UInt32(_), _) => false, + (UInt64(v1), UInt64(v2)) => v1.eq(v2), + (UInt64(_), _) => false, + (Utf8(v1), Utf8(v2)) => v1.eq(v2), + (Utf8(_), _) => false, + (LargeUtf8(v1), LargeUtf8(v2)) => v1.eq(v2), + (LargeUtf8(_), _) => false, + (Binary(v1), Binary(v2)) => v1.eq(v2), + (Binary(_), _) => false, + (LargeBinary(v1), LargeBinary(v2)) => v1.eq(v2), + (LargeBinary(_), _) => false, + (List(v1, t1), List(v2, t2)) => v1.eq(v2) && t1.eq(t2), + (List(_, _), _) => false, + (Date32(v1), Date32(v2)) => v1.eq(v2), + (Date32(_), _) => false, + (Date64(v1), Date64(v2)) => v1.eq(v2), + (Date64(_), _) => false, + (TimestampSecond(v1, _), TimestampSecond(v2, _)) => v1.eq(v2), + (TimestampSecond(_, _), _) => false, + (TimestampMillisecond(v1, _), TimestampMillisecond(v2, _)) => v1.eq(v2), + (TimestampMillisecond(_, _), _) => false, + (TimestampMicrosecond(v1, _), TimestampMicrosecond(v2, _)) => v1.eq(v2), + (TimestampMicrosecond(_, _), _) => false, + (TimestampNanosecond(v1, _), TimestampNanosecond(v2, _)) => v1.eq(v2), + (TimestampNanosecond(_, _), _) => false, + (IntervalYearMonth(v1), IntervalYearMonth(v2)) => v1.eq(v2), + (IntervalYearMonth(_), _) => false, + (IntervalDayTime(v1), IntervalDayTime(v2)) => v1.eq(v2), + (IntervalDayTime(_), _) => false, + (IntervalMonthDayNano(v1), IntervalMonthDayNano(v2)) => v1.eq(v2), + (IntervalMonthDayNano(_), _) => false, + (Struct(v1, t1), Struct(v2, t2)) => v1.eq(v2) && t1.eq(t2), + (Struct(_, _), _) => false, + } + } +} + +// manual implementation of `PartialOrd` that uses OrderedFloat to +// get defined behavior for floating point +impl PartialOrd for ScalarValue { + fn partial_cmp(&self, other: &Self) -> Option { + use ScalarValue::*; + // This purposely doesn't have a catch-all "(_, _)" so that + // any newly added enum variant will require editing this list + // or else face a compile error + match (self, other) { + (Decimal128(v1, p1, s1), Decimal128(v2, p2, s2)) => { + if p1.eq(p2) && s1.eq(s2) { + v1.partial_cmp(v2) + } else { + // Two decimal values can be compared if they have the same precision and scale. + None + } + } + (Decimal128(_, _, _), _) => None, + (Boolean(v1), Boolean(v2)) => v1.partial_cmp(v2), + (Boolean(_), _) => None, + (Float32(v1), Float32(v2)) => { + let v1 = v1.map(OrderedFloat); + let v2 = v2.map(OrderedFloat); + v1.partial_cmp(&v2) + } + (Float32(_), _) => None, + (Float64(v1), Float64(v2)) => { + let v1 = v1.map(OrderedFloat); + let v2 = v2.map(OrderedFloat); + v1.partial_cmp(&v2) + } + (Float64(_), _) => None, + (Int8(v1), Int8(v2)) => v1.partial_cmp(v2), + (Int8(_), _) => None, + (Int16(v1), Int16(v2)) => v1.partial_cmp(v2), + (Int16(_), _) => None, + (Int32(v1), Int32(v2)) => v1.partial_cmp(v2), + (Int32(_), _) => None, + (Int64(v1), Int64(v2)) => v1.partial_cmp(v2), + (Int64(_), _) => None, + (UInt8(v1), UInt8(v2)) => v1.partial_cmp(v2), + (UInt8(_), _) => None, + (UInt16(v1), UInt16(v2)) => v1.partial_cmp(v2), + (UInt16(_), _) => None, + (UInt32(v1), UInt32(v2)) => v1.partial_cmp(v2), + (UInt32(_), _) => None, + (UInt64(v1), UInt64(v2)) => v1.partial_cmp(v2), + (UInt64(_), _) => None, + (Utf8(v1), Utf8(v2)) => v1.partial_cmp(v2), + (Utf8(_), _) => None, + (LargeUtf8(v1), LargeUtf8(v2)) => v1.partial_cmp(v2), + (LargeUtf8(_), _) => None, + (Binary(v1), Binary(v2)) => v1.partial_cmp(v2), + (Binary(_), _) => None, + (LargeBinary(v1), LargeBinary(v2)) => v1.partial_cmp(v2), + (LargeBinary(_), _) => None, + (List(v1, t1), List(v2, t2)) => { + if t1.eq(t2) { + v1.partial_cmp(v2) + } else { + None + } + } + (List(_, _), _) => None, + (Date32(v1), Date32(v2)) => v1.partial_cmp(v2), + (Date32(_), _) => None, + (Date64(v1), Date64(v2)) => v1.partial_cmp(v2), + (Date64(_), _) => None, + (TimestampSecond(v1, _), TimestampSecond(v2, _)) => v1.partial_cmp(v2), + (TimestampSecond(_, _), _) => None, + (TimestampMillisecond(v1, _), TimestampMillisecond(v2, _)) => { + v1.partial_cmp(v2) + } + (TimestampMillisecond(_, _), _) => None, + (TimestampMicrosecond(v1, _), TimestampMicrosecond(v2, _)) => { + v1.partial_cmp(v2) + } + (TimestampMicrosecond(_, _), _) => None, + (TimestampNanosecond(v1, _), TimestampNanosecond(v2, _)) => { + v1.partial_cmp(v2) + } + (TimestampNanosecond(_, _), _) => None, + (IntervalYearMonth(v1), IntervalYearMonth(v2)) => v1.partial_cmp(v2), + (IntervalYearMonth(_), _) => None, + (_, IntervalDayTime(_)) => None, + (IntervalDayTime(_), _) => None, + (IntervalMonthDayNano(v1), IntervalMonthDayNano(v2)) => v1.partial_cmp(v2), + (IntervalMonthDayNano(_), _) => None, + (Struct(v1, t1), Struct(v2, t2)) => { + if t1.eq(t2) { + v1.partial_cmp(v2) + } else { + None + } + } + (Struct(_, _), _) => None, + } + } +} + +impl Eq for ScalarValue {} + +// manual implementation of `Hash` that uses OrderedFloat to +// get defined behavior for floating point +impl std::hash::Hash for ScalarValue { + fn hash(&self, state: &mut H) { + use ScalarValue::*; + match self { + Decimal128(v, p, s) => { + v.hash(state); + p.hash(state); + s.hash(state) + } + Boolean(v) => v.hash(state), + Float32(v) => { + let v = v.map(OrderedFloat); + v.hash(state) + } + Float64(v) => { + let v = v.map(OrderedFloat); + v.hash(state) + } + Int8(v) => v.hash(state), + Int16(v) => v.hash(state), + Int32(v) => v.hash(state), + Int64(v) => v.hash(state), + UInt8(v) => v.hash(state), + UInt16(v) => v.hash(state), + UInt32(v) => v.hash(state), + UInt64(v) => v.hash(state), + Utf8(v) => v.hash(state), + LargeUtf8(v) => v.hash(state), + Binary(v) => v.hash(state), + LargeBinary(v) => v.hash(state), + List(v, t) => { + v.hash(state); + t.hash(state); + } + Date32(v) => v.hash(state), + Date64(v) => v.hash(state), + TimestampSecond(v, _) => v.hash(state), + TimestampMillisecond(v, _) => v.hash(state), + TimestampMicrosecond(v, _) => v.hash(state), + TimestampNanosecond(v, _) => v.hash(state), + IntervalYearMonth(v) => v.hash(state), + IntervalDayTime(v) => v.hash(state), + IntervalMonthDayNano(v) => v.hash(state), + Struct(v, t) => { + v.hash(state); + t.hash(state); + } + } + } +} + +// return the index into the dictionary values for array@index as well +// as a reference to the dictionary values array. Returns None for the +// index if the array is NULL at index +#[inline] +fn get_dict_value( + array: &ArrayRef, + index: usize, +) -> Result<(&ArrayRef, Option)> { + let dict_array = array.as_any().downcast_ref::>().unwrap(); + + // look up the index in the values dictionary + let keys_col = dict_array.keys(); + if !keys_col.is_valid(index) { + return Ok((dict_array.values(), None)); + } + let values_index = keys_col.value(index).to_usize().ok_or_else(|| { + DataFusionError::Internal(format!( + "Can not convert index to usize in dictionary of type creating group by value {:?}", + keys_col.data_type() + )) + })?; + + Ok((dict_array.values(), Some(values_index))) +} + +macro_rules! typed_cast_tz { + ($array:expr, $index:expr, $SCALAR:ident, $TZ:expr) => {{ + let array = $array.as_any().downcast_ref::().unwrap(); + ScalarValue::$SCALAR( + match array.is_null($index) { + true => None, + false => Some(array.value($index).into()), + }, + $TZ.clone(), + ) + }}; +} + +macro_rules! typed_cast { + ($array:expr, $index:expr, $ARRAYTYPE:ident, $SCALAR:ident) => {{ + let array = $array.as_any().downcast_ref::<$ARRAYTYPE>().unwrap(); + ScalarValue::$SCALAR(match array.is_null($index) { + true => None, + false => Some(array.value($index).into()), + }) + }}; +} + +macro_rules! build_list { + ($VALUE_BUILDER_TY:ident, $SCALAR_TY:ident, $VALUES:expr, $SIZE:expr) => {{ + let dt = DataType::List(Box::new(Field::new("item", DataType::$SCALAR_TY, true))); + match $VALUES { + // the return on the macro is necessary, to short-circuit and return ArrayRef + None => { + return Arc::from(new_null_array(dt, $SIZE)); + } + Some(values) => { + let mut array = MutableListArray::::new_from( + <$VALUE_BUILDER_TY>::default(), + dt, + $SIZE, + ); + build_values_list!(array, $SCALAR_TY, values.as_ref(), $SIZE) + } + } + }}; +} + +macro_rules! build_timestamp_list { + ($TIME_UNIT:expr, $VALUES:expr, $SIZE:expr, $TZ:expr) => {{ + let child_dt = DataType::Timestamp($TIME_UNIT, $TZ.clone()); + match $VALUES { + // the return on the macro is necessary, to short-circuit and return ArrayRef + None => { + let null_array: ArrayRef = new_null_array( + DataType::List(Box::new(Field::new("item", child_dt, true))), + $SIZE, + ) + .into(); + null_array + } + Some(values) => { + let values = values.as_ref(); + let empty_arr = ::default().to(child_dt.clone()); + let mut array = MutableListArray::::new_from( + empty_arr, + DataType::List(Box::new(Field::new("item", child_dt, true))), + $SIZE, + ); + + match $TIME_UNIT { + TimeUnit::Second => { + build_values_list_tz!(array, TimestampSecond, values, $SIZE) + } + TimeUnit::Microsecond => { + build_values_list_tz!(array, TimestampMillisecond, values, $SIZE) + } + TimeUnit::Millisecond => { + build_values_list_tz!(array, TimestampMicrosecond, values, $SIZE) + } + TimeUnit::Nanosecond => { + build_values_list_tz!(array, TimestampNanosecond, values, $SIZE) + } + } + } + } + }}; +} + +macro_rules! build_values_list { + ($MUTABLE_ARR:ident, $SCALAR_TY:ident, $VALUES:expr, $SIZE:expr) => {{ + for _ in 0..$SIZE { + let mut vec = vec![]; + for scalar_value in $VALUES { + match scalar_value { + ScalarValue::$SCALAR_TY(v) => { + vec.push(v.clone()); + } + _ => panic!("Incompatible ScalarValue for list"), + }; + } + $MUTABLE_ARR.try_push(Some(vec)).unwrap(); + } + + let array: ListArray = $MUTABLE_ARR.into(); + Arc::new(array) + }}; +} + +macro_rules! dyn_to_array { + ($self:expr, $value:expr, $size:expr, $ty:ty) => {{ + Arc::new(PrimitiveArray::<$ty>::from_data( + $self.get_datatype(), + Buffer::<$ty>::from_iter(repeat(*$value).take($size)), + None, + )) + }}; +} + +macro_rules! build_values_list_tz { + ($MUTABLE_ARR:ident, $SCALAR_TY:ident, $VALUES:expr, $SIZE:expr) => {{ + for _ in 0..$SIZE { + let mut vec = vec![]; + for scalar_value in $VALUES { + match scalar_value { + ScalarValue::$SCALAR_TY(v, _) => { + vec.push(v.clone()); + } + _ => panic!("Incompatible ScalarValue for list"), + }; + } + $MUTABLE_ARR.try_push(Some(vec)).unwrap(); + } + + let array: ListArray = $MUTABLE_ARR.into(); + Arc::new(array) + }}; +} + +macro_rules! eq_array_primitive { + ($array:expr, $index:expr, $ARRAYTYPE:ident, $VALUE:expr) => {{ + let array = $array.as_any().downcast_ref::<$ARRAYTYPE>().unwrap(); + let is_valid = array.is_valid($index); + match $VALUE { + Some(val) => is_valid && &array.value($index) == val, + None => !is_valid, + } + }}; +} + +impl ScalarValue { + /// Create a decimal Scalar from value/precision and scale. + pub fn try_new_decimal128( + value: i128, + precision: usize, + scale: usize, + ) -> Result { + // make sure the precision and scale is valid + if precision <= MAX_PRECISION_FOR_DECIMAL128 && scale <= precision { + return Ok(ScalarValue::Decimal128(Some(value), precision, scale)); + } + return Err(DataFusionError::Internal(format!( + "Can not new a decimal type ScalarValue for precision {} and scale {}", + precision, scale + ))); + } + + /// Getter for the `DataType` of the value + pub fn get_datatype(&self) -> DataType { + match self { + ScalarValue::Boolean(_) => DataType::Boolean, + ScalarValue::UInt8(_) => DataType::UInt8, + ScalarValue::UInt16(_) => DataType::UInt16, + ScalarValue::UInt32(_) => DataType::UInt32, + ScalarValue::UInt64(_) => DataType::UInt64, + ScalarValue::Int8(_) => DataType::Int8, + ScalarValue::Int16(_) => DataType::Int16, + ScalarValue::Int32(_) => DataType::Int32, + ScalarValue::Int64(_) => DataType::Int64, + ScalarValue::Decimal128(_, precision, scale) => { + DataType::Decimal(*precision, *scale) + } + ScalarValue::TimestampSecond(_, tz_opt) => { + DataType::Timestamp(TimeUnit::Second, tz_opt.clone()) + } + ScalarValue::TimestampMillisecond(_, tz_opt) => { + DataType::Timestamp(TimeUnit::Millisecond, tz_opt.clone()) + } + ScalarValue::TimestampMicrosecond(_, tz_opt) => { + DataType::Timestamp(TimeUnit::Microsecond, tz_opt.clone()) + } + ScalarValue::TimestampNanosecond(_, tz_opt) => { + DataType::Timestamp(TimeUnit::Nanosecond, tz_opt.clone()) + } + ScalarValue::Float32(_) => DataType::Float32, + ScalarValue::Float64(_) => DataType::Float64, + ScalarValue::Utf8(_) => DataType::Utf8, + ScalarValue::LargeUtf8(_) => DataType::LargeUtf8, + ScalarValue::Binary(_) => DataType::Binary, + ScalarValue::LargeBinary(_) => DataType::LargeBinary, + ScalarValue::List(_, data_type) => DataType::List(Box::new(Field::new( + "item", + data_type.as_ref().clone(), + true, + ))), + ScalarValue::Date32(_) => DataType::Date32, + ScalarValue::Date64(_) => DataType::Date64, + ScalarValue::IntervalYearMonth(_) => { + DataType::Interval(IntervalUnit::YearMonth) + } + ScalarValue::IntervalDayTime(_) => DataType::Interval(IntervalUnit::DayTime), + ScalarValue::IntervalMonthDayNano(_) => { + DataType::Interval(IntervalUnit::MonthDayNano) + } + ScalarValue::Struct(_, fields) => DataType::Struct(fields.as_ref().clone()), + } + } + + /// Calculate arithmetic negation for a scalar value + pub fn arithmetic_negate(&self) -> Self { + match self { + ScalarValue::Boolean(None) + | ScalarValue::Int8(None) + | ScalarValue::Int16(None) + | ScalarValue::Int32(None) + | ScalarValue::Int64(None) + | ScalarValue::Float32(None) => self.clone(), + ScalarValue::Float64(Some(v)) => ScalarValue::Float64(Some(-v)), + ScalarValue::Float32(Some(v)) => ScalarValue::Float32(Some(-v)), + ScalarValue::Int8(Some(v)) => ScalarValue::Int8(Some(-v)), + ScalarValue::Int16(Some(v)) => ScalarValue::Int16(Some(-v)), + ScalarValue::Int32(Some(v)) => ScalarValue::Int32(Some(-v)), + ScalarValue::Int64(Some(v)) => ScalarValue::Int64(Some(-v)), + ScalarValue::Decimal128(Some(v), precision, scale) => { + ScalarValue::Decimal128(Some(-v), *precision, *scale) + } + _ => panic!("Cannot run arithmetic negate on scalar value: {:?}", self), + } + } + + /// whether this value is null or not. + pub fn is_null(&self) -> bool { + matches!( + *self, + ScalarValue::Boolean(None) + | ScalarValue::UInt8(None) + | ScalarValue::UInt16(None) + | ScalarValue::UInt32(None) + | ScalarValue::UInt64(None) + | ScalarValue::Int8(None) + | ScalarValue::Int16(None) + | ScalarValue::Int32(None) + | ScalarValue::Int64(None) + | ScalarValue::Float32(None) + | ScalarValue::Float64(None) + | ScalarValue::Date32(None) + | ScalarValue::Date64(None) + | ScalarValue::Utf8(None) + | ScalarValue::LargeUtf8(None) + | ScalarValue::List(None, _) + | ScalarValue::TimestampSecond(None, _) + | ScalarValue::TimestampMillisecond(None, _) + | ScalarValue::TimestampMicrosecond(None, _) + | ScalarValue::TimestampNanosecond(None, _) + | ScalarValue::Struct(None, _) + | ScalarValue::Decimal128(None, _, _) // For decimal type, the value is null means ScalarValue::Decimal128 is null. + ) + } + + /// Converts a scalar value into an 1-row array. + pub fn to_array(&self) -> ArrayRef { + self.to_array_of_size(1) + } + + /// Converts an iterator of references [`ScalarValue`] into an [`ArrayRef`] + /// corresponding to those values. For example, + /// + /// Returns an error if the iterator is empty or if the + /// [`ScalarValue`]s are not all the same type + /// + /// Example + /// ``` + /// use datafusion::scalar::ScalarValue; + /// use arrow::array::{BooleanArray, Array}; + /// + /// let scalars = vec![ + /// ScalarValue::Boolean(Some(true)), + /// ScalarValue::Boolean(None), + /// ScalarValue::Boolean(Some(false)), + /// ]; + /// + /// // Build an Array from the list of ScalarValues + /// let array = ScalarValue::iter_to_array(scalars.into_iter()) + /// .unwrap(); + /// + /// let expected: Box = Box::new( + /// BooleanArray::from(vec![ + /// Some(true), + /// None, + /// Some(false) + /// ] + /// )); + /// + /// assert_eq!(&array, &expected); + /// ``` + pub fn iter_to_array( + scalars: impl IntoIterator, + ) -> Result { + let mut scalars = scalars.into_iter().peekable(); + + // figure out the type based on the first element + let data_type = match scalars.peek() { + None => { + return Err(DataFusionError::Internal( + "Empty iterator passed to ScalarValue::iter_to_array".to_string(), + )); + } + Some(sv) => sv.get_datatype(), + }; + + /// Creates an array of $ARRAY_TY by unpacking values of + /// SCALAR_TY for primitive types + macro_rules! build_array_primitive { + ($TY:ty, $SCALAR_TY:ident, $DT:ident) => {{ + { + Arc::new(scalars + .map(|sv| { + if let ScalarValue::$SCALAR_TY(v) = sv { + Ok(v) + } else { + Err(DataFusionError::Internal(format!( + "Inconsistent types in ScalarValue::iter_to_array. \ + Expected {:?}, got {:?}", + data_type, sv + ))) + } + }).collect::>>()?.to($DT) + ) as Arc + } + }}; + } + + macro_rules! build_array_primitive_tz { + ($SCALAR_TY:ident) => {{ + { + let array = scalars + .map(|sv| { + if let ScalarValue::$SCALAR_TY(v, _) = sv { + Ok(v) + } else { + Err(DataFusionError::Internal(format!( + "Inconsistent types in ScalarValue::iter_to_array. \ + Expected {:?}, got {:?}", + data_type, sv + ))) + } + }) + .collect::>()?; + + Arc::new(array) + } + }}; + } + + /// Creates an array of $ARRAY_TY by unpacking values of + /// SCALAR_TY for "string-like" types. + macro_rules! build_array_string { + ($ARRAY_TY:ident, $SCALAR_TY:ident) => {{ + { + let array = scalars + .map(|sv| { + if let ScalarValue::$SCALAR_TY(v) = sv { + Ok(v) + } else { + Err(DataFusionError::Internal(format!( + "Inconsistent types in ScalarValue::iter_to_array. \ + Expected {:?}, got {:?}", + data_type, sv + ))) + } + }) + .collect::>()?; + Arc::new(array) + } + }}; + } + + macro_rules! build_array_list { + ($MUTABLE_TY:ty, $SCALAR_TY:ident) => {{ + let mut array = MutableListArray::::new(); + for scalar in scalars.into_iter() { + match scalar { + ScalarValue::List(Some(xs), _) => { + let xs = *xs; + let mut vec = vec![]; + for s in xs { + match s { + ScalarValue::$SCALAR_TY(o) => { vec.push(o) } + sv => return Err(DataFusionError::Internal(format!( + "Inconsistent types in ScalarValue::iter_to_array. \ + Expected Utf8, got {:?}", + sv + ))), + } + } + array.try_push(Some(vec))?; + } + ScalarValue::List(None, _) => { + array.push_null(); + } + sv => { + return Err(DataFusionError::Internal(format!( + "Inconsistent types in ScalarValue::iter_to_array. \ + Expected List, got {:?}", + sv + ))) + } + } + } + + let array: ListArray = array.into(); + Arc::new(array) + }} + } + + use DataType::*; + let array: Arc = match &data_type { + DataType::Decimal(precision, scale) => { + let decimal_array = + ScalarValue::iter_to_decimal_array(scalars, precision, scale)?; + Arc::new(decimal_array) + } + DataType::Boolean => Arc::new( + scalars + .map(|sv| { + if let ScalarValue::Boolean(v) = sv { + Ok(v) + } else { + Err(DataFusionError::Internal(format!( + "Inconsistent types in ScalarValue::iter_to_array. \ + Expected {:?}, got {:?}", + data_type, sv + ))) + } + }) + .collect::>()?, + ), + Float32 => { + build_array_primitive!(f32, Float32, Float32) + } + Float64 => { + build_array_primitive!(f64, Float64, Float64) + } + Int8 => build_array_primitive!(i8, Int8, Int8), + Int16 => build_array_primitive!(i16, Int16, Int16), + Int32 => build_array_primitive!(i32, Int32, Int32), + Int64 => build_array_primitive!(i64, Int64, Int64), + UInt8 => build_array_primitive!(u8, UInt8, UInt8), + UInt16 => build_array_primitive!(u16, UInt16, UInt16), + UInt32 => build_array_primitive!(u32, UInt32, UInt32), + UInt64 => build_array_primitive!(u64, UInt64, UInt64), + Utf8 => build_array_string!(StringArray, Utf8), + LargeUtf8 => build_array_string!(LargeStringArray, LargeUtf8), + Binary => build_array_string!(SmallBinaryArray, Binary), + LargeBinary => build_array_string!(LargeBinaryArray, LargeBinary), + Date32 => build_array_primitive!(i32, Date32, Date32), + Date64 => build_array_primitive!(i64, Date64, Date64), + Timestamp(TimeUnit::Second, _) => { + build_array_primitive_tz!(TimestampSecond) + } + Timestamp(TimeUnit::Millisecond, _) => { + build_array_primitive_tz!(TimestampMillisecond) + } + Timestamp(TimeUnit::Microsecond, _) => { + build_array_primitive_tz!(TimestampMicrosecond) + } + Timestamp(TimeUnit::Nanosecond, _) => { + build_array_primitive_tz!(TimestampNanosecond) + } + Interval(IntervalUnit::DayTime) => { + build_array_primitive!(days_ms, IntervalDayTime, data_type) + } + Interval(IntervalUnit::YearMonth) => { + build_array_primitive!(i32, IntervalYearMonth, data_type) + } + DataType::List(fields) if fields.data_type() == &DataType::Int8 => { + build_array_list!(Int8Vec, Int8) + } + DataType::List(fields) if fields.data_type() == &DataType::Int16 => { + build_array_list!(Int16Vec, Int16) + } + DataType::List(fields) if fields.data_type() == &DataType::Int32 => { + build_array_list!(Int32Vec, Int32) + } + DataType::List(fields) if fields.data_type() == &DataType::Int64 => { + build_array_list!(Int64Vec, Int64) + } + DataType::List(fields) if fields.data_type() == &DataType::UInt8 => { + build_array_list!(UInt8Vec, UInt8) + } + DataType::List(fields) if fields.data_type() == &DataType::UInt16 => { + build_array_list!(UInt16Vec, UInt16) + } + DataType::List(fields) if fields.data_type() == &DataType::UInt32 => { + build_array_list!(UInt32Vec, UInt32) + } + DataType::List(fields) if fields.data_type() == &DataType::UInt64 => { + build_array_list!(UInt64Vec, UInt64) + } + DataType::List(fields) if fields.data_type() == &DataType::Float32 => { + build_array_list!(Float32Vec, Float32) + } + DataType::List(fields) if fields.data_type() == &DataType::Float64 => { + build_array_list!(Float64Vec, Float64) + } + DataType::List(fields) if fields.data_type() == &DataType::Utf8 => { + build_array_list!(MutableStringArray, Utf8) + } + DataType::List(fields) if fields.data_type() == &DataType::LargeUtf8 => { + build_array_list!(MutableLargeStringArray, LargeUtf8) + } + DataType::List(_) => { + // Fallback case handling homogeneous lists with any ScalarValue element type + let list_array = ScalarValue::iter_to_array_list(scalars, &data_type)?; + Arc::new(list_array) + } + DataType::Struct(fields) => { + // Initialize a Vector to store the ScalarValues for each column + let mut columns: Vec> = + (0..fields.len()).map(|_| Vec::new()).collect(); + + // Iterate over scalars to populate the column scalars for each row + for scalar in scalars { + if let ScalarValue::Struct(values, fields) = scalar { + match values { + Some(values) => { + // Push value for each field + for c in 0..columns.len() { + let column = columns.get_mut(c).unwrap(); + column.push(values[c].clone()); + } + } + None => { + // Push NULL of the appropriate type for each field + for c in 0..columns.len() { + let dtype = fields[c].data_type(); + let column = columns.get_mut(c).unwrap(); + column.push(ScalarValue::try_from(dtype)?); + } + } + }; + } else { + return Err(DataFusionError::Internal(format!( + "Expected Struct but found: {}", + scalar + ))); + }; + } + + // Call iter_to_array recursively to convert the scalars for each column into Arrow arrays + let field_values = columns + .iter() + .map(|c| Self::iter_to_array(c.clone()).map(Arc::from)) + .collect::>>()?; + + Arc::new(StructArray::from_data(data_type, field_values, None)) + } + _ => { + return Err(DataFusionError::Internal(format!( + "Unsupported creation of {:?} array from ScalarValue {:?}", + data_type, + scalars.peek() + ))); + } + }; + + Ok(array) + } + + fn iter_to_decimal_array( + scalars: impl IntoIterator, + precision: &usize, + scale: &usize, + ) -> Result { + // collect the value as Option + let array = scalars + .into_iter() + .map(|element: ScalarValue| match element { + ScalarValue::Decimal128(v1, _, _) => v1, + _ => unreachable!(), + }) + .collect::>>(); + + // build the decimal array using the Decimal Builder + Ok(Int128Vec::from(array) + .to(Decimal(*precision, *scale)) + .into()) + } + + fn iter_to_array_list( + scalars: impl IntoIterator, + data_type: &DataType, + ) -> Result> { + let mut offsets: Vec = vec![0]; + + let mut elements: Vec = Vec::new(); + let mut valid: Vec = vec![]; + + let mut flat_len = 0i32; + for scalar in scalars { + if let ScalarValue::List(values, _) = scalar { + match values { + Some(values) => { + let element_array = ScalarValue::iter_to_array(*values)?; + + // Add new offset index + flat_len += element_array.len() as i32; + offsets.push(flat_len); + + elements.push(element_array); + + // Element is valid + valid.push(true); + } + None => { + // Repeat previous offset index + offsets.push(flat_len); + + // Element is null + valid.push(false); + } + } + } else { + return Err(DataFusionError::Internal(format!( + "Expected ScalarValue::List element. Received {:?}", + scalar + ))); + } + } + + // Concatenate element arrays to create single flat array + let element_arrays: Vec<&dyn Array> = + elements.iter().map(|a| a.as_ref()).collect(); + let flat_array = match concatenate::concatenate(&element_arrays) { + Ok(flat_array) => flat_array, + Err(err) => return Err(DataFusionError::ArrowError(err)), + }; + + let list_array = ListArray::::from_data( + data_type.clone(), + Buffer::from(offsets), + flat_array.into(), + Some(Bitmap::from(valid)), + ); + + Ok(list_array) + } + + /// Converts a scalar value into an array of `size` rows. + pub fn to_array_of_size(&self, size: usize) -> ArrayRef { + match self { + ScalarValue::Decimal128(e, precision, scale) => { + Int128Vec::from_iter(repeat(e).take(size)) + .to(Decimal(*precision, *scale)) + .into_arc() + } + ScalarValue::Boolean(e) => { + Arc::new(BooleanArray::from(vec![*e; size])) as ArrayRef + } + ScalarValue::Float64(e) => match e { + Some(value) => { + dyn_to_array!(self, value, size, f64) + } + None => new_null_array(self.get_datatype(), size).into(), + }, + ScalarValue::Float32(e) => match e { + Some(value) => dyn_to_array!(self, value, size, f32), + None => new_null_array(self.get_datatype(), size).into(), + }, + ScalarValue::Int8(e) => match e { + Some(value) => dyn_to_array!(self, value, size, i8), + None => new_null_array(self.get_datatype(), size).into(), + }, + ScalarValue::Int16(e) => match e { + Some(value) => dyn_to_array!(self, value, size, i16), + None => new_null_array(self.get_datatype(), size).into(), + }, + ScalarValue::Int32(e) + | ScalarValue::Date32(e) + | ScalarValue::IntervalYearMonth(e) => match e { + Some(value) => dyn_to_array!(self, value, size, i32), + None => new_null_array(self.get_datatype(), size).into(), + }, + ScalarValue::IntervalMonthDayNano(e) => match e { + Some(value) => dyn_to_array!(self, value, size, i128), + None => new_null_array(self.get_datatype(), size).into(), + }, + ScalarValue::Int64(e) | ScalarValue::Date64(e) => match e { + Some(value) => dyn_to_array!(self, value, size, i64), + None => new_null_array(self.get_datatype(), size).into(), + }, + ScalarValue::UInt8(e) => match e { + Some(value) => dyn_to_array!(self, value, size, u8), + None => new_null_array(self.get_datatype(), size).into(), + }, + ScalarValue::UInt16(e) => match e { + Some(value) => dyn_to_array!(self, value, size, u16), + None => new_null_array(self.get_datatype(), size).into(), + }, + ScalarValue::UInt32(e) => match e { + Some(value) => dyn_to_array!(self, value, size, u32), + None => new_null_array(self.get_datatype(), size).into(), + }, + ScalarValue::UInt64(e) => match e { + Some(value) => dyn_to_array!(self, value, size, u64), + None => new_null_array(self.get_datatype(), size).into(), + }, + ScalarValue::TimestampSecond(e, _) => match e { + Some(value) => dyn_to_array!(self, value, size, i64), + None => new_null_array(self.get_datatype(), size).into(), + }, + ScalarValue::TimestampMillisecond(e, _) => match e { + Some(value) => dyn_to_array!(self, value, size, i64), + None => new_null_array(self.get_datatype(), size).into(), + }, + + ScalarValue::TimestampMicrosecond(e, _) => match e { + Some(value) => dyn_to_array!(self, value, size, i64), + None => new_null_array(self.get_datatype(), size).into(), + }, + ScalarValue::TimestampNanosecond(e, _) => match e { + Some(value) => dyn_to_array!(self, value, size, i64), + None => new_null_array(self.get_datatype(), size).into(), + }, + ScalarValue::Utf8(e) => match e { + Some(value) => Arc::new(Utf8Array::::from_trusted_len_values_iter( + repeat(&value).take(size), + )), + None => new_null_array(self.get_datatype(), size).into(), + }, + ScalarValue::LargeUtf8(e) => match e { + Some(value) => Arc::new(Utf8Array::::from_trusted_len_values_iter( + repeat(&value).take(size), + )), + None => new_null_array(self.get_datatype(), size).into(), + }, + ScalarValue::Binary(e) => match e { + Some(value) => Arc::new( + repeat(Some(value.as_slice())) + .take(size) + .collect::>(), + ), + None => new_null_array(self.get_datatype(), size).into(), + }, + ScalarValue::LargeBinary(e) => match e { + Some(value) => Arc::new( + repeat(Some(value.as_slice())) + .take(size) + .collect::>(), + ), + None => new_null_array(self.get_datatype(), size).into(), + }, + ScalarValue::List(values, data_type) => match data_type.as_ref() { + DataType::Boolean => { + build_list!(MutableBooleanArray, Boolean, values, size) + } + DataType::Int8 => build_list!(Int8Vec, Int8, values, size), + DataType::Int16 => build_list!(Int16Vec, Int16, values, size), + DataType::Int32 => build_list!(Int32Vec, Int32, values, size), + DataType::Int64 => build_list!(Int64Vec, Int64, values, size), + DataType::UInt8 => build_list!(UInt8Vec, UInt8, values, size), + DataType::UInt16 => build_list!(UInt16Vec, UInt16, values, size), + DataType::UInt32 => build_list!(UInt32Vec, UInt32, values, size), + DataType::UInt64 => build_list!(UInt64Vec, UInt64, values, size), + DataType::Float32 => build_list!(Float32Vec, Float32, values, size), + DataType::Float64 => build_list!(Float64Vec, Float64, values, size), + DataType::Timestamp(unit, tz) => { + build_timestamp_list!(*unit, values, size, tz.clone()) + } + DataType::Utf8 => build_list!(MutableStringArray, Utf8, values, size), + DataType::LargeUtf8 => { + build_list!(MutableLargeStringArray, LargeUtf8, values, size) + } + dt => panic!("Unexpected DataType for list {:?}", dt), + }, + ScalarValue::IntervalDayTime(e) => match e { + Some(value) => { + Arc::new(PrimitiveArray::::from_trusted_len_values_iter( + std::iter::repeat(*value).take(size), + )) + } + None => new_null_array(self.get_datatype(), size).into(), + }, + ScalarValue::Struct(values, _) => match values { + Some(values) => { + let field_values = + values.iter().map(|v| v.to_array_of_size(size)).collect(); + Arc::new(StructArray::from_data( + self.get_datatype(), + field_values, + None, + )) + } + None => Arc::new(StructArray::new_null(self.get_datatype(), size)), + }, + } + } + + fn get_decimal_value_from_array( + array: &ArrayRef, + index: usize, + precision: &usize, + scale: &usize, + ) -> ScalarValue { + let array = array.as_any().downcast_ref::().unwrap(); + if array.is_null(index) { + ScalarValue::Decimal128(None, *precision, *scale) + } else { + ScalarValue::Decimal128(Some(array.value(index)), *precision, *scale) + } + } + + /// Converts a value in `array` at `index` into a ScalarValue + pub fn try_from_array(array: &ArrayRef, index: usize) -> Result { + // handle NULL value + if !array.is_valid(index) { + return array.data_type().try_into(); + } + + Ok(match array.data_type() { + DataType::Decimal(precision, scale) => { + ScalarValue::get_decimal_value_from_array(array, index, precision, scale) + } + DataType::Boolean => typed_cast!(array, index, BooleanArray, Boolean), + DataType::Float64 => typed_cast!(array, index, Float64Array, Float64), + DataType::Float32 => typed_cast!(array, index, Float32Array, Float32), + DataType::UInt64 => typed_cast!(array, index, UInt64Array, UInt64), + DataType::UInt32 => typed_cast!(array, index, UInt32Array, UInt32), + DataType::UInt16 => typed_cast!(array, index, UInt16Array, UInt16), + DataType::UInt8 => typed_cast!(array, index, UInt8Array, UInt8), + DataType::Int64 => typed_cast!(array, index, Int64Array, Int64), + DataType::Int32 => typed_cast!(array, index, Int32Array, Int32), + DataType::Int16 => typed_cast!(array, index, Int16Array, Int16), + DataType::Int8 => typed_cast!(array, index, Int8Array, Int8), + DataType::Binary => typed_cast!(array, index, SmallBinaryArray, Binary), + DataType::LargeBinary => { + typed_cast!(array, index, LargeBinaryArray, LargeBinary) + } + DataType::Utf8 => typed_cast!(array, index, StringArray, Utf8), + DataType::LargeUtf8 => typed_cast!(array, index, LargeStringArray, LargeUtf8), + DataType::List(nested_type) => { + let list_array = array + .as_any() + .downcast_ref::>() + .ok_or_else(|| { + DataFusionError::Internal( + "Failed to downcast ListArray".to_string(), + ) + })?; + let value = match list_array.is_null(index) { + true => None, + false => { + let nested_array = ArrayRef::from(list_array.value(index)); + let scalar_vec = (0..nested_array.len()) + .map(|i| ScalarValue::try_from_array(&nested_array, i)) + .collect::>>()?; + Some(scalar_vec) + } + }; + let value = value.map(Box::new); + let data_type = Box::new(nested_type.data_type().clone()); + ScalarValue::List(value, data_type) + } + DataType::Date32 => { + typed_cast!(array, index, Int32Array, Date32) + } + DataType::Date64 => { + typed_cast!(array, index, Int64Array, Date64) + } + DataType::Timestamp(TimeUnit::Second, tz_opt) => { + typed_cast_tz!(array, index, TimestampSecond, tz_opt) + } + DataType::Timestamp(TimeUnit::Millisecond, tz_opt) => { + typed_cast_tz!(array, index, TimestampMillisecond, tz_opt) + } + DataType::Timestamp(TimeUnit::Microsecond, tz_opt) => { + typed_cast_tz!(array, index, TimestampMicrosecond, tz_opt) + } + DataType::Timestamp(TimeUnit::Nanosecond, tz_opt) => { + typed_cast_tz!(array, index, TimestampNanosecond, tz_opt) + } + DataType::Dictionary(index_type, _, _) => { + let (values, values_index) = match index_type { + IntegerType::Int8 => get_dict_value::(array, index)?, + IntegerType::Int16 => get_dict_value::(array, index)?, + IntegerType::Int32 => get_dict_value::(array, index)?, + IntegerType::Int64 => get_dict_value::(array, index)?, + IntegerType::UInt8 => get_dict_value::(array, index)?, + IntegerType::UInt16 => get_dict_value::(array, index)?, + IntegerType::UInt32 => get_dict_value::(array, index)?, + IntegerType::UInt64 => get_dict_value::(array, index)?, + }; + + match values_index { + Some(values_index) => Self::try_from_array(values, values_index)?, + // was null + None => values.data_type().try_into()?, + } + } + DataType::Struct(fields) => { + let array = + array + .as_any() + .downcast_ref::() + .ok_or_else(|| { + DataFusionError::Internal( + "Failed to downcast ArrayRef to StructArray".to_string(), + ) + })?; + let mut field_values: Vec = Vec::new(); + for col_index in 0..array.num_columns() { + let col_array = &array.values()[col_index]; + let col_scalar = ScalarValue::try_from_array(col_array, index)?; + field_values.push(col_scalar); + } + Self::Struct(Some(Box::new(field_values)), Box::new(fields.clone())) + } + other => { + return Err(DataFusionError::NotImplemented(format!( + "Can't create a scalar from array of type \"{:?}\"", + other + ))); + } + }) + } + + fn eq_array_decimal( + array: &ArrayRef, + index: usize, + value: &Option, + precision: usize, + scale: usize, + ) -> bool { + let array = array.as_any().downcast_ref::().unwrap(); + match array.data_type() { + Decimal(pre, sca) => { + if *pre != precision || *sca != scale { + return false; + } + } + _ => return false, + } + match value { + None => array.is_null(index), + Some(v) => !array.is_null(index) && array.value(index) == *v, + } + } + + /// Compares a single row of array @ index for equality with self, + /// in an optimized fashion. + /// + /// This method implements an optimized version of: + /// + /// ```text + /// let arr_scalar = Self::try_from_array(array, index).unwrap(); + /// arr_scalar.eq(self) + /// ``` + /// + /// *Performance note*: the arrow compute kernels should be + /// preferred over this function if at all possible as they can be + /// vectorized and are generally much faster. + /// + /// This function has a few narrow usescases such as hash table key + /// comparisons where comparing a single row at a time is necessary. + #[inline] + pub fn eq_array(&self, array: &ArrayRef, index: usize) -> bool { + if let DataType::Dictionary(key_type, _, _) = array.data_type() { + return self.eq_array_dictionary(array, index, key_type); + } + + match self { + ScalarValue::Decimal128(v, precision, scale) => { + ScalarValue::eq_array_decimal(array, index, v, *precision, *scale) + } + ScalarValue::Boolean(val) => { + eq_array_primitive!(array, index, BooleanArray, val) + } + ScalarValue::Float32(val) => { + eq_array_primitive!(array, index, Float32Array, val) + } + ScalarValue::Float64(val) => { + eq_array_primitive!(array, index, Float64Array, val) + } + ScalarValue::Int8(val) => eq_array_primitive!(array, index, Int8Array, val), + ScalarValue::Int16(val) => eq_array_primitive!(array, index, Int16Array, val), + ScalarValue::Int32(val) => eq_array_primitive!(array, index, Int32Array, val), + ScalarValue::Int64(val) => eq_array_primitive!(array, index, Int64Array, val), + ScalarValue::UInt8(val) => eq_array_primitive!(array, index, UInt8Array, val), + ScalarValue::UInt16(val) => { + eq_array_primitive!(array, index, UInt16Array, val) + } + ScalarValue::UInt32(val) => { + eq_array_primitive!(array, index, UInt32Array, val) + } + ScalarValue::UInt64(val) => { + eq_array_primitive!(array, index, UInt64Array, val) + } + ScalarValue::Utf8(val) => eq_array_primitive!(array, index, StringArray, val), + ScalarValue::LargeUtf8(val) => { + eq_array_primitive!(array, index, LargeStringArray, val) + } + ScalarValue::Binary(val) => { + eq_array_primitive!(array, index, SmallBinaryArray, val) + } + ScalarValue::LargeBinary(val) => { + eq_array_primitive!(array, index, LargeBinaryArray, val) + } + ScalarValue::List(_, _) => unimplemented!(), + ScalarValue::Date32(val) => { + eq_array_primitive!(array, index, Int32Array, val) + } + ScalarValue::Date64(val) => { + eq_array_primitive!(array, index, Int64Array, val) + } + ScalarValue::TimestampSecond(val, _) => { + eq_array_primitive!(array, index, Int64Array, val) + } + ScalarValue::TimestampMillisecond(val, _) => { + eq_array_primitive!(array, index, Int64Array, val) + } + ScalarValue::TimestampMicrosecond(val, _) => { + eq_array_primitive!(array, index, Int64Array, val) + } + ScalarValue::TimestampNanosecond(val, _) => { + eq_array_primitive!(array, index, Int64Array, val) + } + ScalarValue::IntervalYearMonth(val) => { + eq_array_primitive!(array, index, Int32Array, val) + } + ScalarValue::IntervalDayTime(val) => { + eq_array_primitive!(array, index, DaysMsArray, val) + } + ScalarValue::IntervalMonthDayNano(val) => { + eq_array_primitive!(array, index, Int128Array, val) + } + ScalarValue::Struct(_, _) => unimplemented!(), + } + } + + /// Compares a dictionary array with indexes of type `key_type` + /// with the array @ index for equality with self + fn eq_array_dictionary( + &self, + array: &ArrayRef, + index: usize, + key_type: &IntegerType, + ) -> bool { + let (values, values_index) = match key_type { + IntegerType::Int8 => get_dict_value::(array, index).unwrap(), + IntegerType::Int16 => get_dict_value::(array, index).unwrap(), + IntegerType::Int32 => get_dict_value::(array, index).unwrap(), + IntegerType::Int64 => get_dict_value::(array, index).unwrap(), + IntegerType::UInt8 => get_dict_value::(array, index).unwrap(), + IntegerType::UInt16 => get_dict_value::(array, index).unwrap(), + IntegerType::UInt32 => get_dict_value::(array, index).unwrap(), + IntegerType::UInt64 => get_dict_value::(array, index).unwrap(), + }; + + match values_index { + Some(values_index) => self.eq_array(values, values_index), + None => self.is_null(), + } + } +} + +macro_rules! impl_scalar { + ($ty:ty, $scalar:tt) => { + impl From<$ty> for ScalarValue { + fn from(value: $ty) -> Self { + ScalarValue::$scalar(Some(value)) + } + } + + impl From> for ScalarValue { + fn from(value: Option<$ty>) -> Self { + ScalarValue::$scalar(value) + } + } + }; +} + +impl_scalar!(f64, Float64); +impl_scalar!(f32, Float32); +impl_scalar!(i8, Int8); +impl_scalar!(i16, Int16); +impl_scalar!(i32, Int32); +impl_scalar!(i64, Int64); +impl_scalar!(bool, Boolean); +impl_scalar!(u8, UInt8); +impl_scalar!(u16, UInt16); +impl_scalar!(u32, UInt32); +impl_scalar!(u64, UInt64); + +impl From<&str> for ScalarValue { + fn from(value: &str) -> Self { + Some(value).into() + } +} + +impl From> for ScalarValue { + fn from(value: Option<&str>) -> Self { + let value = value.map(|s| s.to_string()); + ScalarValue::Utf8(value) + } +} + +impl FromStr for ScalarValue { + type Err = Infallible; + + fn from_str(s: &str) -> std::result::Result { + Ok(s.into()) + } +} + +impl From> for ScalarValue { + fn from(value: Vec<(&str, ScalarValue)>) -> Self { + let (fields, scalars): (Vec<_>, Vec<_>) = value + .into_iter() + .map(|(name, scalar)| { + (Field::new(name, scalar.get_datatype(), false), scalar) + }) + .unzip(); + + Self::Struct(Some(Box::new(scalars)), Box::new(fields)) + } +} + +macro_rules! impl_try_from { + ($SCALAR:ident, $NATIVE:ident) => { + impl TryFrom for $NATIVE { + type Error = DataFusionError; + + fn try_from(value: ScalarValue) -> Result { + match value { + ScalarValue::$SCALAR(Some(inner_value)) => Ok(inner_value), + _ => Err(DataFusionError::Internal(format!( + "Cannot convert {:?} to {}", + value, + std::any::type_name::() + ))), + } + } + } + }; +} + +impl_try_from!(Int8, i8); +impl_try_from!(Int16, i16); + +// special implementation for i32 because of Date32 +impl TryFrom for i32 { + type Error = DataFusionError; + + fn try_from(value: ScalarValue) -> Result { + match value { + ScalarValue::Int32(Some(inner_value)) + | ScalarValue::Date32(Some(inner_value)) => Ok(inner_value), + _ => Err(DataFusionError::Internal(format!( + "Cannot convert {:?} to {}", + value, + std::any::type_name::() + ))), + } + } +} + +// special implementation for i64 because of TimeNanosecond +impl TryFrom for i64 { + type Error = DataFusionError; + + fn try_from(value: ScalarValue) -> Result { + match value { + ScalarValue::Int64(Some(inner_value)) + | ScalarValue::Date64(Some(inner_value)) + | ScalarValue::TimestampNanosecond(Some(inner_value), _) + | ScalarValue::TimestampMicrosecond(Some(inner_value), _) + | ScalarValue::TimestampMillisecond(Some(inner_value), _) + | ScalarValue::TimestampSecond(Some(inner_value), _) => Ok(inner_value), + _ => Err(DataFusionError::Internal(format!( + "Cannot convert {:?} to {}", + value, + std::any::type_name::() + ))), + } + } +} + +// special implementation for i128 because of Decimal128 +impl TryFrom for i128 { + type Error = DataFusionError; + + fn try_from(value: ScalarValue) -> Result { + match value { + ScalarValue::Decimal128(Some(inner_value), _, _) => Ok(inner_value), + _ => Err(DataFusionError::Internal(format!( + "Cannot convert {:?} to {}", + value, + std::any::type_name::() + ))), + } + } +} + +impl_try_from!(UInt8, u8); +impl_try_from!(UInt16, u16); +impl_try_from!(UInt32, u32); +impl_try_from!(UInt64, u64); +impl_try_from!(Float32, f32); +impl_try_from!(Float64, f64); +impl_try_from!(Boolean, bool); + +impl TryInto> for &ScalarValue { + type Error = DataFusionError; + + fn try_into(self) -> Result> { + use arrow::scalar::*; + match self { + ScalarValue::Boolean(b) => Ok(Box::new(BooleanScalar::new(*b))), + ScalarValue::Float32(f) => { + Ok(Box::new(PrimitiveScalar::::new(DataType::Float32, *f))) + } + ScalarValue::Float64(f) => { + Ok(Box::new(PrimitiveScalar::::new(DataType::Float64, *f))) + } + ScalarValue::Int8(i) => { + Ok(Box::new(PrimitiveScalar::::new(DataType::Int8, *i))) + } + ScalarValue::Int16(i) => { + Ok(Box::new(PrimitiveScalar::::new(DataType::Int16, *i))) + } + ScalarValue::Int32(i) => { + Ok(Box::new(PrimitiveScalar::::new(DataType::Int32, *i))) + } + ScalarValue::Int64(i) => { + Ok(Box::new(PrimitiveScalar::::new(DataType::Int64, *i))) + } + ScalarValue::UInt8(u) => { + Ok(Box::new(PrimitiveScalar::::new(DataType::UInt8, *u))) + } + ScalarValue::UInt16(u) => { + Ok(Box::new(PrimitiveScalar::::new(DataType::UInt16, *u))) + } + ScalarValue::UInt32(u) => { + Ok(Box::new(PrimitiveScalar::::new(DataType::UInt32, *u))) + } + ScalarValue::UInt64(u) => { + Ok(Box::new(PrimitiveScalar::::new(DataType::UInt64, *u))) + } + ScalarValue::Utf8(s) => Ok(Box::new(Utf8Scalar::::new(s.clone()))), + ScalarValue::LargeUtf8(s) => Ok(Box::new(Utf8Scalar::::new(s.clone()))), + ScalarValue::Binary(b) => Ok(Box::new(BinaryScalar::::new(b.clone()))), + ScalarValue::LargeBinary(b) => { + Ok(Box::new(BinaryScalar::::new(b.clone()))) + } + ScalarValue::Date32(i) => { + Ok(Box::new(PrimitiveScalar::::new(DataType::Date32, *i))) + } + ScalarValue::Date64(i) => { + Ok(Box::new(PrimitiveScalar::::new(DataType::Date64, *i))) + } + ScalarValue::TimestampSecond(i, tz) => { + Ok(Box::new(PrimitiveScalar::::new( + DataType::Timestamp(TimeUnit::Second, tz.clone()), + *i, + ))) + } + ScalarValue::TimestampMillisecond(i, tz) => { + Ok(Box::new(PrimitiveScalar::::new( + DataType::Timestamp(TimeUnit::Millisecond, tz.clone()), + *i, + ))) + } + ScalarValue::TimestampMicrosecond(i, tz) => { + Ok(Box::new(PrimitiveScalar::::new( + DataType::Timestamp(TimeUnit::Microsecond, tz.clone()), + *i, + ))) + } + ScalarValue::TimestampNanosecond(i, tz) => { + Ok(Box::new(PrimitiveScalar::::new( + DataType::Timestamp(TimeUnit::Nanosecond, tz.clone()), + *i, + ))) + } + ScalarValue::IntervalYearMonth(i) => { + Ok(Box::new(PrimitiveScalar::::new( + DataType::Interval(IntervalUnit::YearMonth), + *i, + ))) + } + + // List and IntervalDayTime comparison not possible in arrow2 + _ => Err(DataFusionError::Internal( + "Conversion not possible in arrow2".to_owned(), + )), + } + } +} + +impl TryFrom> for ScalarValue { + type Error = DataFusionError; + + fn try_from(s: PrimitiveScalar) -> Result { + match s.data_type() { + DataType::Timestamp(TimeUnit::Second, tz) => { + let s = s.as_any().downcast_ref::>().unwrap(); + Ok(ScalarValue::TimestampSecond(s.value(), tz.clone())) + } + DataType::Timestamp(TimeUnit::Microsecond, tz) => { + let s = s.as_any().downcast_ref::>().unwrap(); + Ok(ScalarValue::TimestampMicrosecond(s.value(), tz.clone())) + } + DataType::Timestamp(TimeUnit::Millisecond, tz) => { + let s = s.as_any().downcast_ref::>().unwrap(); + Ok(ScalarValue::TimestampMillisecond(s.value(), tz.clone())) + } + DataType::Timestamp(TimeUnit::Nanosecond, tz) => { + let s = s.as_any().downcast_ref::>().unwrap(); + Ok(ScalarValue::TimestampNanosecond(s.value(), tz.clone())) + } + _ => Err(DataFusionError::Internal( + format!( + "Conversion from arrow Scalar to Datafusion ScalarValue not implemented for: {:?}", s)) + ), + } + } +} + +impl TryFrom<&DataType> for ScalarValue { + type Error = DataFusionError; + + /// Create a Null instance of ScalarValue for this datatype + fn try_from(datatype: &DataType) -> Result { + Ok(match datatype { + DataType::Boolean => ScalarValue::Boolean(None), + DataType::Float64 => ScalarValue::Float64(None), + DataType::Float32 => ScalarValue::Float32(None), + DataType::Int8 => ScalarValue::Int8(None), + DataType::Int16 => ScalarValue::Int16(None), + DataType::Int32 => ScalarValue::Int32(None), + DataType::Int64 => ScalarValue::Int64(None), + DataType::UInt8 => ScalarValue::UInt8(None), + DataType::UInt16 => ScalarValue::UInt16(None), + DataType::UInt32 => ScalarValue::UInt32(None), + DataType::UInt64 => ScalarValue::UInt64(None), + DataType::Decimal(precision, scale) => { + ScalarValue::Decimal128(None, *precision, *scale) + } + DataType::Utf8 => ScalarValue::Utf8(None), + DataType::LargeUtf8 => ScalarValue::LargeUtf8(None), + DataType::Date32 => ScalarValue::Date32(None), + DataType::Date64 => ScalarValue::Date64(None), + DataType::Timestamp(TimeUnit::Second, tz_opt) => { + ScalarValue::TimestampSecond(None, tz_opt.clone()) + } + DataType::Timestamp(TimeUnit::Millisecond, tz_opt) => { + ScalarValue::TimestampMillisecond(None, tz_opt.clone()) + } + DataType::Timestamp(TimeUnit::Microsecond, tz_opt) => { + ScalarValue::TimestampMicrosecond(None, tz_opt.clone()) + } + DataType::Timestamp(TimeUnit::Nanosecond, tz_opt) => { + ScalarValue::TimestampNanosecond(None, tz_opt.clone()) + } + DataType::Dictionary(_index_type, value_type, _) => { + value_type.as_ref().try_into()? + } + DataType::List(ref nested_type) => { + ScalarValue::List(None, Box::new(nested_type.data_type().clone())) + } + DataType::Struct(fields) => { + ScalarValue::Struct(None, Box::new(fields.clone())) + } + _ => { + return Err(DataFusionError::NotImplemented(format!( + "Can't create a scalar from data_type \"{:?}\"", + datatype + ))); + } + }) + } +} + +macro_rules! format_option { + ($F:expr, $EXPR:expr) => {{ + match $EXPR { + Some(e) => write!($F, "{}", e), + None => write!($F, "NULL"), + } + }}; +} + +impl fmt::Display for ScalarValue { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match self { + ScalarValue::Decimal128(v, p, s) => { + write!(f, "{}", format_args!("{:?},{:?},{:?}", v, p, s))?; + } + ScalarValue::Boolean(e) => format_option!(f, e)?, + ScalarValue::Float32(e) => format_option!(f, e)?, + ScalarValue::Float64(e) => format_option!(f, e)?, + ScalarValue::Int8(e) => format_option!(f, e)?, + ScalarValue::Int16(e) => format_option!(f, e)?, + ScalarValue::Int32(e) => format_option!(f, e)?, + ScalarValue::Int64(e) => format_option!(f, e)?, + ScalarValue::UInt8(e) => format_option!(f, e)?, + ScalarValue::UInt16(e) => format_option!(f, e)?, + ScalarValue::UInt32(e) => format_option!(f, e)?, + ScalarValue::UInt64(e) => format_option!(f, e)?, + ScalarValue::TimestampSecond(e, _) => format_option!(f, e)?, + ScalarValue::TimestampMillisecond(e, _) => format_option!(f, e)?, + ScalarValue::TimestampMicrosecond(e, _) => format_option!(f, e)?, + ScalarValue::TimestampNanosecond(e, _) => format_option!(f, e)?, + ScalarValue::Utf8(e) => format_option!(f, e)?, + ScalarValue::LargeUtf8(e) => format_option!(f, e)?, + ScalarValue::Binary(e) => match e { + Some(l) => write!( + f, + "{}", + l.iter() + .map(|v| format!("{}", v)) + .collect::>() + .join(",") + )?, + None => write!(f, "NULL")?, + }, + ScalarValue::LargeBinary(e) => match e { + Some(l) => write!( + f, + "{}", + l.iter() + .map(|v| format!("{}", v)) + .collect::>() + .join(",") + )?, + None => write!(f, "NULL")?, + }, + ScalarValue::List(e, _) => match e { + Some(l) => write!( + f, + "{}", + l.iter() + .map(|v| format!("{}", v)) + .collect::>() + .join(",") + )?, + None => write!(f, "NULL")?, + }, + ScalarValue::Date32(e) => format_option!(f, e)?, + ScalarValue::Date64(e) => format_option!(f, e)?, + ScalarValue::IntervalDayTime(e) => format_option!(f, e)?, + ScalarValue::IntervalYearMonth(e) => format_option!(f, e)?, + ScalarValue::IntervalMonthDayNano(e) => format_option!(f, e)?, + ScalarValue::Struct(e, fields) => match e { + Some(l) => write!( + f, + "{{{}}}", + l.iter() + .zip(fields.iter()) + .map(|(value, field)| format!("{}:{}", field.name(), value)) + .collect::>() + .join(",") + )?, + None => write!(f, "NULL")?, + }, + }; + Ok(()) + } +} + +impl fmt::Debug for ScalarValue { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match self { + ScalarValue::Decimal128(_, _, _) => write!(f, "Decimal128({})", self), + ScalarValue::Boolean(_) => write!(f, "Boolean({})", self), + ScalarValue::Float32(_) => write!(f, "Float32({})", self), + ScalarValue::Float64(_) => write!(f, "Float64({})", self), + ScalarValue::Int8(_) => write!(f, "Int8({})", self), + ScalarValue::Int16(_) => write!(f, "Int16({})", self), + ScalarValue::Int32(_) => write!(f, "Int32({})", self), + ScalarValue::Int64(_) => write!(f, "Int64({})", self), + ScalarValue::UInt8(_) => write!(f, "UInt8({})", self), + ScalarValue::UInt16(_) => write!(f, "UInt16({})", self), + ScalarValue::UInt32(_) => write!(f, "UInt32({})", self), + ScalarValue::UInt64(_) => write!(f, "UInt64({})", self), + ScalarValue::TimestampSecond(_, tz_opt) => { + write!(f, "TimestampSecond({}, {:?})", self, tz_opt) + } + ScalarValue::TimestampMillisecond(_, tz_opt) => { + write!(f, "TimestampMillisecond({}, {:?})", self, tz_opt) + } + ScalarValue::TimestampMicrosecond(_, tz_opt) => { + write!(f, "TimestampMicrosecond({}, {:?})", self, tz_opt) + } + ScalarValue::TimestampNanosecond(_, tz_opt) => { + write!(f, "TimestampNanosecond({}, {:?})", self, tz_opt) + } + ScalarValue::Utf8(None) => write!(f, "Utf8({})", self), + ScalarValue::Utf8(Some(_)) => write!(f, "Utf8(\"{}\")", self), + ScalarValue::LargeUtf8(None) => write!(f, "LargeUtf8({})", self), + ScalarValue::LargeUtf8(Some(_)) => write!(f, "LargeUtf8(\"{}\")", self), + ScalarValue::Binary(None) => write!(f, "Binary({})", self), + ScalarValue::Binary(Some(_)) => write!(f, "Binary(\"{}\")", self), + ScalarValue::LargeBinary(None) => write!(f, "LargeBinary({})", self), + ScalarValue::LargeBinary(Some(_)) => write!(f, "LargeBinary(\"{}\")", self), + ScalarValue::List(_, dt) => write!(f, "List[{:?}]([{}])", dt, self), + ScalarValue::Date32(_) => write!(f, "Date32(\"{}\")", self), + ScalarValue::Date64(_) => write!(f, "Date64(\"{}\")", self), + ScalarValue::IntervalDayTime(_) => { + write!(f, "IntervalDayTime(\"{}\")", self) + } + ScalarValue::IntervalYearMonth(_) => { + write!(f, "IntervalYearMonth(\"{}\")", self) + } + ScalarValue::IntervalMonthDayNano(_) => { + write!(f, "IntervalMonthDayNano(\"{}\")", self) + } + ScalarValue::Struct(e, fields) => { + // Use Debug representation of field values + match e { + Some(l) => write!( + f, + "Struct({{{}}})", + l.iter() + .zip(fields.iter()) + .map(|(value, field)| format!("{}:{:?}", field.name(), value)) + .collect::>() + .join(",") + ), + None => write!(f, "Struct(NULL)"), + } + } + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::field_util::struct_array_from; + + #[test] + fn scalar_decimal_test() { + let decimal_value = ScalarValue::Decimal128(Some(123), 10, 1); + assert_eq!(DataType::Decimal(10, 1), decimal_value.get_datatype()); + let try_into_value: i128 = decimal_value.clone().try_into().unwrap(); + assert_eq!(123_i128, try_into_value); + assert!(!decimal_value.is_null()); + let neg_decimal_value = decimal_value.arithmetic_negate(); + match neg_decimal_value { + ScalarValue::Decimal128(v, _, _) => { + assert_eq!(-123, v.unwrap()); + } + _ => { + unreachable!(); + } + } + + // decimal scalar to array + let array = decimal_value.to_array(); + let array = array.as_any().downcast_ref::().unwrap(); + assert_eq!(1, array.len()); + assert_eq!(DataType::Decimal(10, 1), array.data_type().clone()); + assert_eq!(123i128, array.value(0)); + + // decimal scalar to array with size + let array = decimal_value.to_array_of_size(10); + let array_decimal = array.as_any().downcast_ref::().unwrap(); + assert_eq!(10, array.len()); + assert_eq!(DataType::Decimal(10, 1), array.data_type().clone()); + assert_eq!(123i128, array_decimal.value(0)); + assert_eq!(123i128, array_decimal.value(9)); + // test eq array + assert!(decimal_value.eq_array(&array, 1)); + assert!(decimal_value.eq_array(&array, 5)); + // test try from array + assert_eq!( + decimal_value, + ScalarValue::try_from_array(&array, 5).unwrap() + ); + + assert_eq!( + decimal_value, + ScalarValue::try_new_decimal128(123, 10, 1).unwrap() + ); + + // test compare + let left = ScalarValue::Decimal128(Some(123), 10, 2); + let right = ScalarValue::Decimal128(Some(124), 10, 2); + assert!(!left.eq(&right)); + let result = left < right; + assert!(result); + let result = left <= right; + assert!(result); + let right = ScalarValue::Decimal128(Some(124), 10, 3); + // make sure that two decimals with diff datatype can't be compared. + let result = left.partial_cmp(&right); + assert_eq!(None, result); + + let decimal_vec = vec![ + ScalarValue::Decimal128(Some(1), 10, 2), + ScalarValue::Decimal128(Some(2), 10, 2), + ScalarValue::Decimal128(Some(3), 10, 2), + ]; + // convert the vec to decimal array and check the result + let array = ScalarValue::iter_to_array(decimal_vec.into_iter()).unwrap(); + assert_eq!(3, array.len()); + assert_eq!(DataType::Decimal(10, 2), array.data_type().clone()); + + let decimal_vec = vec![ + ScalarValue::Decimal128(Some(1), 10, 2), + ScalarValue::Decimal128(Some(2), 10, 2), + ScalarValue::Decimal128(Some(3), 10, 2), + ScalarValue::Decimal128(None, 10, 2), + ]; + let array: ArrayRef = + ScalarValue::iter_to_array(decimal_vec.into_iter()).unwrap(); + assert_eq!(4, array.len()); + assert_eq!(DataType::Decimal(10, 2), array.data_type().clone()); + + assert!(ScalarValue::try_new_decimal128(1, 10, 2) + .unwrap() + .eq_array(&array, 0)); + assert!(ScalarValue::try_new_decimal128(2, 10, 2) + .unwrap() + .eq_array(&array, 1)); + assert!(ScalarValue::try_new_decimal128(3, 10, 2) + .unwrap() + .eq_array(&array, 2)); + assert_eq!( + ScalarValue::Decimal128(None, 10, 2), + ScalarValue::try_from_array(&array, 3).unwrap() + ); + assert_eq!( + ScalarValue::Decimal128(None, 10, 2), + ScalarValue::try_from_array(&array, 4).unwrap() + ); + } + + #[test] + fn scalar_value_to_array_u64() { + let value = ScalarValue::UInt64(Some(13u64)); + let array = value.to_array(); + let array = array.as_any().downcast_ref::().unwrap(); + assert_eq!(array.len(), 1); + assert!(!array.is_null(0)); + assert_eq!(array.value(0), 13); + + let value = ScalarValue::UInt64(None); + let array = value.to_array(); + let array = array.as_any().downcast_ref::().unwrap(); + assert_eq!(array.len(), 1); + assert!(array.is_null(0)); + } + + #[test] + fn scalar_value_to_array_u32() { + let value = ScalarValue::UInt32(Some(13u32)); + let array = value.to_array(); + let array = array.as_any().downcast_ref::().unwrap(); + assert_eq!(array.len(), 1); + assert!(!array.is_null(0)); + assert_eq!(array.value(0), 13); + + let value = ScalarValue::UInt32(None); + let array = value.to_array(); + let array = array.as_any().downcast_ref::().unwrap(); + assert_eq!(array.len(), 1); + assert!(array.is_null(0)); + } + + #[test] + fn scalar_list_null_to_array() { + let list_array_ref = + ScalarValue::List(None, Box::new(DataType::UInt64)).to_array(); + let list_array = list_array_ref + .as_any() + .downcast_ref::>() + .unwrap(); + + assert!(list_array.is_null(0)); + assert_eq!(list_array.len(), 1); + assert_eq!(list_array.values().len(), 0); + } + + #[test] + fn scalar_list_to_array() { + let list_array_ref = ScalarValue::List( + Some(Box::new(vec![ + ScalarValue::UInt64(Some(100)), + ScalarValue::UInt64(None), + ScalarValue::UInt64(Some(101)), + ])), + Box::new(DataType::UInt64), + ) + .to_array(); + + let list_array = list_array_ref + .as_any() + .downcast_ref::>() + .unwrap(); + assert_eq!(list_array.len(), 1); + assert_eq!(list_array.values().len(), 3); + + let prim_array_ref = list_array.value(0); + let prim_array = prim_array_ref + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(prim_array.len(), 3); + assert_eq!(prim_array.value(0), 100); + assert!(prim_array.is_null(1)); + assert_eq!(prim_array.value(2), 101); + } + + /// Creates array directly and via ScalarValue and ensures they are the same + macro_rules! check_scalar_iter { + ($SCALAR_T:ident, $ARRAYTYPE:ident, $INPUT:expr) => {{ + let scalars: Vec<_> = + $INPUT.iter().map(|v| ScalarValue::$SCALAR_T(*v)).collect(); + + let array = ScalarValue::iter_to_array(scalars.into_iter()).unwrap(); + + let expected = $ARRAYTYPE::from($INPUT).as_arc(); + + assert_eq!(&array, &expected); + }}; + } + + /// Creates array directly and via ScalarValue and ensures they are the same + /// but for variants that carry a timezone field. + macro_rules! check_scalar_iter_tz { + ($SCALAR_T:ident, $INPUT:expr) => {{ + let scalars: Vec<_> = $INPUT + .iter() + .map(|v| ScalarValue::$SCALAR_T(*v, None)) + .collect(); + + let array = ScalarValue::iter_to_array(scalars.into_iter()).unwrap(); + + let expected: Arc = Arc::new(Int64Array::from($INPUT)); + + assert_eq!(&array, &expected); + }}; + } + + /// Creates array directly and via ScalarValue and ensures they + /// are the same, for string arrays + macro_rules! check_scalar_iter_string { + ($SCALAR_T:ident, $ARRAYTYPE:ident, $INPUT:expr) => {{ + let scalars: Vec<_> = $INPUT + .iter() + .map(|v| ScalarValue::$SCALAR_T(v.map(|v| v.to_string()))) + .collect(); + + let array = ScalarValue::iter_to_array(scalars.into_iter()).unwrap(); + + let expected: Arc = Arc::new($ARRAYTYPE::from($INPUT)); + + assert_eq!(&array, &expected); + }}; + } + + /// Creates array directly and via ScalarValue and ensures they + /// are the same, for binary arrays + macro_rules! check_scalar_iter_binary { + ($SCALAR_T:ident, $ARRAYTYPE:ident, $INPUT:expr) => {{ + let scalars: Vec<_> = $INPUT + .iter() + .map(|v| ScalarValue::$SCALAR_T(v.map(|v| v.to_vec()))) + .collect(); + + let array = ScalarValue::iter_to_array(scalars.into_iter()).unwrap(); + + let expected: $ARRAYTYPE = + $INPUT.iter().map(|v| v.map(|v| v.to_vec())).collect(); + + let expected: Arc = Arc::new(expected); + + assert_eq!(&array, &expected); + }}; + } + + #[test] + fn scalar_iter_to_array_boolean() { + check_scalar_iter!( + Boolean, + MutableBooleanArray, + vec![Some(true), None, Some(false)] + ); + check_scalar_iter!(Float32, Float32Vec, vec![Some(1.9), None, Some(-2.1)]); + check_scalar_iter!(Float64, Float64Vec, vec![Some(1.9), None, Some(-2.1)]); + + check_scalar_iter!(Int8, Int8Vec, vec![Some(1), None, Some(3)]); + check_scalar_iter!(Int16, Int16Vec, vec![Some(1), None, Some(3)]); + check_scalar_iter!(Int32, Int32Vec, vec![Some(1), None, Some(3)]); + check_scalar_iter!(Int64, Int64Vec, vec![Some(1), None, Some(3)]); + + check_scalar_iter!(UInt8, UInt8Vec, vec![Some(1), None, Some(3)]); + check_scalar_iter!(UInt16, UInt16Vec, vec![Some(1), None, Some(3)]); + check_scalar_iter!(UInt32, UInt32Vec, vec![Some(1), None, Some(3)]); + check_scalar_iter!(UInt64, UInt64Vec, vec![Some(1), None, Some(3)]); + + check_scalar_iter_tz!(TimestampSecond, vec![Some(1), None, Some(3)]); + check_scalar_iter_tz!(TimestampMillisecond, vec![Some(1), None, Some(3)]); + check_scalar_iter_tz!(TimestampMicrosecond, vec![Some(1), None, Some(3)]); + check_scalar_iter_tz!(TimestampNanosecond, vec![Some(1), None, Some(3)]); + + check_scalar_iter_string!( + Utf8, + StringArray, + vec![Some("foo"), None, Some("bar")] + ); + check_scalar_iter_string!( + LargeUtf8, + LargeStringArray, + vec![Some("foo"), None, Some("bar")] + ); + check_scalar_iter_binary!( + Binary, + SmallBinaryArray, + vec![Some(b"foo"), None, Some(b"bar")] + ); + check_scalar_iter_binary!( + LargeBinary, + LargeBinaryArray, + vec![Some(b"foo"), None, Some(b"bar")] + ); + } + + #[test] + fn scalar_iter_to_array_empty() { + let scalars = vec![] as Vec; + + let result = ScalarValue::iter_to_array(scalars.into_iter()).unwrap_err(); + assert!( + result + .to_string() + .contains("Empty iterator passed to ScalarValue::iter_to_array"), + "{}", + result + ); + } + + #[test] + fn scalar_iter_to_array_mismatched_types() { + use ScalarValue::*; + // If the scalar values are not all the correct type, error here + let scalars: Vec = vec![Boolean(Some(true)), Int32(Some(5))]; + + let result = ScalarValue::iter_to_array(scalars.into_iter()).unwrap_err(); + assert!(result.to_string().contains("Inconsistent types in ScalarValue::iter_to_array. Expected Boolean, got Int32(5)"), + "{}", result); + } + + #[test] + fn scalar_try_from_array_null() { + let array = vec![Some(33), None].into_iter().collect::(); + let array: ArrayRef = Arc::new(array); + + assert_eq!( + ScalarValue::Int64(Some(33)), + ScalarValue::try_from_array(&array, 0).unwrap() + ); + assert_eq!( + ScalarValue::Int64(None), + ScalarValue::try_from_array(&array, 1).unwrap() + ); + } + + #[test] + fn scalar_try_from_dict_datatype() { + let data_type = + DataType::Dictionary(IntegerType::Int8, Box::new(DataType::Utf8), false); + let data_type = &data_type; + assert_eq!(ScalarValue::Utf8(None), data_type.try_into().unwrap()) + } + + #[test] + fn size_of_scalar() { + // Since ScalarValues are used in a non trivial number of places, + // making it larger means significant more memory consumption + // per distinct value. + #[cfg(target_arch = "aarch64")] + assert_eq!(std::mem::size_of::(), 64); + + #[cfg(target_arch = "amd64")] + assert_eq!(std::mem::size_of::(), 48); + } + + #[test] + fn scalar_eq_array() { + // Validate that eq_array has the same semantics as ScalarValue::eq + macro_rules! make_typed_vec { + ($INPUT:expr, $TYPE:ident) => {{ + $INPUT + .iter() + .map(|v| v.map(|v| v as $TYPE)) + .collect::>() + }}; + } + + let bool_vals = vec![Some(true), None, Some(false)]; + let f32_vals = vec![Some(-1.0), None, Some(1.0)]; + let f64_vals = make_typed_vec!(f32_vals, f64); + + let i8_vals = vec![Some(-1), None, Some(1)]; + let i16_vals = make_typed_vec!(i8_vals, i16); + let i32_vals = make_typed_vec!(i8_vals, i32); + let i64_vals = make_typed_vec!(i8_vals, i64); + let days_ms_vals = &[Some(days_ms::new(1, 2)), None, Some(days_ms::new(10, 0))]; + + let u8_vals = vec![Some(0), None, Some(1)]; + let u16_vals = make_typed_vec!(u8_vals, u16); + let u32_vals = make_typed_vec!(u8_vals, u32); + let u64_vals = make_typed_vec!(u8_vals, u64); + + let str_vals = &[Some("foo"), None, Some("bar")]; + + /// Test each value in `scalar` with the corresponding element + /// at `array`. Assumes each element is unique (aka not equal + /// with all other indexes) + struct TestCase { + array: ArrayRef, + scalars: Vec, + } + + /// Create a test case for casing the input to the specified array type + macro_rules! make_test_case { + ($INPUT:expr, $ARRAY_TY:ident, $SCALAR_TY:ident) => {{ + TestCase { + array: Arc::new($INPUT.iter().collect::<$ARRAY_TY>()), + scalars: $INPUT.iter().map(|v| ScalarValue::$SCALAR_TY(*v)).collect(), + } + }}; + + ($INPUT:expr, $ARRAY_TY:ident, $SCALAR_TY:ident, $TZ:expr) => {{ + let tz = $TZ; + TestCase { + array: Arc::new($INPUT.iter().collect::<$ARRAY_TY>()), + scalars: $INPUT + .iter() + .map(|v| ScalarValue::$SCALAR_TY(*v, tz.clone())) + .collect(), + } + }}; + } + + macro_rules! make_date_test_case { + ($INPUT:expr, $ARRAY_TY:ident, $SCALAR_TY:ident) => {{ + TestCase { + array: Arc::new($ARRAY_TY::from($INPUT).to(DataType::$SCALAR_TY)), + scalars: $INPUT.iter().map(|v| ScalarValue::$SCALAR_TY(*v)).collect(), + } + }}; + } + + macro_rules! make_ts_test_case { + ($INPUT:expr, $ARROW_TU:ident, $SCALAR_TY:ident, $TZ:expr) => {{ + TestCase { + array: Arc::new( + Int64Array::from($INPUT) + .to(DataType::Timestamp(TimeUnit::$ARROW_TU, $TZ)), + ), + scalars: $INPUT + .iter() + .map(|v| ScalarValue::$SCALAR_TY(*v, $TZ)) + .collect(), + } + }}; + } + + macro_rules! make_temporal_test_case { + ($INPUT:expr, $ARRAY_TY:ident, $ARROW_TU:ident, $SCALAR_TY:ident) => {{ + TestCase { + array: Arc::new( + $ARRAY_TY::from($INPUT) + .to(DataType::Interval(IntervalUnit::$ARROW_TU)), + ), + scalars: $INPUT.iter().map(|v| ScalarValue::$SCALAR_TY(*v)).collect(), + } + }}; + } + + macro_rules! make_str_test_case { + ($INPUT:expr, $ARRAY_TY:ident, $SCALAR_TY:ident) => {{ + TestCase { + array: Arc::new($INPUT.iter().cloned().collect::<$ARRAY_TY>()), + scalars: $INPUT + .iter() + .map(|v| ScalarValue::$SCALAR_TY(v.map(|v| v.to_string()))) + .collect(), + } + }}; + } + + macro_rules! make_binary_test_case { + ($INPUT:expr, $ARRAY_TY:ident, $SCALAR_TY:ident) => {{ + TestCase { + array: Arc::new($INPUT.iter().cloned().collect::<$ARRAY_TY>()), + scalars: $INPUT + .iter() + .map(|v| { + ScalarValue::$SCALAR_TY(v.map(|v| v.as_bytes().to_vec())) + }) + .collect(), + } + }}; + } + + /// create a test case for DictionaryArray<$INDEX_TY> + macro_rules! make_str_dict_test_case { + ($INPUT:expr, $INDEX_TY:ty, $SCALAR_TY:ident) => {{ + TestCase { + array: { + let mut array = MutableDictionaryArray::< + $INDEX_TY, + MutableUtf8Array, + >::new(); + array.try_extend(*($INPUT)).unwrap(); + let array: DictionaryArray<$INDEX_TY> = array.into(); + Arc::new(array) + }, + scalars: $INPUT + .iter() + .map(|v| ScalarValue::$SCALAR_TY(v.map(|v| v.to_string()))) + .collect(), + } + }}; + } + let utc_tz = Some("UTC".to_owned()); + let cases = vec![ + make_test_case!(bool_vals, BooleanArray, Boolean), + make_test_case!(f32_vals, Float32Array, Float32), + make_test_case!(f64_vals, Float64Array, Float64), + make_test_case!(i8_vals, Int8Array, Int8), + make_test_case!(i16_vals, Int16Array, Int16), + make_test_case!(i32_vals, Int32Array, Int32), + make_test_case!(i64_vals, Int64Array, Int64), + make_test_case!(u8_vals, UInt8Array, UInt8), + make_test_case!(u16_vals, UInt16Array, UInt16), + make_test_case!(u32_vals, UInt32Array, UInt32), + make_test_case!(u64_vals, UInt64Array, UInt64), + make_str_test_case!(str_vals, StringArray, Utf8), + make_str_test_case!(str_vals, LargeStringArray, LargeUtf8), + make_binary_test_case!(str_vals, SmallBinaryArray, Binary), + make_binary_test_case!(str_vals, LargeBinaryArray, LargeBinary), + make_date_test_case!(&i32_vals, Int32Array, Date32), + make_date_test_case!(&i64_vals, Int64Array, Date64), + make_ts_test_case!(&i64_vals, Second, TimestampSecond, utc_tz.clone()), + make_ts_test_case!( + &i64_vals, + Millisecond, + TimestampMillisecond, + utc_tz.clone() + ), + make_ts_test_case!( + &i64_vals, + Microsecond, + TimestampMicrosecond, + utc_tz.clone() + ), + make_ts_test_case!( + &i64_vals, + Nanosecond, + TimestampNanosecond, + utc_tz.clone() + ), + make_ts_test_case!(&i64_vals, Second, TimestampSecond, None), + make_ts_test_case!(&i64_vals, Millisecond, TimestampMillisecond, None), + make_ts_test_case!(&i64_vals, Microsecond, TimestampMicrosecond, None), + make_ts_test_case!(&i64_vals, Nanosecond, TimestampNanosecond, None), + make_temporal_test_case!(&i32_vals, Int32Array, YearMonth, IntervalYearMonth), + make_temporal_test_case!(days_ms_vals, DaysMsArray, DayTime, IntervalDayTime), + make_str_dict_test_case!(str_vals, i8, Utf8), + make_str_dict_test_case!(str_vals, i16, Utf8), + make_str_dict_test_case!(str_vals, i32, Utf8), + make_str_dict_test_case!(str_vals, i64, Utf8), + make_str_dict_test_case!(str_vals, u8, Utf8), + make_str_dict_test_case!(str_vals, u16, Utf8), + make_str_dict_test_case!(str_vals, u32, Utf8), + make_str_dict_test_case!(str_vals, u64, Utf8), + ]; + + for case in cases { + let TestCase { array, scalars } = case; + assert_eq!(array.len(), scalars.len()); + + for (index, scalar) in scalars.into_iter().enumerate() { + assert!( + scalar.eq_array(&array, index), + "Expected {:?} to be equal to {:?} at index {}", + scalar, + array, + index + ); + + // test that all other elements are *not* equal + for other_index in 0..array.len() { + if index != other_index { + assert!( + !scalar.eq_array(&array, other_index), + "Expected {:?} to be NOT equal to {:?} at index {}", + scalar, + array, + other_index + ); + } + } + } + } + } + + #[test] + fn scalar_partial_ordering() { + use ScalarValue::*; + + assert_eq!( + Int64(Some(33)).partial_cmp(&Int64(Some(0))), + Some(Ordering::Greater) + ); + assert_eq!( + Int64(Some(0)).partial_cmp(&Int64(Some(33))), + Some(Ordering::Less) + ); + assert_eq!( + Int64(Some(33)).partial_cmp(&Int64(Some(33))), + Some(Ordering::Equal) + ); + // For different data type, `partial_cmp` returns None. + assert_eq!(Int64(Some(33)).partial_cmp(&Int32(Some(33))), None); + assert_eq!(Int32(Some(33)).partial_cmp(&Int64(Some(33))), None); + + assert_eq!( + List( + Some(Box::new(vec![Int32(Some(1)), Int32(Some(5))])), + Box::new(DataType::Int32), + ) + .partial_cmp(&List( + Some(Box::new(vec![Int32(Some(1)), Int32(Some(5))])), + Box::new(DataType::Int32), + )), + Some(Ordering::Equal) + ); + + assert_eq!( + List( + Some(Box::new(vec![Int32(Some(10)), Int32(Some(5))])), + Box::new(DataType::Int32), + ) + .partial_cmp(&List( + Some(Box::new(vec![Int32(Some(1)), Int32(Some(5))])), + Box::new(DataType::Int32), + )), + Some(Ordering::Greater) + ); + + assert_eq!( + List( + Some(Box::new(vec![Int32(Some(1)), Int32(Some(5))])), + Box::new(DataType::Int32), + ) + .partial_cmp(&List( + Some(Box::new(vec![Int32(Some(10)), Int32(Some(5))])), + Box::new(DataType::Int32), + )), + Some(Ordering::Less) + ); + + // For different data type, `partial_cmp` returns None. + assert_eq!( + List( + Some(Box::new(vec![Int64(Some(1)), Int64(Some(5))])), + Box::new(DataType::Int64), + ) + .partial_cmp(&List( + Some(Box::new(vec![Int32(Some(1)), Int32(Some(5))])), + Box::new(DataType::Int32), + )), + None + ); + + assert_eq!( + ScalarValue::from(vec![ + ("A", ScalarValue::from(1.0)), + ("B", ScalarValue::from("Z")), + ]) + .partial_cmp(&ScalarValue::from(vec![ + ("A", ScalarValue::from(2.0)), + ("B", ScalarValue::from("A")), + ])), + Some(Ordering::Less) + ); + + // For different struct fields, `partial_cmp` returns None. + assert_eq!( + ScalarValue::from(vec![ + ("A", ScalarValue::from(1.0)), + ("B", ScalarValue::from("Z")), + ]) + .partial_cmp(&ScalarValue::from(vec![ + ("a", ScalarValue::from(2.0)), + ("b", ScalarValue::from("A")), + ])), + None + ); + } + + #[test] + fn test_scalar_struct() { + let field_a = Field::new("A", DataType::Int32, false); + let field_b = Field::new("B", DataType::Boolean, false); + let field_c = Field::new("C", DataType::Utf8, false); + + let field_e = Field::new("e", DataType::Int16, false); + let field_f = Field::new("f", DataType::Int64, false); + let field_d = Field::new( + "D", + DataType::Struct(vec![field_e.clone(), field_f.clone()]), + false, + ); + + let scalar = ScalarValue::Struct( + Some(Box::new(vec![ + ScalarValue::Int32(Some(23)), + ScalarValue::Boolean(Some(false)), + ScalarValue::Utf8(Some("Hello".to_string())), + ScalarValue::from(vec![ + ("e", ScalarValue::from(2i16)), + ("f", ScalarValue::from(3i64)), + ]), + ])), + Box::new(vec![ + field_a.clone(), + field_b.clone(), + field_c.clone(), + field_d.clone(), + ]), + ); + let _dt = scalar.get_datatype(); + let _sub_dt = field_d.data_type.clone(); + + // Check Display + assert_eq!( + format!("{}", scalar), + String::from("{A:23,B:false,C:Hello,D:{e:2,f:3}}") + ); + + // Check Debug + assert_eq!( + format!("{:?}", scalar), + String::from( + r#"Struct({A:Int32(23),B:Boolean(false),C:Utf8("Hello"),D:Struct({e:Int16(2),f:Int64(3)})})"# + ) + ); + + // Convert to length-2 array + let array = scalar.to_array_of_size(2); + let expected_vals = vec![ + (field_a.clone(), Int32Vec::from_slice(&[23, 23]).as_arc()), + ( + field_b.clone(), + Arc::new(BooleanArray::from_slice(&vec![false, false])) as ArrayRef, + ), + ( + field_c.clone(), + Arc::new(StringArray::from_slice(&vec!["Hello", "Hello"])) as ArrayRef, + ), + ( + field_d.clone(), + Arc::new(StructArray::from_data( + DataType::Struct(vec![field_e.clone(), field_f.clone()]), + vec![ + Int16Vec::from_slice(&[2, 2]).as_arc(), + Int64Vec::from_slice(&[3, 3]).as_arc(), + ], + None, + )) as ArrayRef, + ), + ]; + + let expected = Arc::new(struct_array_from(expected_vals)) as ArrayRef; + assert_eq!(&array, &expected); + + // Construct from second element of ArrayRef + let constructed = ScalarValue::try_from_array(&expected, 1).unwrap(); + assert_eq!(constructed, scalar); + + // None version + let none_scalar = ScalarValue::try_from(array.data_type()).unwrap(); + assert!(none_scalar.is_null()); + assert_eq!(format!("{:?}", none_scalar), String::from("Struct(NULL)")); + + // Construct with convenience From> + let constructed = ScalarValue::from(vec![ + ("A", ScalarValue::from(23i32)), + ("B", ScalarValue::from(false)), + ("C", ScalarValue::from("Hello")), + ( + "D", + ScalarValue::from(vec![ + ("e", ScalarValue::from(2i16)), + ("f", ScalarValue::from(3i64)), + ]), + ), + ]); + assert_eq!(constructed, scalar); + + // Build Array from Vec of structs + let scalars = vec![ + ScalarValue::from(vec![ + ("A", ScalarValue::from(23i32)), + ("B", ScalarValue::from(false)), + ("C", ScalarValue::from("Hello")), + ( + "D", + ScalarValue::from(vec![ + ("e", ScalarValue::from(2i16)), + ("f", ScalarValue::from(3i64)), + ]), + ), + ]), + ScalarValue::from(vec![ + ("A", ScalarValue::from(7i32)), + ("B", ScalarValue::from(true)), + ("C", ScalarValue::from("World")), + ( + "D", + ScalarValue::from(vec![ + ("e", ScalarValue::from(4i16)), + ("f", ScalarValue::from(5i64)), + ]), + ), + ]), + ScalarValue::from(vec![ + ("A", ScalarValue::from(-1000i32)), + ("B", ScalarValue::from(true)), + ("C", ScalarValue::from("!!!!!")), + ( + "D", + ScalarValue::from(vec![ + ("e", ScalarValue::from(6i16)), + ("f", ScalarValue::from(7i64)), + ]), + ), + ]), + ]; + let array: ArrayRef = ScalarValue::iter_to_array(scalars).unwrap(); + + let expected = Arc::new(struct_array_from(vec![ + (field_a, Int32Vec::from_slice(&[23, 7, -1000]).as_arc()), + ( + field_b, + Arc::new(BooleanArray::from_slice(&vec![false, true, true])) as ArrayRef, + ), + ( + field_c, + Arc::new(StringArray::from_slice(&vec!["Hello", "World", "!!!!!"])) + as ArrayRef, + ), + ( + field_d, + Arc::new(StructArray::from_data( + DataType::Struct(vec![field_e, field_f]), + vec![ + Int16Vec::from_slice(&[2, 4, 6]).as_arc(), + Int64Vec::from_slice(&[3, 5, 7]).as_arc(), + ], + None, + )) as ArrayRef, + ), + ])) as ArrayRef; + + assert_eq!(&array, &expected); + } + + #[test] + fn test_lists_in_struct() { + let field_a = Field::new("A", DataType::Utf8, false); + let field_primitive_list = Field::new( + "primitive_list", + DataType::List(Box::new(Field::new("item", DataType::Int32, true))), + false, + ); + + // Define primitive list scalars + let l0 = ScalarValue::List( + Some(Box::new(vec![ + ScalarValue::from(1i32), + ScalarValue::from(2i32), + ScalarValue::from(3i32), + ])), + Box::new(DataType::Int32), + ); + + let l1 = ScalarValue::List( + Some(Box::new(vec![ + ScalarValue::from(4i32), + ScalarValue::from(5i32), + ])), + Box::new(DataType::Int32), + ); + + let l2 = ScalarValue::List( + Some(Box::new(vec![ScalarValue::from(6i32)])), + Box::new(DataType::Int32), + ); + + // Define struct scalars + let s0 = ScalarValue::from(vec![ + ("A", ScalarValue::Utf8(Some(String::from("First")))), + ("primitive_list", l0), + ]); + + let s1 = ScalarValue::from(vec![ + ("A", ScalarValue::Utf8(Some(String::from("Second")))), + ("primitive_list", l1), + ]); + + let s2 = ScalarValue::from(vec![ + ("A", ScalarValue::Utf8(Some(String::from("Third")))), + ("primitive_list", l2), + ]); + + // iter_to_array for struct scalars + let array = + ScalarValue::iter_to_array(vec![s0.clone(), s1.clone(), s2.clone()]).unwrap(); + let array = array.as_any().downcast_ref::().unwrap(); + + let mut list_array = + MutableListArray::::new_with_capacity(Int32Vec::new(), 5); + list_array + .try_extend(vec![ + Some(vec![Some(1), Some(2), Some(3)]), + Some(vec![Some(4), Some(5)]), + Some(vec![Some(6)]), + ]) + .unwrap(); + let expected = struct_array_from(vec![ + ( + field_a.clone(), + Arc::new(StringArray::from_slice(&vec!["First", "Second", "Third"])) + as ArrayRef, + ), + (field_primitive_list.clone(), list_array.as_arc()), + ]); + + assert_eq!(array, &expected); + + // Define list-of-structs scalars + let nl0 = ScalarValue::List( + Some(Box::new(vec![s0.clone(), s1.clone()])), + Box::new(s0.get_datatype()), + ); + + let nl1 = + ScalarValue::List(Some(Box::new(vec![s2])), Box::new(s0.get_datatype())); + + let nl2 = + ScalarValue::List(Some(Box::new(vec![s1])), Box::new(s0.get_datatype())); + + // iter_to_array for list-of-struct + let array = ScalarValue::iter_to_array(vec![nl0, nl1, nl2]).unwrap(); + let array = array.as_any().downcast_ref::>().unwrap(); + + // Construct expected array with array builders + let field_a_builder = + Utf8Array::::from_slice(&vec!["First", "Second", "Third", "Second"]); + let primitive_value_builder = Int32Vec::with_capacity(5); + let mut field_primitive_list_builder = + MutableListArray::::new_with_capacity( + primitive_value_builder, + 0, + ); + field_primitive_list_builder + .try_push(Some(vec![1, 2, 3].into_iter().map(Option::Some))) + .unwrap(); + field_primitive_list_builder + .try_push(Some(vec![4, 5].into_iter().map(Option::Some))) + .unwrap(); + field_primitive_list_builder + .try_push(Some(vec![6].into_iter().map(Option::Some))) + .unwrap(); + field_primitive_list_builder + .try_push(Some(vec![4, 5].into_iter().map(Option::Some))) + .unwrap(); + let _element_builder = StructArray::from_data( + DataType::Struct(vec![field_a, field_primitive_list]), + vec![ + Arc::new(field_a_builder), + field_primitive_list_builder.as_arc(), + ], + None, + ); + //let expected = ListArray::(element_builder, 5); + eprintln!("array = {:?}", array); + //assert_eq!(array, &expected); + } + + #[test] + fn test_nested_lists() { + // Define inner list scalars + let l1 = ScalarValue::List( + Some(Box::new(vec![ + ScalarValue::List( + Some(Box::new(vec![ + ScalarValue::from(1i32), + ScalarValue::from(2i32), + ScalarValue::from(3i32), + ])), + Box::new(DataType::Int32), + ), + ScalarValue::List( + Some(Box::new(vec![ + ScalarValue::from(4i32), + ScalarValue::from(5i32), + ])), + Box::new(DataType::Int32), + ), + ])), + Box::new(DataType::List(Box::new(Field::new( + "item", + DataType::Int32, + true, + )))), + ); + + let l2 = ScalarValue::List( + Some(Box::new(vec![ + ScalarValue::List( + Some(Box::new(vec![ScalarValue::from(6i32)])), + Box::new(DataType::Int32), + ), + ScalarValue::List( + Some(Box::new(vec![ + ScalarValue::from(7i32), + ScalarValue::from(8i32), + ])), + Box::new(DataType::Int32), + ), + ])), + Box::new(DataType::List(Box::new(Field::new( + "item", + DataType::Int32, + true, + )))), + ); + + let l3 = ScalarValue::List( + Some(Box::new(vec![ScalarValue::List( + Some(Box::new(vec![ScalarValue::from(9i32)])), + Box::new(DataType::Int32), + )])), + Box::new(DataType::List(Box::new(Field::new( + "item", + DataType::Int32, + true, + )))), + ); + + let array = ScalarValue::iter_to_array(vec![l1, l2, l3]).unwrap(); + + // Construct expected array with array builders + let inner_builder = Int32Vec::with_capacity(8); + let middle_builder = + MutableListArray::::new_with_capacity(inner_builder, 0); + let mut outer_builder = + MutableListArray::>::new_with_capacity( + middle_builder, + 0, + ); + outer_builder + .try_push(Some(vec![ + Some(vec![Some(1), Some(2), Some(3)]), + Some(vec![Some(4), Some(5)]), + ])) + .unwrap(); + outer_builder + .try_push(Some(vec![ + Some(vec![Some(6)]), + Some(vec![Some(7), Some(8)]), + ])) + .unwrap(); + outer_builder + .try_push(Some(vec![Some(vec![Some(9)])])) + .unwrap(); + + let expected = outer_builder.as_arc(); + + assert_eq!(&array, &expected); + } + + #[test] + fn scalar_timestamp_ns_utc_timezone() { + let scalar = ScalarValue::TimestampNanosecond( + Some(1599566400000000000), + Some("UTC".to_owned()), + ); + + assert_eq!( + scalar.get_datatype(), + DataType::Timestamp(TimeUnit::Nanosecond, Some("UTC".to_owned())) + ); + + let array = scalar.to_array(); + assert_eq!(array.len(), 1); + assert_eq!( + array.data_type(), + &DataType::Timestamp(TimeUnit::Nanosecond, Some("UTC".to_owned())) + ); + + let newscalar = ScalarValue::try_from_array(&array, 0).unwrap(); + assert_eq!( + newscalar.get_datatype(), + DataType::Timestamp(TimeUnit::Nanosecond, Some("UTC".to_owned())) + ); + } +} diff --git a/datafusion-expr/Cargo.toml b/datafusion-expr/Cargo.toml new file mode 100644 index 000000000000..7da7c2602791 --- /dev/null +++ b/datafusion-expr/Cargo.toml @@ -0,0 +1,41 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +[package] +name = "datafusion-expr" +description = "DataFusion is an in-memory query engine that uses Apache Arrow as the memory model" +version = "6.0.0" +homepage = "https://github.com/apache/arrow-datafusion" +repository = "https://github.com/apache/arrow-datafusion" +readme = "../README.md" +authors = ["Apache Arrow "] +license = "Apache-2.0" +keywords = [ "arrow", "query", "sql" ] +edition = "2021" +rust-version = "1.58" + +[lib] +name = "datafusion_expr" +path = "src/lib.rs" + +[features] + +[dependencies] +datafusion-common = { path = "../datafusion-common", version = "6.0.0" } +arrow = { package = "arrow2", version = "0.9", default-features = false } +sqlparser = "0.13" +ahash = { version = "0.7", default-features = false } diff --git a/datafusion-expr/README.md b/datafusion-expr/README.md new file mode 100644 index 000000000000..25ac79c223c1 --- /dev/null +++ b/datafusion-expr/README.md @@ -0,0 +1,24 @@ + + +# DataFusion Expr + +This is an internal module for fundamental expression types of [DataFusion][df]. + +[df]: https://crates.io/crates/datafusion diff --git a/datafusion-expr/src/accumulator.rs b/datafusion-expr/src/accumulator.rs new file mode 100644 index 000000000000..599bd363fb61 --- /dev/null +++ b/datafusion-expr/src/accumulator.rs @@ -0,0 +1,44 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::ArrayRef; +use datafusion_common::{Result, ScalarValue}; +use std::fmt::Debug; + +/// An accumulator represents a stateful object that lives throughout the evaluation of multiple rows and +/// generically accumulates values. +/// +/// An accumulator knows how to: +/// * update its state from inputs via `update_batch` +/// * convert its internal state to a vector of scalar values +/// * update its state from multiple accumulators' states via `merge_batch` +/// * compute the final value from its internal state via `evaluate` +pub trait Accumulator: Send + Sync + Debug { + /// Returns the state of the accumulator at the end of the accumulation. + // in the case of an average on which we track `sum` and `n`, this function should return a vector + // of two values, sum and n. + fn state(&self) -> Result>; + + /// updates the accumulator's state from a vector of arrays. + fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()>; + + /// updates the accumulator's state from a vector of states. + fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()>; + + /// returns its value based on its current state. + fn evaluate(&self) -> Result; +} diff --git a/datafusion-expr/src/aggregate_function.rs b/datafusion-expr/src/aggregate_function.rs new file mode 100644 index 000000000000..8f12e88bf1a2 --- /dev/null +++ b/datafusion-expr/src/aggregate_function.rs @@ -0,0 +1,93 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use datafusion_common::{DataFusionError, Result}; +use std::{fmt, str::FromStr}; + +/// Enum of all built-in aggregate functions +#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Hash)] +pub enum AggregateFunction { + /// count + Count, + /// sum + Sum, + /// min + Min, + /// max + Max, + /// avg + Avg, + /// Approximate aggregate function + ApproxDistinct, + /// array_agg + ArrayAgg, + /// Variance (Sample) + Variance, + /// Variance (Population) + VariancePop, + /// Standard Deviation (Sample) + Stddev, + /// Standard Deviation (Population) + StddevPop, + /// Covariance (Sample) + Covariance, + /// Covariance (Population) + CovariancePop, + /// Correlation + Correlation, + /// Approximate continuous percentile function + ApproxPercentileCont, +} + +impl fmt::Display for AggregateFunction { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + // uppercase of the debug. + write!(f, "{}", format!("{:?}", self).to_uppercase()) + } +} + +impl FromStr for AggregateFunction { + type Err = DataFusionError; + fn from_str(name: &str) -> Result { + Ok(match name { + "min" => AggregateFunction::Min, + "max" => AggregateFunction::Max, + "count" => AggregateFunction::Count, + "avg" => AggregateFunction::Avg, + "sum" => AggregateFunction::Sum, + "approx_distinct" => AggregateFunction::ApproxDistinct, + "array_agg" => AggregateFunction::ArrayAgg, + "var" => AggregateFunction::Variance, + "var_samp" => AggregateFunction::Variance, + "var_pop" => AggregateFunction::VariancePop, + "stddev" => AggregateFunction::Stddev, + "stddev_samp" => AggregateFunction::Stddev, + "stddev_pop" => AggregateFunction::StddevPop, + "covar" => AggregateFunction::Covariance, + "covar_samp" => AggregateFunction::Covariance, + "covar_pop" => AggregateFunction::CovariancePop, + "corr" => AggregateFunction::Correlation, + "approx_percentile_cont" => AggregateFunction::ApproxPercentileCont, + _ => { + return Err(DataFusionError::Plan(format!( + "There is no built-in function named {}", + name + ))); + } + }) + } +} diff --git a/datafusion-expr/src/built_in_function.rs b/datafusion-expr/src/built_in_function.rs new file mode 100644 index 000000000000..0d5ee9792ecb --- /dev/null +++ b/datafusion-expr/src/built_in_function.rs @@ -0,0 +1,330 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Built-in functions + +use crate::Volatility; +use datafusion_common::{DataFusionError, Result}; +use std::fmt; +use std::str::FromStr; + +/// Enum of all built-in scalar functions +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub enum BuiltinScalarFunction { + // math functions + /// abs + Abs, + /// acos + Acos, + /// asin + Asin, + /// atan + Atan, + /// ceil + Ceil, + /// cos + Cos, + /// Digest + Digest, + /// exp + Exp, + /// floor + Floor, + /// ln, Natural logarithm + Ln, + /// log, same as log10 + Log, + /// log10 + Log10, + /// log2 + Log2, + /// round + Round, + /// signum + Signum, + /// sin + Sin, + /// sqrt + Sqrt, + /// tan + Tan, + /// trunc + Trunc, + + // string functions + /// construct an array from columns + Array, + /// ascii + Ascii, + /// bit_length + BitLength, + /// btrim + Btrim, + /// character_length + CharacterLength, + /// chr + Chr, + /// concat + Concat, + /// concat_ws + ConcatWithSeparator, + /// date_part + DatePart, + /// date_trunc + DateTrunc, + /// initcap + InitCap, + /// left + Left, + /// lpad + Lpad, + /// lower + Lower, + /// ltrim + Ltrim, + /// md5 + MD5, + /// nullif + NullIf, + /// octet_length + OctetLength, + /// random + Random, + /// regexp_replace + RegexpReplace, + /// repeat + Repeat, + /// replace + Replace, + /// reverse + Reverse, + /// right + Right, + /// rpad + Rpad, + /// rtrim + Rtrim, + /// sha224 + SHA224, + /// sha256 + SHA256, + /// sha384 + SHA384, + /// Sha512 + SHA512, + /// split_part + SplitPart, + /// starts_with + StartsWith, + /// strpos + Strpos, + /// substr + Substr, + /// to_hex + ToHex, + /// to_timestamp + ToTimestamp, + /// to_timestamp_millis + ToTimestampMillis, + /// to_timestamp_micros + ToTimestampMicros, + /// to_timestamp_seconds + ToTimestampSeconds, + ///now + Now, + /// translate + Translate, + /// trim + Trim, + /// upper + Upper, + /// regexp_match + RegexpMatch, +} + +impl BuiltinScalarFunction { + /// an allowlist of functions to take zero arguments, so that they will get special treatment + /// while executing. + pub fn supports_zero_argument(&self) -> bool { + matches!( + self, + BuiltinScalarFunction::Random | BuiltinScalarFunction::Now + ) + } + /// Returns the [Volatility] of the builtin function. + pub fn volatility(&self) -> Volatility { + match self { + //Immutable scalar builtins + BuiltinScalarFunction::Abs => Volatility::Immutable, + BuiltinScalarFunction::Acos => Volatility::Immutable, + BuiltinScalarFunction::Asin => Volatility::Immutable, + BuiltinScalarFunction::Atan => Volatility::Immutable, + BuiltinScalarFunction::Ceil => Volatility::Immutable, + BuiltinScalarFunction::Cos => Volatility::Immutable, + BuiltinScalarFunction::Exp => Volatility::Immutable, + BuiltinScalarFunction::Floor => Volatility::Immutable, + BuiltinScalarFunction::Ln => Volatility::Immutable, + BuiltinScalarFunction::Log => Volatility::Immutable, + BuiltinScalarFunction::Log10 => Volatility::Immutable, + BuiltinScalarFunction::Log2 => Volatility::Immutable, + BuiltinScalarFunction::Round => Volatility::Immutable, + BuiltinScalarFunction::Signum => Volatility::Immutable, + BuiltinScalarFunction::Sin => Volatility::Immutable, + BuiltinScalarFunction::Sqrt => Volatility::Immutable, + BuiltinScalarFunction::Tan => Volatility::Immutable, + BuiltinScalarFunction::Trunc => Volatility::Immutable, + BuiltinScalarFunction::Array => Volatility::Immutable, + BuiltinScalarFunction::Ascii => Volatility::Immutable, + BuiltinScalarFunction::BitLength => Volatility::Immutable, + BuiltinScalarFunction::Btrim => Volatility::Immutable, + BuiltinScalarFunction::CharacterLength => Volatility::Immutable, + BuiltinScalarFunction::Chr => Volatility::Immutable, + BuiltinScalarFunction::Concat => Volatility::Immutable, + BuiltinScalarFunction::ConcatWithSeparator => Volatility::Immutable, + BuiltinScalarFunction::DatePart => Volatility::Immutable, + BuiltinScalarFunction::DateTrunc => Volatility::Immutable, + BuiltinScalarFunction::InitCap => Volatility::Immutable, + BuiltinScalarFunction::Left => Volatility::Immutable, + BuiltinScalarFunction::Lpad => Volatility::Immutable, + BuiltinScalarFunction::Lower => Volatility::Immutable, + BuiltinScalarFunction::Ltrim => Volatility::Immutable, + BuiltinScalarFunction::MD5 => Volatility::Immutable, + BuiltinScalarFunction::NullIf => Volatility::Immutable, + BuiltinScalarFunction::OctetLength => Volatility::Immutable, + BuiltinScalarFunction::RegexpReplace => Volatility::Immutable, + BuiltinScalarFunction::Repeat => Volatility::Immutable, + BuiltinScalarFunction::Replace => Volatility::Immutable, + BuiltinScalarFunction::Reverse => Volatility::Immutable, + BuiltinScalarFunction::Right => Volatility::Immutable, + BuiltinScalarFunction::Rpad => Volatility::Immutable, + BuiltinScalarFunction::Rtrim => Volatility::Immutable, + BuiltinScalarFunction::SHA224 => Volatility::Immutable, + BuiltinScalarFunction::SHA256 => Volatility::Immutable, + BuiltinScalarFunction::SHA384 => Volatility::Immutable, + BuiltinScalarFunction::SHA512 => Volatility::Immutable, + BuiltinScalarFunction::Digest => Volatility::Immutable, + BuiltinScalarFunction::SplitPart => Volatility::Immutable, + BuiltinScalarFunction::StartsWith => Volatility::Immutable, + BuiltinScalarFunction::Strpos => Volatility::Immutable, + BuiltinScalarFunction::Substr => Volatility::Immutable, + BuiltinScalarFunction::ToHex => Volatility::Immutable, + BuiltinScalarFunction::ToTimestamp => Volatility::Immutable, + BuiltinScalarFunction::ToTimestampMillis => Volatility::Immutable, + BuiltinScalarFunction::ToTimestampMicros => Volatility::Immutable, + BuiltinScalarFunction::ToTimestampSeconds => Volatility::Immutable, + BuiltinScalarFunction::Translate => Volatility::Immutable, + BuiltinScalarFunction::Trim => Volatility::Immutable, + BuiltinScalarFunction::Upper => Volatility::Immutable, + BuiltinScalarFunction::RegexpMatch => Volatility::Immutable, + + //Stable builtin functions + BuiltinScalarFunction::Now => Volatility::Stable, + + //Volatile builtin functions + BuiltinScalarFunction::Random => Volatility::Volatile, + } + } +} + +impl fmt::Display for BuiltinScalarFunction { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + // lowercase of the debug. + write!(f, "{}", format!("{:?}", self).to_lowercase()) + } +} + +impl FromStr for BuiltinScalarFunction { + type Err = DataFusionError; + fn from_str(name: &str) -> Result { + Ok(match name { + // math functions + "abs" => BuiltinScalarFunction::Abs, + "acos" => BuiltinScalarFunction::Acos, + "asin" => BuiltinScalarFunction::Asin, + "atan" => BuiltinScalarFunction::Atan, + "ceil" => BuiltinScalarFunction::Ceil, + "cos" => BuiltinScalarFunction::Cos, + "exp" => BuiltinScalarFunction::Exp, + "floor" => BuiltinScalarFunction::Floor, + "ln" => BuiltinScalarFunction::Ln, + "log" => BuiltinScalarFunction::Log, + "log10" => BuiltinScalarFunction::Log10, + "log2" => BuiltinScalarFunction::Log2, + "round" => BuiltinScalarFunction::Round, + "signum" => BuiltinScalarFunction::Signum, + "sin" => BuiltinScalarFunction::Sin, + "sqrt" => BuiltinScalarFunction::Sqrt, + "tan" => BuiltinScalarFunction::Tan, + "trunc" => BuiltinScalarFunction::Trunc, + + // string functions + "array" => BuiltinScalarFunction::Array, + "ascii" => BuiltinScalarFunction::Ascii, + "bit_length" => BuiltinScalarFunction::BitLength, + "btrim" => BuiltinScalarFunction::Btrim, + "char_length" => BuiltinScalarFunction::CharacterLength, + "character_length" => BuiltinScalarFunction::CharacterLength, + "concat" => BuiltinScalarFunction::Concat, + "concat_ws" => BuiltinScalarFunction::ConcatWithSeparator, + "chr" => BuiltinScalarFunction::Chr, + "date_part" | "datepart" => BuiltinScalarFunction::DatePart, + "date_trunc" | "datetrunc" => BuiltinScalarFunction::DateTrunc, + "initcap" => BuiltinScalarFunction::InitCap, + "left" => BuiltinScalarFunction::Left, + "length" => BuiltinScalarFunction::CharacterLength, + "lower" => BuiltinScalarFunction::Lower, + "lpad" => BuiltinScalarFunction::Lpad, + "ltrim" => BuiltinScalarFunction::Ltrim, + "md5" => BuiltinScalarFunction::MD5, + "nullif" => BuiltinScalarFunction::NullIf, + "octet_length" => BuiltinScalarFunction::OctetLength, + "random" => BuiltinScalarFunction::Random, + "regexp_replace" => BuiltinScalarFunction::RegexpReplace, + "repeat" => BuiltinScalarFunction::Repeat, + "replace" => BuiltinScalarFunction::Replace, + "reverse" => BuiltinScalarFunction::Reverse, + "right" => BuiltinScalarFunction::Right, + "rpad" => BuiltinScalarFunction::Rpad, + "rtrim" => BuiltinScalarFunction::Rtrim, + "sha224" => BuiltinScalarFunction::SHA224, + "sha256" => BuiltinScalarFunction::SHA256, + "sha384" => BuiltinScalarFunction::SHA384, + "sha512" => BuiltinScalarFunction::SHA512, + "digest" => BuiltinScalarFunction::Digest, + "split_part" => BuiltinScalarFunction::SplitPart, + "starts_with" => BuiltinScalarFunction::StartsWith, + "strpos" => BuiltinScalarFunction::Strpos, + "substr" => BuiltinScalarFunction::Substr, + "to_hex" => BuiltinScalarFunction::ToHex, + "to_timestamp" => BuiltinScalarFunction::ToTimestamp, + "to_timestamp_millis" => BuiltinScalarFunction::ToTimestampMillis, + "to_timestamp_micros" => BuiltinScalarFunction::ToTimestampMicros, + "to_timestamp_seconds" => BuiltinScalarFunction::ToTimestampSeconds, + "now" => BuiltinScalarFunction::Now, + "translate" => BuiltinScalarFunction::Translate, + "trim" => BuiltinScalarFunction::Trim, + "upper" => BuiltinScalarFunction::Upper, + "regexp_match" => BuiltinScalarFunction::RegexpMatch, + _ => { + return Err(DataFusionError::Plan(format!( + "There is no built-in function named {}", + name + ))) + } + }) + } +} diff --git a/datafusion-expr/src/columnar_value.rs b/datafusion-expr/src/columnar_value.rs new file mode 100644 index 000000000000..fb00f0c12b91 --- /dev/null +++ b/datafusion-expr/src/columnar_value.rs @@ -0,0 +1,63 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::ArrayRef; +use arrow::array::NullArray; +use arrow::datatypes::DataType; +use datafusion_common::record_batch::RecordBatch; +use datafusion_common::ScalarValue; +use std::sync::Arc; + +/// Represents the result from an expression +#[derive(Clone)] +pub enum ColumnarValue { + /// Array of values + Array(ArrayRef), + /// A single value + Scalar(ScalarValue), +} + +impl ColumnarValue { + pub fn data_type(&self) -> DataType { + match self { + ColumnarValue::Array(array_value) => array_value.data_type().clone(), + ColumnarValue::Scalar(scalar_value) => scalar_value.get_datatype(), + } + } + + /// Convert a columnar value into an ArrayRef + pub fn into_array(self, num_rows: usize) -> ArrayRef { + match self { + ColumnarValue::Array(array) => array, + ColumnarValue::Scalar(scalar) => scalar.to_array_of_size(num_rows), + } + } +} + +/// null columnar values are implemented as a null array in order to pass batch +/// num_rows +pub type NullColumnarValue = ColumnarValue; + +impl From<&RecordBatch> for NullColumnarValue { + fn from(batch: &RecordBatch) -> Self { + let num_rows = batch.num_rows(); + ColumnarValue::Array(Arc::new(NullArray::new_null( + DataType::Struct(batch.schema().fields.to_vec()), + num_rows, + ))) + } +} diff --git a/datafusion-expr/src/expr.rs b/datafusion-expr/src/expr.rs new file mode 100644 index 000000000000..f26f1dfa9746 --- /dev/null +++ b/datafusion-expr/src/expr.rs @@ -0,0 +1,698 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::aggregate_function; +use crate::built_in_function; +use crate::expr_fn::binary_expr; +use crate::window_frame; +use crate::window_function; +use crate::AggregateUDF; +use crate::Operator; +use crate::ScalarUDF; +use arrow::datatypes::DataType; +use datafusion_common::Column; +use datafusion_common::{DFSchema, Result}; +use datafusion_common::{DataFusionError, ScalarValue}; +use std::fmt; +use std::hash::{BuildHasher, Hash, Hasher}; +use std::ops::Not; +use std::sync::Arc; + +/// `Expr` is a central struct of DataFusion's query API, and +/// represent logical expressions such as `A + 1`, or `CAST(c1 AS +/// int)`. +/// +/// An `Expr` can compute its [DataType](arrow::datatypes::DataType) +/// and nullability, and has functions for building up complex +/// expressions. +/// +/// # Examples +/// +/// ## Create an expression `c1` referring to column named "c1" +/// ``` +/// # use datafusion_common::Column; +/// # use datafusion_expr::{lit, col, Expr}; +/// let expr = col("c1"); +/// assert_eq!(expr, Expr::Column(Column::from_name("c1"))); +/// ``` +/// +/// ## Create the expression `c1 + c2` to add columns "c1" and "c2" together +/// ``` +/// # use datafusion_expr::{lit, col, Operator, Expr}; +/// let expr = col("c1") + col("c2"); +/// +/// assert!(matches!(expr, Expr::BinaryExpr { ..} )); +/// if let Expr::BinaryExpr { left, right, op } = expr { +/// assert_eq!(*left, col("c1")); +/// assert_eq!(*right, col("c2")); +/// assert_eq!(op, Operator::Plus); +/// } +/// ``` +/// +/// ## Create expression `c1 = 42` to compare the value in column "c1" to the literal value `42` +/// ``` +/// # use datafusion_common::ScalarValue; +/// # use datafusion_expr::{lit, col, Operator, Expr}; +/// let expr = col("c1").eq(lit(42_i32)); +/// +/// assert!(matches!(expr, Expr::BinaryExpr { .. } )); +/// if let Expr::BinaryExpr { left, right, op } = expr { +/// assert_eq!(*left, col("c1")); +/// let scalar = ScalarValue::Int32(Some(42)); +/// assert_eq!(*right, Expr::Literal(scalar)); +/// assert_eq!(op, Operator::Eq); +/// } +/// ``` +#[derive(Clone, PartialEq, Hash)] +pub enum Expr { + /// An expression with a specific name. + Alias(Box, String), + /// A named reference to a qualified filed in a schema. + Column(Column), + /// A named reference to a variable in a registry. + ScalarVariable(Vec), + /// A constant value. + Literal(ScalarValue), + /// A binary expression such as "age > 21" + BinaryExpr { + /// Left-hand side of the expression + left: Box, + /// The comparison operator + op: Operator, + /// Right-hand side of the expression + right: Box, + }, + /// Negation of an expression. The expression's type must be a boolean to make sense. + Not(Box), + /// Whether an expression is not Null. This expression is never null. + IsNotNull(Box), + /// Whether an expression is Null. This expression is never null. + IsNull(Box), + /// arithmetic negation of an expression, the operand must be of a signed numeric data type + Negative(Box), + /// Returns the field of a [`ListArray`] or [`StructArray`] by key + GetIndexedField { + /// the expression to take the field from + expr: Box, + /// The name of the field to take + key: ScalarValue, + }, + /// Whether an expression is between a given range. + Between { + /// The value to compare + expr: Box, + /// Whether the expression is negated + negated: bool, + /// The low end of the range + low: Box, + /// The high end of the range + high: Box, + }, + /// The CASE expression is similar to a series of nested if/else and there are two forms that + /// can be used. The first form consists of a series of boolean "when" expressions with + /// corresponding "then" expressions, and an optional "else" expression. + /// + /// CASE WHEN condition THEN result + /// [WHEN ...] + /// [ELSE result] + /// END + /// + /// The second form uses a base expression and then a series of "when" clauses that match on a + /// literal value. + /// + /// CASE expression + /// WHEN value THEN result + /// [WHEN ...] + /// [ELSE result] + /// END + Case { + /// Optional base expression that can be compared to literal values in the "when" expressions + expr: Option>, + /// One or more when/then expressions + when_then_expr: Vec<(Box, Box)>, + /// Optional "else" expression + else_expr: Option>, + }, + /// Casts the expression to a given type and will return a runtime error if the expression cannot be cast. + /// This expression is guaranteed to have a fixed type. + Cast { + /// The expression being cast + expr: Box, + /// The `DataType` the expression will yield + data_type: DataType, + }, + /// Casts the expression to a given type and will return a null value if the expression cannot be cast. + /// This expression is guaranteed to have a fixed type. + TryCast { + /// The expression being cast + expr: Box, + /// The `DataType` the expression will yield + data_type: DataType, + }, + /// A sort expression, that can be used to sort values. + Sort { + /// The expression to sort on + expr: Box, + /// The direction of the sort + asc: bool, + /// Whether to put Nulls before all other data values + nulls_first: bool, + }, + /// Represents the call of a built-in scalar function with a set of arguments. + ScalarFunction { + /// The function + fun: built_in_function::BuiltinScalarFunction, + /// List of expressions to feed to the functions as arguments + args: Vec, + }, + /// Represents the call of a user-defined scalar function with arguments. + ScalarUDF { + /// The function + fun: Arc, + /// List of expressions to feed to the functions as arguments + args: Vec, + }, + /// Represents the call of an aggregate built-in function with arguments. + AggregateFunction { + /// Name of the function + fun: aggregate_function::AggregateFunction, + /// List of expressions to feed to the functions as arguments + args: Vec, + /// Whether this is a DISTINCT aggregation or not + distinct: bool, + }, + /// Represents the call of a window function with arguments. + WindowFunction { + /// Name of the function + fun: window_function::WindowFunction, + /// List of expressions to feed to the functions as arguments + args: Vec, + /// List of partition by expressions + partition_by: Vec, + /// List of order by expressions + order_by: Vec, + /// Window frame + window_frame: Option, + }, + /// aggregate function + AggregateUDF { + /// The function + fun: Arc, + /// List of expressions to feed to the functions as arguments + args: Vec, + }, + /// Returns whether the list contains the expr value. + InList { + /// The expression to compare + expr: Box, + /// A list of values to compare against + list: Vec, + /// Whether the expression is negated + negated: bool, + }, + /// Represents a reference to all fields in a schema. + Wildcard, +} + +/// Fixed seed for the hashing so that Ords are consistent across runs +const SEED: ahash::RandomState = ahash::RandomState::with_seeds(0, 0, 0, 0); + +impl PartialOrd for Expr { + fn partial_cmp(&self, other: &Self) -> Option { + let mut hasher = SEED.build_hasher(); + self.hash(&mut hasher); + let s = hasher.finish(); + + let mut hasher = SEED.build_hasher(); + other.hash(&mut hasher); + let o = hasher.finish(); + + Some(s.cmp(&o)) + } +} + +impl Expr { + /// Returns the name of this expression based on [crate::logical_plan::DFSchema]. + /// + /// This represents how a column with this expression is named when no alias is chosen + pub fn name(&self, input_schema: &DFSchema) -> Result { + create_name(self, input_schema) + } + + /// Return `self == other` + pub fn eq(self, other: Expr) -> Expr { + binary_expr(self, Operator::Eq, other) + } + + /// Return `self != other` + pub fn not_eq(self, other: Expr) -> Expr { + binary_expr(self, Operator::NotEq, other) + } + + /// Return `self > other` + pub fn gt(self, other: Expr) -> Expr { + binary_expr(self, Operator::Gt, other) + } + + /// Return `self >= other` + pub fn gt_eq(self, other: Expr) -> Expr { + binary_expr(self, Operator::GtEq, other) + } + + /// Return `self < other` + pub fn lt(self, other: Expr) -> Expr { + binary_expr(self, Operator::Lt, other) + } + + /// Return `self <= other` + pub fn lt_eq(self, other: Expr) -> Expr { + binary_expr(self, Operator::LtEq, other) + } + + /// Return `self && other` + pub fn and(self, other: Expr) -> Expr { + binary_expr(self, Operator::And, other) + } + + /// Return `self || other` + pub fn or(self, other: Expr) -> Expr { + binary_expr(self, Operator::Or, other) + } + + /// Return `!self` + #[allow(clippy::should_implement_trait)] + pub fn not(self) -> Expr { + !self + } + + /// Calculate the modulus of two expressions. + /// Return `self % other` + pub fn modulus(self, other: Expr) -> Expr { + binary_expr(self, Operator::Modulo, other) + } + + /// Return `self LIKE other` + pub fn like(self, other: Expr) -> Expr { + binary_expr(self, Operator::Like, other) + } + + /// Return `self NOT LIKE other` + pub fn not_like(self, other: Expr) -> Expr { + binary_expr(self, Operator::NotLike, other) + } + + /// Return `self AS name` alias expression + pub fn alias(self, name: &str) -> Expr { + Expr::Alias(Box::new(self), name.to_owned()) + } + + /// Return `self IN ` if `negated` is false, otherwise + /// return `self NOT IN `.a + pub fn in_list(self, list: Vec, negated: bool) -> Expr { + Expr::InList { + expr: Box::new(self), + list, + negated, + } + } + + /// Return `IsNull(Box(self)) + #[allow(clippy::wrong_self_convention)] + pub fn is_null(self) -> Expr { + Expr::IsNull(Box::new(self)) + } + + /// Return `IsNotNull(Box(self)) + #[allow(clippy::wrong_self_convention)] + pub fn is_not_null(self) -> Expr { + Expr::IsNotNull(Box::new(self)) + } + + /// Create a sort expression from an existing expression. + /// + /// ``` + /// # use datafusion_expr::col; + /// let sort_expr = col("foo").sort(true, true); // SORT ASC NULLS_FIRST + /// ``` + pub fn sort(self, asc: bool, nulls_first: bool) -> Expr { + Expr::Sort { + expr: Box::new(self), + asc, + nulls_first, + } + } +} + +impl Not for Expr { + type Output = Self; + + fn not(self) -> Self::Output { + Expr::Not(Box::new(self)) + } +} + +impl std::fmt::Display for Expr { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + match self { + Expr::BinaryExpr { + ref left, + ref right, + ref op, + } => write!(f, "{} {} {}", left, op, right), + Expr::AggregateFunction { + /// Name of the function + ref fun, + /// List of expressions to feed to the functions as arguments + ref args, + /// Whether this is a DISTINCT aggregation or not + ref distinct, + } => fmt_function(f, &fun.to_string(), *distinct, args, true), + Expr::ScalarFunction { + /// Name of the function + ref fun, + /// List of expressions to feed to the functions as arguments + ref args, + } => fmt_function(f, &fun.to_string(), false, args, true), + _ => write!(f, "{:?}", self), + } + } +} + +impl fmt::Debug for Expr { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match self { + Expr::Alias(expr, alias) => write!(f, "{:?} AS {}", expr, alias), + Expr::Column(c) => write!(f, "{}", c), + Expr::ScalarVariable(var_names) => write!(f, "{}", var_names.join(".")), + Expr::Literal(v) => write!(f, "{:?}", v), + Expr::Case { + expr, + when_then_expr, + else_expr, + .. + } => { + write!(f, "CASE ")?; + if let Some(e) = expr { + write!(f, "{:?} ", e)?; + } + for (w, t) in when_then_expr { + write!(f, "WHEN {:?} THEN {:?} ", w, t)?; + } + if let Some(e) = else_expr { + write!(f, "ELSE {:?} ", e)?; + } + write!(f, "END") + } + Expr::Cast { expr, data_type } => { + write!(f, "CAST({:?} AS {:?})", expr, data_type) + } + Expr::TryCast { expr, data_type } => { + write!(f, "TRY_CAST({:?} AS {:?})", expr, data_type) + } + Expr::Not(expr) => write!(f, "NOT {:?}", expr), + Expr::Negative(expr) => write!(f, "(- {:?})", expr), + Expr::IsNull(expr) => write!(f, "{:?} IS NULL", expr), + Expr::IsNotNull(expr) => write!(f, "{:?} IS NOT NULL", expr), + Expr::BinaryExpr { left, op, right } => { + write!(f, "{:?} {} {:?}", left, op, right) + } + Expr::Sort { + expr, + asc, + nulls_first, + } => { + if *asc { + write!(f, "{:?} ASC", expr)?; + } else { + write!(f, "{:?} DESC", expr)?; + } + if *nulls_first { + write!(f, " NULLS FIRST") + } else { + write!(f, " NULLS LAST") + } + } + Expr::ScalarFunction { fun, args, .. } => { + fmt_function(f, &fun.to_string(), false, args, false) + } + Expr::ScalarUDF { fun, ref args, .. } => { + fmt_function(f, &fun.name, false, args, false) + } + Expr::WindowFunction { + fun, + args, + partition_by, + order_by, + window_frame, + } => { + fmt_function(f, &fun.to_string(), false, args, false)?; + if !partition_by.is_empty() { + write!(f, " PARTITION BY {:?}", partition_by)?; + } + if !order_by.is_empty() { + write!(f, " ORDER BY {:?}", order_by)?; + } + if let Some(window_frame) = window_frame { + write!( + f, + " {} BETWEEN {} AND {}", + window_frame.units, + window_frame.start_bound, + window_frame.end_bound + )?; + } + Ok(()) + } + Expr::AggregateFunction { + fun, + distinct, + ref args, + .. + } => fmt_function(f, &fun.to_string(), *distinct, args, true), + Expr::AggregateUDF { fun, ref args, .. } => { + fmt_function(f, &fun.name, false, args, false) + } + Expr::Between { + expr, + negated, + low, + high, + } => { + if *negated { + write!(f, "{:?} NOT BETWEEN {:?} AND {:?}", expr, low, high) + } else { + write!(f, "{:?} BETWEEN {:?} AND {:?}", expr, low, high) + } + } + Expr::InList { + expr, + list, + negated, + } => { + if *negated { + write!(f, "{:?} NOT IN ({:?})", expr, list) + } else { + write!(f, "{:?} IN ({:?})", expr, list) + } + } + Expr::Wildcard => write!(f, "*"), + Expr::GetIndexedField { ref expr, key } => { + write!(f, "({:?})[{}]", expr, key) + } + } + } +} + +fn fmt_function( + f: &mut fmt::Formatter, + fun: &str, + distinct: bool, + args: &[Expr], + display: bool, +) -> fmt::Result { + let args: Vec = match display { + true => args.iter().map(|arg| format!("{}", arg)).collect(), + false => args.iter().map(|arg| format!("{:?}", arg)).collect(), + }; + + // let args: Vec = args.iter().map(|arg| format!("{:?}", arg)).collect(); + let distinct_str = match distinct { + true => "DISTINCT ", + false => "", + }; + write!(f, "{}({}{})", fun, distinct_str, args.join(", ")) +} + +fn create_function_name( + fun: &str, + distinct: bool, + args: &[Expr], + input_schema: &DFSchema, +) -> Result { + let names: Vec = args + .iter() + .map(|e| create_name(e, input_schema)) + .collect::>()?; + let distinct_str = match distinct { + true => "DISTINCT ", + false => "", + }; + Ok(format!("{}({}{})", fun, distinct_str, names.join(","))) +} + +/// Returns a readable name of an expression based on the input schema. +/// This function recursively transverses the expression for names such as "CAST(a > 2)". +fn create_name(e: &Expr, input_schema: &DFSchema) -> Result { + match e { + Expr::Alias(_, name) => Ok(name.clone()), + Expr::Column(c) => Ok(c.flat_name()), + Expr::ScalarVariable(variable_names) => Ok(variable_names.join(".")), + Expr::Literal(value) => Ok(format!("{:?}", value)), + Expr::BinaryExpr { left, op, right } => { + let left = create_name(left, input_schema)?; + let right = create_name(right, input_schema)?; + Ok(format!("{} {} {}", left, op, right)) + } + Expr::Case { + expr, + when_then_expr, + else_expr, + } => { + let mut name = "CASE ".to_string(); + if let Some(e) = expr { + let e = create_name(e, input_schema)?; + name += &format!("{} ", e); + } + for (w, t) in when_then_expr { + let when = create_name(w, input_schema)?; + let then = create_name(t, input_schema)?; + name += &format!("WHEN {} THEN {} ", when, then); + } + if let Some(e) = else_expr { + let e = create_name(e, input_schema)?; + name += &format!("ELSE {} ", e); + } + name += "END"; + Ok(name) + } + Expr::Cast { expr, data_type } => { + let expr = create_name(expr, input_schema)?; + Ok(format!("CAST({} AS {:?})", expr, data_type)) + } + Expr::TryCast { expr, data_type } => { + let expr = create_name(expr, input_schema)?; + Ok(format!("TRY_CAST({} AS {:?})", expr, data_type)) + } + Expr::Not(expr) => { + let expr = create_name(expr, input_schema)?; + Ok(format!("NOT {}", expr)) + } + Expr::Negative(expr) => { + let expr = create_name(expr, input_schema)?; + Ok(format!("(- {})", expr)) + } + Expr::IsNull(expr) => { + let expr = create_name(expr, input_schema)?; + Ok(format!("{} IS NULL", expr)) + } + Expr::IsNotNull(expr) => { + let expr = create_name(expr, input_schema)?; + Ok(format!("{} IS NOT NULL", expr)) + } + Expr::GetIndexedField { expr, key } => { + let expr = create_name(expr, input_schema)?; + Ok(format!("{}[{}]", expr, key)) + } + Expr::ScalarFunction { fun, args, .. } => { + create_function_name(&fun.to_string(), false, args, input_schema) + } + Expr::ScalarUDF { fun, args, .. } => { + create_function_name(&fun.name, false, args, input_schema) + } + Expr::WindowFunction { + fun, + args, + window_frame, + partition_by, + order_by, + } => { + let mut parts: Vec = vec![create_function_name( + &fun.to_string(), + false, + args, + input_schema, + )?]; + if !partition_by.is_empty() { + parts.push(format!("PARTITION BY {:?}", partition_by)); + } + if !order_by.is_empty() { + parts.push(format!("ORDER BY {:?}", order_by)); + } + if let Some(window_frame) = window_frame { + parts.push(format!("{}", window_frame)); + } + Ok(parts.join(" ")) + } + Expr::AggregateFunction { + fun, + distinct, + args, + .. + } => create_function_name(&fun.to_string(), *distinct, args, input_schema), + Expr::AggregateUDF { fun, args } => { + let mut names = Vec::with_capacity(args.len()); + for e in args { + names.push(create_name(e, input_schema)?); + } + Ok(format!("{}({})", fun.name, names.join(","))) + } + Expr::InList { + expr, + list, + negated, + } => { + let expr = create_name(expr, input_schema)?; + let list = list.iter().map(|expr| create_name(expr, input_schema)); + if *negated { + Ok(format!("{} NOT IN ({:?})", expr, list)) + } else { + Ok(format!("{} IN ({:?})", expr, list)) + } + } + Expr::Between { + expr, + negated, + low, + high, + } => { + let expr = create_name(expr, input_schema)?; + let low = create_name(low, input_schema)?; + let high = create_name(high, input_schema)?; + if *negated { + Ok(format!("{} NOT BETWEEN {} AND {}", expr, low, high)) + } else { + Ok(format!("{} BETWEEN {} AND {}", expr, low, high)) + } + } + Expr::Sort { .. } => Err(DataFusionError::Internal( + "Create name does not support sort expression".to_string(), + )), + Expr::Wildcard => Err(DataFusionError::Internal( + "Create name does not support wildcard".to_string(), + )), + } +} diff --git a/datafusion-expr/src/expr_fn.rs b/datafusion-expr/src/expr_fn.rs new file mode 100644 index 000000000000..469a82d0ff24 --- /dev/null +++ b/datafusion-expr/src/expr_fn.rs @@ -0,0 +1,32 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::{Expr, Operator}; + +/// Create a column expression based on a qualified or unqualified column name +pub fn col(ident: &str) -> Expr { + Expr::Column(ident.into()) +} + +/// return a new expression l r +pub fn binary_expr(l: Expr, op: Operator, r: Expr) -> Expr { + Expr::BinaryExpr { + left: Box::new(l), + op, + right: Box::new(r), + } +} diff --git a/datafusion-expr/src/function.rs b/datafusion-expr/src/function.rs new file mode 100644 index 000000000000..2bacd6ae6227 --- /dev/null +++ b/datafusion-expr/src/function.rs @@ -0,0 +1,46 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::Accumulator; +use crate::ColumnarValue; +use arrow::datatypes::DataType; +use datafusion_common::Result; +use std::sync::Arc; + +/// Scalar function +/// +/// The Fn param is the wrapped function but be aware that the function will +/// be passed with the slice / vec of columnar values (either scalar or array) +/// with the exception of zero param function, where a singular element vec +/// will be passed. In that case the single element is a null array to indicate +/// the batch's row count (so that the generative zero-argument function can know +/// the result array size). +pub type ScalarFunctionImplementation = + Arc Result + Send + Sync>; + +/// A function's return type +pub type ReturnTypeFunction = + Arc Result> + Send + Sync>; + +/// the implementation of an aggregate function +pub type AccumulatorFunctionImplementation = + Arc Result> + Send + Sync>; + +/// This signature corresponds to which types an aggregator serializes +/// its state, given its return datatype. +pub type StateTypeFunction = + Arc Result>> + Send + Sync>; diff --git a/datafusion-expr/src/lib.rs b/datafusion-expr/src/lib.rs new file mode 100644 index 000000000000..709fa634d52d --- /dev/null +++ b/datafusion-expr/src/lib.rs @@ -0,0 +1,49 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +mod accumulator; +mod aggregate_function; +mod built_in_function; +mod columnar_value; +pub mod expr; +pub mod expr_fn; +mod function; +mod literal; +mod operator; +mod signature; +mod udaf; +mod udf; +mod window_frame; +mod window_function; + +pub use accumulator::Accumulator; +pub use aggregate_function::AggregateFunction; +pub use built_in_function::BuiltinScalarFunction; +pub use columnar_value::{ColumnarValue, NullColumnarValue}; +pub use expr::Expr; +pub use expr_fn::col; +pub use function::{ + AccumulatorFunctionImplementation, ReturnTypeFunction, ScalarFunctionImplementation, + StateTypeFunction, +}; +pub use literal::{lit, lit_timestamp_nano, Literal, TimestampLiteral}; +pub use operator::Operator; +pub use signature::{Signature, TypeSignature, Volatility}; +pub use udaf::AggregateUDF; +pub use udf::ScalarUDF; +pub use window_frame::{WindowFrame, WindowFrameBound, WindowFrameUnits}; +pub use window_function::{BuiltInWindowFunction, WindowFunction}; diff --git a/datafusion-expr/src/literal.rs b/datafusion-expr/src/literal.rs new file mode 100644 index 000000000000..02c75af69573 --- /dev/null +++ b/datafusion-expr/src/literal.rs @@ -0,0 +1,138 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::Expr; +use datafusion_common::ScalarValue; + +/// Create a literal expression +pub fn lit(n: T) -> Expr { + n.lit() +} + +/// Create a literal timestamp expression +pub fn lit_timestamp_nano(n: T) -> Expr { + n.lit_timestamp_nano() +} + +/// Trait for converting a type to a [`Literal`] literal expression. +pub trait Literal { + /// convert the value to a Literal expression + fn lit(&self) -> Expr; +} + +/// Trait for converting a type to a literal timestamp +pub trait TimestampLiteral { + fn lit_timestamp_nano(&self) -> Expr; +} + +impl Literal for &str { + fn lit(&self) -> Expr { + Expr::Literal(ScalarValue::Utf8(Some((*self).to_owned()))) + } +} + +impl Literal for String { + fn lit(&self) -> Expr { + Expr::Literal(ScalarValue::Utf8(Some((*self).to_owned()))) + } +} + +impl Literal for Vec { + fn lit(&self) -> Expr { + Expr::Literal(ScalarValue::Binary(Some((*self).to_owned()))) + } +} + +impl Literal for &[u8] { + fn lit(&self) -> Expr { + Expr::Literal(ScalarValue::Binary(Some((*self).to_owned()))) + } +} + +impl Literal for ScalarValue { + fn lit(&self) -> Expr { + Expr::Literal(self.clone()) + } +} + +macro_rules! make_literal { + ($TYPE:ty, $SCALAR:ident, $DOC: expr) => { + #[doc = $DOC] + impl Literal for $TYPE { + fn lit(&self) -> Expr { + Expr::Literal(ScalarValue::$SCALAR(Some(self.clone()))) + } + } + }; +} + +macro_rules! make_timestamp_literal { + ($TYPE:ty, $SCALAR:ident, $DOC: expr) => { + #[doc = $DOC] + impl TimestampLiteral for $TYPE { + fn lit_timestamp_nano(&self) -> Expr { + Expr::Literal(ScalarValue::TimestampNanosecond( + Some((self.clone()).into()), + None, + )) + } + } + }; +} + +make_literal!(bool, Boolean, "literal expression containing a bool"); +make_literal!(f32, Float32, "literal expression containing an f32"); +make_literal!(f64, Float64, "literal expression containing an f64"); +make_literal!(i8, Int8, "literal expression containing an i8"); +make_literal!(i16, Int16, "literal expression containing an i16"); +make_literal!(i32, Int32, "literal expression containing an i32"); +make_literal!(i64, Int64, "literal expression containing an i64"); +make_literal!(u8, UInt8, "literal expression containing a u8"); +make_literal!(u16, UInt16, "literal expression containing a u16"); +make_literal!(u32, UInt32, "literal expression containing a u32"); +make_literal!(u64, UInt64, "literal expression containing a u64"); + +make_timestamp_literal!(i8, Int8, "literal expression containing an i8"); +make_timestamp_literal!(i16, Int16, "literal expression containing an i16"); +make_timestamp_literal!(i32, Int32, "literal expression containing an i32"); +make_timestamp_literal!(i64, Int64, "literal expression containing an i64"); +make_timestamp_literal!(u8, UInt8, "literal expression containing a u8"); +make_timestamp_literal!(u16, UInt16, "literal expression containing a u16"); +make_timestamp_literal!(u32, UInt32, "literal expression containing a u32"); + +#[cfg(test)] +mod test { + use super::*; + use crate::expr_fn::col; + use datafusion_common::ScalarValue; + + #[test] + fn test_lit_timestamp_nano() { + let expr = col("time").eq(lit_timestamp_nano(10)); // 10 is an implicit i32 + let expected = + col("time").eq(lit(ScalarValue::TimestampNanosecond(Some(10), None))); + assert_eq!(expr, expected); + + let i: i64 = 10; + let expr = col("time").eq(lit_timestamp_nano(i)); + assert_eq!(expr, expected); + + let i: u32 = 10; + let expr = col("time").eq(lit_timestamp_nano(i)); + assert_eq!(expr, expected); + } +} diff --git a/datafusion-expr/src/operator.rs b/datafusion-expr/src/operator.rs new file mode 100644 index 000000000000..a1cad76cdd97 --- /dev/null +++ b/datafusion-expr/src/operator.rs @@ -0,0 +1,140 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::expr_fn::binary_expr; +use crate::Expr; +use std::fmt; +use std::ops; + +/// Operators applied to expressions +#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Hash)] +pub enum Operator { + /// Expressions are equal + Eq, + /// Expressions are not equal + NotEq, + /// Left side is smaller than right side + Lt, + /// Left side is smaller or equal to right side + LtEq, + /// Left side is greater than right side + Gt, + /// Left side is greater or equal to right side + GtEq, + /// Addition + Plus, + /// Subtraction + Minus, + /// Multiplication operator, like `*` + Multiply, + /// Division operator, like `/` + Divide, + /// Remainder operator, like `%` + Modulo, + /// Logical AND, like `&&` + And, + /// Logical OR, like `||` + Or, + /// Matches a wildcard pattern + Like, + /// Does not match a wildcard pattern + NotLike, + /// IS DISTINCT FROM + IsDistinctFrom, + /// IS NOT DISTINCT FROM + IsNotDistinctFrom, + /// Case sensitive regex match + RegexMatch, + /// Case insensitive regex match + RegexIMatch, + /// Case sensitive regex not match + RegexNotMatch, + /// Case insensitive regex not match + RegexNotIMatch, + /// Bitwise and, like `&` + BitwiseAnd, +} + +impl fmt::Display for Operator { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + let display = match &self { + Operator::Eq => "=", + Operator::NotEq => "!=", + Operator::Lt => "<", + Operator::LtEq => "<=", + Operator::Gt => ">", + Operator::GtEq => ">=", + Operator::Plus => "+", + Operator::Minus => "-", + Operator::Multiply => "*", + Operator::Divide => "/", + Operator::Modulo => "%", + Operator::And => "AND", + Operator::Or => "OR", + Operator::Like => "LIKE", + Operator::NotLike => "NOT LIKE", + Operator::RegexMatch => "~", + Operator::RegexIMatch => "~*", + Operator::RegexNotMatch => "!~", + Operator::RegexNotIMatch => "!~*", + Operator::IsDistinctFrom => "IS DISTINCT FROM", + Operator::IsNotDistinctFrom => "IS NOT DISTINCT FROM", + Operator::BitwiseAnd => "&", + }; + write!(f, "{}", display) + } +} + +impl ops::Add for Expr { + type Output = Self; + + fn add(self, rhs: Self) -> Self { + binary_expr(self, Operator::Plus, rhs) + } +} + +impl ops::Sub for Expr { + type Output = Self; + + fn sub(self, rhs: Self) -> Self { + binary_expr(self, Operator::Minus, rhs) + } +} + +impl ops::Mul for Expr { + type Output = Self; + + fn mul(self, rhs: Self) -> Self { + binary_expr(self, Operator::Multiply, rhs) + } +} + +impl ops::Div for Expr { + type Output = Self; + + fn div(self, rhs: Self) -> Self { + binary_expr(self, Operator::Divide, rhs) + } +} + +impl ops::Rem for Expr { + type Output = Self; + + fn rem(self, rhs: Self) -> Self { + binary_expr(self, Operator::Modulo, rhs) + } +} diff --git a/datafusion-expr/src/signature.rs b/datafusion-expr/src/signature.rs new file mode 100644 index 000000000000..5c27f422c105 --- /dev/null +++ b/datafusion-expr/src/signature.rs @@ -0,0 +1,116 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::datatypes::DataType; + +///A function's volatility, which defines the functions eligibility for certain optimizations +#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Copy, Hash)] +pub enum Volatility { + /// Immutable - An immutable function will always return the same output when given the same input. An example of this is [BuiltinScalarFunction::Cos]. + Immutable, + /// Stable - A stable function may return different values given the same input accross different queries but must return the same value for a given input within a query. An example of this is [BuiltinScalarFunction::Now]. + Stable, + /// Volatile - A volatile function may change the return value from evaluation to evaluation. Mutiple invocations of a volatile function may return different results when used in the same query. An example of this is [BuiltinScalarFunction::Random]. + Volatile, +} + +/// A function's type signature, which defines the function's supported argument types. +#[derive(Debug, Clone, PartialEq, Hash)] +pub enum TypeSignature { + /// arbitrary number of arguments of an common type out of a list of valid types + // A function such as `concat` is `Variadic(vec![DataType::Utf8, DataType::LargeUtf8])` + Variadic(Vec), + /// arbitrary number of arguments of an arbitrary but equal type + // A function such as `array` is `VariadicEqual` + // The first argument decides the type used for coercion + VariadicEqual, + /// fixed number of arguments of an arbitrary but equal type out of a list of valid types + // A function of one argument of f64 is `Uniform(1, vec![DataType::Float64])` + // A function of one argument of f64 or f32 is `Uniform(1, vec![DataType::Float32, DataType::Float64])` + Uniform(usize, Vec), + /// exact number of arguments of an exact type + Exact(Vec), + /// fixed number of arguments of arbitrary types + Any(usize), + /// One of a list of signatures + OneOf(Vec), +} + +///The Signature of a function defines its supported input types as well as its volatility. +#[derive(Debug, Clone, PartialEq, Hash)] +pub struct Signature { + /// type_signature - The types that the function accepts. See [TypeSignature] for more information. + pub type_signature: TypeSignature, + /// volatility - The volatility of the function. See [Volatility] for more information. + pub volatility: Volatility, +} + +impl Signature { + /// new - Creates a new Signature from any type signature and the volatility. + pub fn new(type_signature: TypeSignature, volatility: Volatility) -> Self { + Signature { + type_signature, + volatility, + } + } + /// variadic - Creates a variadic signature that represents an arbitrary number of arguments all from a type in common_types. + pub fn variadic(common_types: Vec, volatility: Volatility) -> Self { + Self { + type_signature: TypeSignature::Variadic(common_types), + volatility, + } + } + /// variadic_equal - Creates a variadic signature that represents an arbitrary number of arguments of the same type. + pub fn variadic_equal(volatility: Volatility) -> Self { + Self { + type_signature: TypeSignature::VariadicEqual, + volatility, + } + } + /// uniform - Creates a function with a fixed number of arguments of the same type, which must be from valid_types. + pub fn uniform( + arg_count: usize, + valid_types: Vec, + volatility: Volatility, + ) -> Self { + Self { + type_signature: TypeSignature::Uniform(arg_count, valid_types), + volatility, + } + } + /// exact - Creates a signture which must match the types in exact_types in order. + pub fn exact(exact_types: Vec, volatility: Volatility) -> Self { + Signature { + type_signature: TypeSignature::Exact(exact_types), + volatility, + } + } + /// any - Creates a signature which can a be made of any type but of a specified number + pub fn any(arg_count: usize, volatility: Volatility) -> Self { + Signature { + type_signature: TypeSignature::Any(arg_count), + volatility, + } + } + /// one_of Creates a signature which can match any of the [TypeSignature]s which are passed in. + pub fn one_of(type_signatures: Vec, volatility: Volatility) -> Self { + Signature { + type_signature: TypeSignature::OneOf(type_signatures), + volatility, + } + } +} diff --git a/datafusion-expr/src/udaf.rs b/datafusion-expr/src/udaf.rs new file mode 100644 index 000000000000..a39d58b622f3 --- /dev/null +++ b/datafusion-expr/src/udaf.rs @@ -0,0 +1,92 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! This module contains functions and structs supporting user-defined aggregate functions. + +use crate::Expr; +use crate::{ + AccumulatorFunctionImplementation, ReturnTypeFunction, Signature, StateTypeFunction, +}; +use std::fmt::{self, Debug, Formatter}; +use std::sync::Arc; + +/// Logical representation of a user-defined aggregate function (UDAF) +/// A UDAF is different from a UDF in that it is stateful across batches. +#[derive(Clone)] +pub struct AggregateUDF { + /// name + pub name: String, + /// signature + pub signature: Signature, + /// Return type + pub return_type: ReturnTypeFunction, + /// actual implementation + pub accumulator: AccumulatorFunctionImplementation, + /// the accumulator's state's description as a function of the return type + pub state_type: StateTypeFunction, +} + +impl Debug for AggregateUDF { + fn fmt(&self, f: &mut Formatter) -> fmt::Result { + f.debug_struct("AggregateUDF") + .field("name", &self.name) + .field("signature", &self.signature) + .field("fun", &"") + .finish() + } +} + +impl PartialEq for AggregateUDF { + fn eq(&self, other: &Self) -> bool { + self.name == other.name && self.signature == other.signature + } +} + +impl std::hash::Hash for AggregateUDF { + fn hash(&self, state: &mut H) { + self.name.hash(state); + self.signature.hash(state); + } +} + +impl AggregateUDF { + /// Create a new AggregateUDF + pub fn new( + name: &str, + signature: &Signature, + return_type: &ReturnTypeFunction, + accumulator: &AccumulatorFunctionImplementation, + state_type: &StateTypeFunction, + ) -> Self { + Self { + name: name.to_owned(), + signature: signature.clone(), + return_type: return_type.clone(), + accumulator: accumulator.clone(), + state_type: state_type.clone(), + } + } + + /// creates a logical expression with a call of the UDAF + /// This utility allows using the UDAF without requiring access to the registry. + pub fn call(&self, args: Vec) -> Expr { + Expr::AggregateUDF { + fun: Arc::new(self.clone()), + args, + } + } +} diff --git a/datafusion-expr/src/udf.rs b/datafusion-expr/src/udf.rs new file mode 100644 index 000000000000..79a17a4a2b4b --- /dev/null +++ b/datafusion-expr/src/udf.rs @@ -0,0 +1,93 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! UDF support + +use crate::{Expr, ReturnTypeFunction, ScalarFunctionImplementation, Signature}; +use std::fmt; +use std::fmt::Debug; +use std::fmt::Formatter; +use std::sync::Arc; + +/// Logical representation of a UDF. +#[derive(Clone)] +pub struct ScalarUDF { + /// name + pub name: String, + /// signature + pub signature: Signature, + /// Return type + pub return_type: ReturnTypeFunction, + /// actual implementation + /// + /// The fn param is the wrapped function but be aware that the function will + /// be passed with the slice / vec of columnar values (either scalar or array) + /// with the exception of zero param function, where a singular element vec + /// will be passed. In that case the single element is a null array to indicate + /// the batch's row count (so that the generative zero-argument function can know + /// the result array size). + pub fun: ScalarFunctionImplementation, +} + +impl Debug for ScalarUDF { + fn fmt(&self, f: &mut Formatter) -> fmt::Result { + f.debug_struct("ScalarUDF") + .field("name", &self.name) + .field("signature", &self.signature) + .field("fun", &"") + .finish() + } +} + +impl PartialEq for ScalarUDF { + fn eq(&self, other: &Self) -> bool { + self.name == other.name && self.signature == other.signature + } +} + +impl std::hash::Hash for ScalarUDF { + fn hash(&self, state: &mut H) { + self.name.hash(state); + self.signature.hash(state); + } +} + +impl ScalarUDF { + /// Create a new ScalarUDF + pub fn new( + name: &str, + signature: &Signature, + return_type: &ReturnTypeFunction, + fun: &ScalarFunctionImplementation, + ) -> Self { + Self { + name: name.to_owned(), + signature: signature.clone(), + return_type: return_type.clone(), + fun: fun.clone(), + } + } + + /// creates a logical expression with a call of the UDF + /// This utility allows using the UDF without requiring access to the registry. + pub fn call(&self, args: Vec) -> Expr { + Expr::ScalarUDF { + fun: Arc::new(self.clone()), + args, + } + } +} diff --git a/datafusion-expr/src/window_frame.rs b/datafusion-expr/src/window_frame.rs new file mode 100644 index 000000000000..ba65a5088b61 --- /dev/null +++ b/datafusion-expr/src/window_frame.rs @@ -0,0 +1,381 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Window frame +//! +//! The frame-spec determines which output rows are read by an aggregate window function. The frame-spec consists of four parts: +//! - A frame type - either ROWS, RANGE or GROUPS, +//! - A starting frame boundary, +//! - An ending frame boundary, +//! - An EXCLUDE clause. + +use datafusion_common::{DataFusionError, Result}; +use sqlparser::ast; +use std::cmp::Ordering; +use std::convert::{From, TryFrom}; +use std::fmt; +use std::hash::{Hash, Hasher}; + +/// The frame-spec determines which output rows are read by an aggregate window function. +/// +/// The ending frame boundary can be omitted (if the BETWEEN and AND keywords that surround the +/// starting frame boundary are also omitted), in which case the ending frame boundary defaults to +/// CURRENT ROW. +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Hash)] +pub struct WindowFrame { + /// A frame type - either ROWS, RANGE or GROUPS + pub units: WindowFrameUnits, + /// A starting frame boundary + pub start_bound: WindowFrameBound, + /// An ending frame boundary + pub end_bound: WindowFrameBound, +} + +impl fmt::Display for WindowFrame { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!( + f, + "{} BETWEEN {} AND {}", + self.units, self.start_bound, self.end_bound + )?; + Ok(()) + } +} + +impl TryFrom for WindowFrame { + type Error = DataFusionError; + + fn try_from(value: ast::WindowFrame) -> Result { + let start_bound = value.start_bound.into(); + let end_bound = value + .end_bound + .map(WindowFrameBound::from) + .unwrap_or(WindowFrameBound::CurrentRow); + + if let WindowFrameBound::Following(None) = start_bound { + Err(DataFusionError::Execution( + "Invalid window frame: start bound cannot be unbounded following" + .to_owned(), + )) + } else if let WindowFrameBound::Preceding(None) = end_bound { + Err(DataFusionError::Execution( + "Invalid window frame: end bound cannot be unbounded preceding" + .to_owned(), + )) + } else if start_bound > end_bound { + Err(DataFusionError::Execution(format!( + "Invalid window frame: start bound ({}) cannot be larger than end bound ({})", + start_bound, end_bound + ))) + } else { + let units = value.units.into(); + if units == WindowFrameUnits::Range { + for bound in &[start_bound, end_bound] { + match bound { + WindowFrameBound::Preceding(Some(v)) + | WindowFrameBound::Following(Some(v)) + if *v > 0 => + { + Err(DataFusionError::NotImplemented(format!( + "With WindowFrameUnits={}, the bound cannot be {} PRECEDING or FOLLOWING at the moment", + units, v + ))) + } + _ => Ok(()), + }?; + } + } + Ok(Self { + units, + start_bound, + end_bound, + }) + } + } +} + +impl Default for WindowFrame { + fn default() -> Self { + WindowFrame { + units: WindowFrameUnits::Range, + start_bound: WindowFrameBound::Preceding(None), + end_bound: WindowFrameBound::CurrentRow, + } + } +} + +/// There are five ways to describe starting and ending frame boundaries: +/// +/// 1. UNBOUNDED PRECEDING +/// 2. PRECEDING +/// 3. CURRENT ROW +/// 4. FOLLOWING +/// 5. UNBOUNDED FOLLOWING +/// +/// in this implementation we'll only allow to be u64 (i.e. no dynamic boundary) +#[derive(Debug, Clone, Copy, Eq)] +pub enum WindowFrameBound { + /// 1. UNBOUNDED PRECEDING + /// The frame boundary is the first row in the partition. + /// + /// 2. PRECEDING + /// must be a non-negative constant numeric expression. The boundary is a row that + /// is "units" prior to the current row. + Preceding(Option), + /// 3. The current row. + /// + /// For RANGE and GROUPS frame types, peers of the current row are also + /// included in the frame, unless specifically excluded by the EXCLUDE clause. + /// This is true regardless of whether CURRENT ROW is used as the starting or ending frame + /// boundary. + CurrentRow, + /// 4. This is the same as " PRECEDING" except that the boundary is units after the + /// current rather than before the current row. + /// + /// 5. UNBOUNDED FOLLOWING + /// The frame boundary is the last row in the partition. + Following(Option), +} + +impl From for WindowFrameBound { + fn from(value: ast::WindowFrameBound) -> Self { + match value { + ast::WindowFrameBound::Preceding(v) => Self::Preceding(v), + ast::WindowFrameBound::Following(v) => Self::Following(v), + ast::WindowFrameBound::CurrentRow => Self::CurrentRow, + } + } +} + +impl fmt::Display for WindowFrameBound { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match self { + WindowFrameBound::CurrentRow => f.write_str("CURRENT ROW"), + WindowFrameBound::Preceding(None) => f.write_str("UNBOUNDED PRECEDING"), + WindowFrameBound::Following(None) => f.write_str("UNBOUNDED FOLLOWING"), + WindowFrameBound::Preceding(Some(n)) => write!(f, "{} PRECEDING", n), + WindowFrameBound::Following(Some(n)) => write!(f, "{} FOLLOWING", n), + } + } +} + +impl PartialEq for WindowFrameBound { + fn eq(&self, other: &Self) -> bool { + self.cmp(other) == Ordering::Equal + } +} + +impl PartialOrd for WindowFrameBound { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } +} + +impl Ord for WindowFrameBound { + fn cmp(&self, other: &Self) -> Ordering { + self.get_rank().cmp(&other.get_rank()) + } +} + +impl Hash for WindowFrameBound { + fn hash(&self, state: &mut H) { + self.get_rank().hash(state) + } +} + +impl WindowFrameBound { + /// get the rank of this window frame bound. + /// + /// the rank is a tuple of (u8, u64) because we'll firstly compare the kind and then the value + /// which requires special handling e.g. with preceding the larger the value the smaller the + /// rank and also for 0 preceding / following it is the same as current row + fn get_rank(&self) -> (u8, u64) { + match self { + WindowFrameBound::Preceding(None) => (0, 0), + WindowFrameBound::Following(None) => (4, 0), + WindowFrameBound::Preceding(Some(0)) + | WindowFrameBound::CurrentRow + | WindowFrameBound::Following(Some(0)) => (2, 0), + WindowFrameBound::Preceding(Some(v)) => (1, u64::MAX - *v), + WindowFrameBound::Following(Some(v)) => (3, *v), + } + } +} + +/// There are three frame types: ROWS, GROUPS, and RANGE. The frame type determines how the +/// starting and ending boundaries of the frame are measured. +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Hash)] +pub enum WindowFrameUnits { + /// The ROWS frame type means that the starting and ending boundaries for the frame are + /// determined by counting individual rows relative to the current row. + Rows, + /// The RANGE frame type requires that the ORDER BY clause of the window have exactly one + /// term. Call that term "X". With the RANGE frame type, the elements of the frame are + /// determined by computing the value of expression X for all rows in the partition and framing + /// those rows for which the value of X is within a certain range of the value of X for the + /// current row. + Range, + /// The GROUPS frame type means that the starting and ending boundaries are determine + /// by counting "groups" relative to the current group. A "group" is a set of rows that all have + /// equivalent values for all all terms of the window ORDER BY clause. + Groups, +} + +impl fmt::Display for WindowFrameUnits { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + f.write_str(match self { + WindowFrameUnits::Rows => "ROWS", + WindowFrameUnits::Range => "RANGE", + WindowFrameUnits::Groups => "GROUPS", + }) + } +} + +impl From for WindowFrameUnits { + fn from(value: ast::WindowFrameUnits) -> Self { + match value { + ast::WindowFrameUnits::Range => Self::Range, + ast::WindowFrameUnits::Groups => Self::Groups, + ast::WindowFrameUnits::Rows => Self::Rows, + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_window_frame_creation() -> Result<()> { + let window_frame = ast::WindowFrame { + units: ast::WindowFrameUnits::Range, + start_bound: ast::WindowFrameBound::Following(None), + end_bound: None, + }; + let result = WindowFrame::try_from(window_frame); + assert_eq!( + result.err().unwrap().to_string(), + "Execution error: Invalid window frame: start bound cannot be unbounded following" + .to_owned() + ); + + let window_frame = ast::WindowFrame { + units: ast::WindowFrameUnits::Range, + start_bound: ast::WindowFrameBound::Preceding(None), + end_bound: Some(ast::WindowFrameBound::Preceding(None)), + }; + let result = WindowFrame::try_from(window_frame); + assert_eq!( + result.err().unwrap().to_string(), + "Execution error: Invalid window frame: end bound cannot be unbounded preceding" + .to_owned() + ); + + let window_frame = ast::WindowFrame { + units: ast::WindowFrameUnits::Range, + start_bound: ast::WindowFrameBound::Preceding(Some(1)), + end_bound: Some(ast::WindowFrameBound::Preceding(Some(2))), + }; + let result = WindowFrame::try_from(window_frame); + assert_eq!( + result.err().unwrap().to_string(), + "Execution error: Invalid window frame: start bound (1 PRECEDING) cannot be larger than end bound (2 PRECEDING)".to_owned() + ); + + let window_frame = ast::WindowFrame { + units: ast::WindowFrameUnits::Range, + start_bound: ast::WindowFrameBound::Preceding(Some(2)), + end_bound: Some(ast::WindowFrameBound::Preceding(Some(1))), + }; + let result = WindowFrame::try_from(window_frame); + assert_eq!( + result.err().unwrap().to_string(), + "This feature is not implemented: With WindowFrameUnits=RANGE, the bound cannot be 2 PRECEDING or FOLLOWING at the moment".to_owned() + ); + + let window_frame = ast::WindowFrame { + units: ast::WindowFrameUnits::Rows, + start_bound: ast::WindowFrameBound::Preceding(Some(2)), + end_bound: Some(ast::WindowFrameBound::Preceding(Some(1))), + }; + let result = WindowFrame::try_from(window_frame); + assert!(result.is_ok()); + Ok(()) + } + + #[test] + fn test_eq() { + assert_eq!( + WindowFrameBound::Preceding(Some(0)), + WindowFrameBound::CurrentRow + ); + assert_eq!( + WindowFrameBound::CurrentRow, + WindowFrameBound::Following(Some(0)) + ); + assert_eq!( + WindowFrameBound::Following(Some(2)), + WindowFrameBound::Following(Some(2)) + ); + assert_eq!( + WindowFrameBound::Following(None), + WindowFrameBound::Following(None) + ); + assert_eq!( + WindowFrameBound::Preceding(Some(2)), + WindowFrameBound::Preceding(Some(2)) + ); + assert_eq!( + WindowFrameBound::Preceding(None), + WindowFrameBound::Preceding(None) + ); + } + + #[test] + fn test_ord() { + assert!(WindowFrameBound::Preceding(Some(1)) < WindowFrameBound::CurrentRow); + // ! yes this is correct! + assert!( + WindowFrameBound::Preceding(Some(2)) < WindowFrameBound::Preceding(Some(1)) + ); + assert!( + WindowFrameBound::Preceding(Some(u64::MAX)) + < WindowFrameBound::Preceding(Some(u64::MAX - 1)) + ); + assert!( + WindowFrameBound::Preceding(None) + < WindowFrameBound::Preceding(Some(1000000)) + ); + assert!( + WindowFrameBound::Preceding(None) + < WindowFrameBound::Preceding(Some(u64::MAX)) + ); + assert!(WindowFrameBound::Preceding(None) < WindowFrameBound::Following(Some(0))); + assert!( + WindowFrameBound::Preceding(Some(1)) < WindowFrameBound::Following(Some(1)) + ); + assert!(WindowFrameBound::CurrentRow < WindowFrameBound::Following(Some(1))); + assert!( + WindowFrameBound::Following(Some(1)) < WindowFrameBound::Following(Some(2)) + ); + assert!(WindowFrameBound::Following(Some(2)) < WindowFrameBound::Following(None)); + assert!( + WindowFrameBound::Following(Some(u64::MAX)) + < WindowFrameBound::Following(None) + ); + } +} diff --git a/datafusion-expr/src/window_function.rs b/datafusion-expr/src/window_function.rs new file mode 100644 index 000000000000..59523d6540b2 --- /dev/null +++ b/datafusion-expr/src/window_function.rs @@ -0,0 +1,204 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::aggregate_function::AggregateFunction; +use datafusion_common::{DataFusionError, Result}; +use std::{fmt, str::FromStr}; + +/// WindowFunction +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub enum WindowFunction { + /// window function that leverages an aggregate function + AggregateFunction(AggregateFunction), + /// window function that leverages a built-in window function + BuiltInWindowFunction(BuiltInWindowFunction), +} + +impl FromStr for WindowFunction { + type Err = DataFusionError; + fn from_str(name: &str) -> Result { + let name = name.to_lowercase(); + if let Ok(aggregate) = AggregateFunction::from_str(name.as_str()) { + Ok(WindowFunction::AggregateFunction(aggregate)) + } else if let Ok(built_in_function) = + BuiltInWindowFunction::from_str(name.as_str()) + { + Ok(WindowFunction::BuiltInWindowFunction(built_in_function)) + } else { + Err(DataFusionError::Plan(format!( + "There is no window function named {}", + name + ))) + } + } +} + +impl fmt::Display for BuiltInWindowFunction { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match self { + BuiltInWindowFunction::RowNumber => write!(f, "ROW_NUMBER"), + BuiltInWindowFunction::Rank => write!(f, "RANK"), + BuiltInWindowFunction::DenseRank => write!(f, "DENSE_RANK"), + BuiltInWindowFunction::PercentRank => write!(f, "PERCENT_RANK"), + BuiltInWindowFunction::CumeDist => write!(f, "CUME_DIST"), + BuiltInWindowFunction::Ntile => write!(f, "NTILE"), + BuiltInWindowFunction::Lag => write!(f, "LAG"), + BuiltInWindowFunction::Lead => write!(f, "LEAD"), + BuiltInWindowFunction::FirstValue => write!(f, "FIRST_VALUE"), + BuiltInWindowFunction::LastValue => write!(f, "LAST_VALUE"), + BuiltInWindowFunction::NthValue => write!(f, "NTH_VALUE"), + } + } +} + +impl fmt::Display for WindowFunction { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match self { + WindowFunction::AggregateFunction(fun) => fun.fmt(f), + WindowFunction::BuiltInWindowFunction(fun) => fun.fmt(f), + } + } +} + +/// An aggregate function that is part of a built-in window function +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub enum BuiltInWindowFunction { + /// number of the current row within its partition, counting from 1 + RowNumber, + /// rank of the current row with gaps; same as row_number of its first peer + Rank, + /// ank of the current row without gaps; this function counts peer groups + DenseRank, + /// relative rank of the current row: (rank - 1) / (total rows - 1) + PercentRank, + /// relative rank of the current row: (number of rows preceding or peer with current row) / (total rows) + CumeDist, + /// integer ranging from 1 to the argument value, dividing the partition as equally as possible + Ntile, + /// returns value evaluated at the row that is offset rows before the current row within the partition; + /// if there is no such row, instead return default (which must be of the same type as value). + /// Both offset and default are evaluated with respect to the current row. + /// If omitted, offset defaults to 1 and default to null + Lag, + /// returns value evaluated at the row that is offset rows after the current row within the partition; + /// if there is no such row, instead return default (which must be of the same type as value). + /// Both offset and default are evaluated with respect to the current row. + /// If omitted, offset defaults to 1 and default to null + Lead, + /// returns value evaluated at the row that is the first row of the window frame + FirstValue, + /// returns value evaluated at the row that is the last row of the window frame + LastValue, + /// returns value evaluated at the row that is the nth row of the window frame (counting from 1); null if no such row + NthValue, +} + +impl FromStr for BuiltInWindowFunction { + type Err = DataFusionError; + fn from_str(name: &str) -> Result { + Ok(match name.to_uppercase().as_str() { + "ROW_NUMBER" => BuiltInWindowFunction::RowNumber, + "RANK" => BuiltInWindowFunction::Rank, + "DENSE_RANK" => BuiltInWindowFunction::DenseRank, + "PERCENT_RANK" => BuiltInWindowFunction::PercentRank, + "CUME_DIST" => BuiltInWindowFunction::CumeDist, + "NTILE" => BuiltInWindowFunction::Ntile, + "LAG" => BuiltInWindowFunction::Lag, + "LEAD" => BuiltInWindowFunction::Lead, + "FIRST_VALUE" => BuiltInWindowFunction::FirstValue, + "LAST_VALUE" => BuiltInWindowFunction::LastValue, + "NTH_VALUE" => BuiltInWindowFunction::NthValue, + _ => { + return Err(DataFusionError::Plan(format!( + "There is no built-in window function named {}", + name + ))) + } + }) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_window_function_case_insensitive() -> Result<()> { + let names = vec![ + "row_number", + "rank", + "dense_rank", + "percent_rank", + "cume_dist", + "ntile", + "lag", + "lead", + "first_value", + "last_value", + "nth_value", + "min", + "max", + "count", + "avg", + "sum", + ]; + for name in names { + let fun = WindowFunction::from_str(name)?; + let fun2 = WindowFunction::from_str(name.to_uppercase().as_str())?; + assert_eq!(fun, fun2); + assert_eq!(fun.to_string(), name.to_uppercase()); + } + Ok(()) + } + + #[test] + fn test_window_function_from_str() -> Result<()> { + assert_eq!( + WindowFunction::from_str("max")?, + WindowFunction::AggregateFunction(AggregateFunction::Max) + ); + assert_eq!( + WindowFunction::from_str("min")?, + WindowFunction::AggregateFunction(AggregateFunction::Min) + ); + assert_eq!( + WindowFunction::from_str("avg")?, + WindowFunction::AggregateFunction(AggregateFunction::Avg) + ); + assert_eq!( + WindowFunction::from_str("cume_dist")?, + WindowFunction::BuiltInWindowFunction(BuiltInWindowFunction::CumeDist) + ); + assert_eq!( + WindowFunction::from_str("first_value")?, + WindowFunction::BuiltInWindowFunction(BuiltInWindowFunction::FirstValue) + ); + assert_eq!( + WindowFunction::from_str("LAST_value")?, + WindowFunction::BuiltInWindowFunction(BuiltInWindowFunction::LastValue) + ); + assert_eq!( + WindowFunction::from_str("LAG")?, + WindowFunction::BuiltInWindowFunction(BuiltInWindowFunction::Lag) + ); + assert_eq!( + WindowFunction::from_str("LEAD")?, + WindowFunction::BuiltInWindowFunction(BuiltInWindowFunction::Lead) + ); + Ok(()) + } +} diff --git a/datafusion/Cargo.toml b/datafusion/Cargo.toml index c37c204005dd..67c09e5b47c4 100644 --- a/datafusion/Cargo.toml +++ b/datafusion/Cargo.toml @@ -45,17 +45,18 @@ simd = ["arrow/simd"] crypto_expressions = ["md-5", "sha2", "blake2", "blake3"] regex_expressions = ["regex"] unicode_expressions = ["unicode-segmentation"] -# FIXME: add pyarrow support to arrow2 pyarrow = ["pyo3", "arrow/pyarrow"] -pyarrow = ["pyo3"] +pyarrow = ["pyo3", "datafusion-common/pyarrow"] # Used for testing ONLY: causes all values to hash to the same value (test for collisions) force_hash_collisions = [] # Used to enable the avro format avro = ["arrow/io_avro", "arrow/io_avro_async", "arrow/io_avro_compression", "num-traits", "avro-schema"] [dependencies] +datafusion-common = { path = "../datafusion-common", version = "6.0.0" } +datafusion-expr = { path = "../datafusion-expr", version = "6.0.0" } ahash = { version = "0.7", default-features = false } hashbrown = { version = "0.11", features = ["raw"] } -parquet = { package = "parquet2", version = "0.9", default_features = false, features = ["stream"] } +parquet = { package = "parquet2", version = "0.10", default_features = false, features = ["stream"] } sqlparser = "0.13" paste = "^1.0" num_cpus = "1.13.0" @@ -70,7 +71,7 @@ md-5 = { version = "^0.10.0", optional = true } sha2 = { version = "^0.10.1", optional = true } blake2 = { version = "^0.10.2", optional = true } blake3 = { version = "1.0", optional = true } -ordered-float = "2.0" +ordered-float = "2.10" unicode-segmentation = { version = "^1.7.1", optional = true } regex = { version = "^1.4.3", optional = true } lazy_static = { version = "^1.4.0" } @@ -79,7 +80,7 @@ rand = "0.8" num-traits = { version = "0.2", optional = true } pyo3 = { version = "0.15", optional = true } tempfile = "3" -parking_lot = "0.11" +parking_lot = "0.12" avro-schema = { version = "0.2", optional = true } # used to print arrow arrays in a nice columnar format diff --git a/datafusion/benches/sort_limit_query_sql.rs b/datafusion/benches/sort_limit_query_sql.rs index 7fe8e7c1f340..2013a2b1ee50 100644 --- a/datafusion/benches/sort_limit_query_sql.rs +++ b/datafusion/benches/sort_limit_query_sql.rs @@ -25,6 +25,9 @@ use datafusion::datasource::object_store::local::LocalFileSystem; use parking_lot::Mutex; use std::sync::Arc; +extern crate arrow; +extern crate datafusion; + use arrow::datatypes::{DataType, Field, Schema}; use datafusion::datasource::MemTable; diff --git a/datafusion/fuzz-utils/Cargo.toml b/datafusion/fuzz-utils/Cargo.toml index cb1e2e942a9e..b0646450cad8 100644 --- a/datafusion/fuzz-utils/Cargo.toml +++ b/datafusion/fuzz-utils/Cargo.toml @@ -23,7 +23,7 @@ edition = "2021" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] +datafusion-common = { path = "../../datafusion-common" } arrow = { package = "arrow2", version="0.9", features = ["io_print"] } -datafusion = { path = ".." } rand = "0.8" env_logger = "0.9.0" diff --git a/datafusion/fuzz-utils/src/lib.rs b/datafusion/fuzz-utils/src/lib.rs index 81da4801f423..03b6678c917f 100644 --- a/datafusion/fuzz-utils/src/lib.rs +++ b/datafusion/fuzz-utils/src/lib.rs @@ -20,14 +20,14 @@ use arrow::array::Int32Array; use rand::prelude::StdRng; use rand::Rng; -use datafusion::record_batch::RecordBatch; +use datafusion_common::record_batch::RecordBatch; pub use env_logger; /// Extracts the i32 values from the set of batches and returns them as a single Vec pub fn batches_to_vec(batches: &[RecordBatch]) -> Vec> { batches .iter() - .map(|batch| { + .flat_map(|batch| { assert_eq!(batch.num_columns(), 1); batch .column(0) @@ -37,7 +37,6 @@ pub fn batches_to_vec(batches: &[RecordBatch]) -> Vec> { .iter() .map(|v| v.copied()) }) - .flatten() .collect() } @@ -45,8 +44,7 @@ pub fn batches_to_vec(batches: &[RecordBatch]) -> Vec> { pub fn partitions_to_sorted_vec(partitions: &[Vec]) -> Vec> { let mut values: Vec<_> = partitions .iter() - .map(|batches| batches_to_vec(batches).into_iter()) - .flatten() + .flat_map(|batches| batches_to_vec(batches).into_iter()) .collect(); values.sort_unstable(); @@ -62,7 +60,7 @@ pub fn add_empty_batches( batches .into_iter() - .map(|batch| { + .flat_map(|batch| { // insert 0, or 1 empty batches before and after the current batch let empty_batch = RecordBatch::new_empty(schema.clone()); std::iter::repeat(empty_batch.clone()) @@ -70,6 +68,5 @@ pub fn add_empty_batches( .chain(std::iter::once(batch)) .chain(std::iter::repeat(empty_batch).take(rng.gen_range(0..2))) }) - .flatten() .collect() } diff --git a/datafusion/src/avro_to_arrow/arrow_array_reader.rs b/datafusion/src/avro_to_arrow/arrow_array_reader.rs index 0fd50e9b2c1f..8667c77fc9a8 100644 --- a/datafusion/src/avro_to_arrow/arrow_array_reader.rs +++ b/datafusion/src/avro_to_arrow/arrow_array_reader.rs @@ -38,6 +38,7 @@ impl<'a, R: Read> AvroBatchReader { avro_schemas: Vec, codec: Option, file_marker: [u8; 16], + projection: Option>, ) -> Result { let reader = AvroReader::new( read::Decompressor::new( @@ -46,6 +47,7 @@ impl<'a, R: Read> AvroBatchReader { ), avro_schemas, schema.fields.clone(), + projection, ); Ok(Self { reader, schema }) } diff --git a/datafusion/src/avro_to_arrow/reader.rs b/datafusion/src/avro_to_arrow/reader.rs index 7cb640e60560..a7a8e9549dfb 100644 --- a/datafusion/src/avro_to_arrow/reader.rs +++ b/datafusion/src/avro_to_arrow/reader.rs @@ -108,22 +108,16 @@ impl ReaderBuilder { // check if schema should be inferred source.seek(SeekFrom::Start(0))?; - let (mut avro_schemas, mut schema, codec, file_marker) = + let (avro_schemas, schema, codec, file_marker) = read::read_metadata(&mut source)?; - if let Some(proj) = self.projection { - let mut indices: Vec = schema + + let projection = self.projection.map(|proj| { + schema .fields .iter() - .filter(|f| !proj.contains(&f.name)) - .enumerate() - .map(|(i, _)| i) - .collect(); - indices.sort_by(|i1, i2| i2.cmp(i1)); - for i in indices { - avro_schemas.remove(i); - schema.fields.remove(i); - } - } + .map(|f| proj.contains(&f.name)) + .collect::>() + }); Reader::try_new( source, @@ -132,6 +126,7 @@ impl ReaderBuilder { avro_schemas, codec, file_marker, + projection, ) } } @@ -155,6 +150,7 @@ impl<'a, R: Read> Reader { avro_schemas: Vec, codec: Option, file_marker: [u8; 16], + projection: Option>, ) -> Result { Ok(Self { array_reader: AvroBatchReader::try_new( @@ -163,6 +159,7 @@ impl<'a, R: Read> Reader { avro_schemas, codec, file_marker, + projection, )?, schema, batch_size, diff --git a/datafusion/src/avro_to_arrow/schema.rs b/datafusion/src/avro_to_arrow/schema.rs deleted file mode 100644 index 8b137891791f..000000000000 --- a/datafusion/src/avro_to_arrow/schema.rs +++ /dev/null @@ -1 +0,0 @@ - diff --git a/datafusion/src/dataframe.rs b/datafusion/src/dataframe.rs index 9c4a4e4aeb4d..dd1ebeb9cd11 100644 --- a/datafusion/src/dataframe.rs +++ b/datafusion/src/dataframe.rs @@ -19,13 +19,14 @@ use crate::error::Result; use crate::logical_plan::{ - DFSchema, Expr, FunctionRegistry, JoinType, LogicalPlan, Partitioning, + DFSchema, FunctionRegistry, JoinType, LogicalPlan, Partitioning, }; use crate::record_batch::RecordBatch; use std::sync::Arc; use crate::physical_plan::SendableRecordBatchStream; use async_trait::async_trait; +use datafusion_expr::Expr; /// DataFrame represents a logical set of rows with the same named columns. /// Similar to a [Pandas DataFrame](https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.DataFrame.html) or diff --git a/datafusion/src/datasource/file_format/parquet.rs b/datafusion/src/datasource/file_format/parquet.rs index c32f7b2aa9ba..7d5eb1126639 100644 --- a/datafusion/src/datasource/file_format/parquet.rs +++ b/datafusion/src/datasource/file_format/parquet.rs @@ -25,7 +25,7 @@ use arrow::datatypes::Schema; use arrow::datatypes::SchemaRef; use async_trait::async_trait; -use arrow::io::parquet::read::{get_schema, read_metadata}; +use arrow::io::parquet::read::{infer_schema, read_metadata}; use futures::TryStreamExt; use parquet::statistics::{ BinaryStatistics as ParquetBinaryStatistics, @@ -265,7 +265,7 @@ fn summarize_min_max( pub fn fetch_schema(object_reader: Arc) -> Result { let mut reader = object_reader.sync_reader()?; let meta_data = read_metadata(&mut reader)?; - let schema = get_schema(&meta_data)?; + let schema = infer_schema(&meta_data)?; Ok(schema) } @@ -273,7 +273,7 @@ pub fn fetch_schema(object_reader: Arc) -> Result { fn fetch_statistics(object_reader: Arc) -> Result { let mut reader = object_reader.sync_reader()?; let meta_data = read_metadata(&mut reader)?; - let schema = get_schema(&meta_data)?; + let schema = infer_schema(&meta_data)?; let num_fields = schema.fields().len(); let fields = schema.fields().to_vec(); diff --git a/datafusion/src/datasource/listing/helpers.rs b/datafusion/src/datasource/listing/helpers.rs index 0d52966f1065..24d3b3579135 100644 --- a/datafusion/src/datasource/listing/helpers.rs +++ b/datafusion/src/datasource/listing/helpers.rs @@ -34,7 +34,7 @@ use log::debug; use crate::{ error::Result, execution::context::ExecutionContext, - logical_plan::{self, Expr, ExpressionVisitor, Recursion}, + logical_plan::{self, Expr, ExprVisitable, ExpressionVisitor, Recursion}, physical_plan::functions::Volatility, scalar::ScalarValue, }; diff --git a/datafusion/src/datasource/memory.rs b/datafusion/src/datasource/memory.rs index 4b1e09e68e71..ddd81ffad97e 100644 --- a/datafusion/src/datasource/memory.rs +++ b/datafusion/src/datasource/memory.rs @@ -166,7 +166,6 @@ mod tests { use arrow::array::Int32Array; use arrow::datatypes::{DataType, Field, Schema}; use arrow::error::ArrowError; - use futures::StreamExt; use std::collections::BTreeMap; #[tokio::test] diff --git a/datafusion/src/error.rs b/datafusion/src/error.rs index fbad9a97d37b..c2c80b48781e 100644 --- a/datafusion/src/error.rs +++ b/datafusion/src/error.rs @@ -16,173 +16,4 @@ // under the License. //! DataFusion error types - -use std::error; -use std::fmt::{Display, Formatter}; -use std::io; -use std::result; - -use arrow::error::ArrowError; -use parquet::error::ParquetError; -use sqlparser::parser::ParserError; - -/// Result type for operations that could result in an [DataFusionError] -pub type Result = result::Result; - -/// Error type for generic operations that could result in DataFusionError::External -pub type GenericError = Box; - -/// DataFusion error -#[derive(Debug)] -#[allow(missing_docs)] -pub enum DataFusionError { - /// Error returned by arrow. - ArrowError(ArrowError), - /// Wraps an error from the Parquet crate - ParquetError(ParquetError), - /// Error associated to I/O operations and associated traits. - IoError(io::Error), - /// Error returned when SQL is syntactically incorrect. - SQL(ParserError), - /// Error returned on a branch that we know it is possible - /// but to which we still have no implementation for. - /// Often, these errors are tracked in our issue tracker. - NotImplemented(String), - /// Error returned as a consequence of an error in DataFusion. - /// This error should not happen in normal usage of DataFusion. - // DataFusions has internal invariants that we are unable to ask the compiler to check for us. - // This error is raised when one of those invariants is not verified during execution. - Internal(String), - /// This error happens whenever a plan is not valid. Examples include - /// impossible casts, schema inference not possible and non-unique column names. - Plan(String), - /// Error returned during execution of the query. - /// Examples include files not found, errors in parsing certain types. - Execution(String), - /// This error is thrown when a consumer cannot acquire memory from the Memory Manager - /// we can just cancel the execution of the partition. - ResourcesExhausted(String), - /// Errors originating from outside DataFusion's core codebase. - /// For example, a custom S3Error from the crate datafusion-objectstore-s3 - External(GenericError), -} - -impl From for DataFusionError { - fn from(e: io::Error) -> Self { - DataFusionError::IoError(e) - } -} - -impl From for DataFusionError { - fn from(e: ArrowError) -> Self { - DataFusionError::ArrowError(e) - } -} - -impl From for ArrowError { - fn from(e: DataFusionError) -> Self { - match e { - DataFusionError::ArrowError(e) => e, - DataFusionError::External(e) => ArrowError::External("".to_string(), e), - other => ArrowError::External("".to_string(), Box::new(other)), - } - } -} - -impl From for DataFusionError { - fn from(e: ParquetError) -> Self { - DataFusionError::ParquetError(e) - } -} - -impl From for DataFusionError { - fn from(e: ParserError) -> Self { - DataFusionError::SQL(e) - } -} - -impl From for DataFusionError { - fn from(err: GenericError) -> Self { - DataFusionError::External(err) - } -} - -impl Display for DataFusionError { - fn fmt(&self, f: &mut Formatter) -> std::fmt::Result { - match *self { - DataFusionError::ArrowError(ref desc) => write!(f, "Arrow error: {}", desc), - DataFusionError::ParquetError(ref desc) => { - write!(f, "Parquet error: {}", desc) - } - DataFusionError::IoError(ref desc) => write!(f, "IO error: {}", desc), - DataFusionError::SQL(ref desc) => { - write!(f, "SQL error: {:?}", desc) - } - DataFusionError::NotImplemented(ref desc) => { - write!(f, "This feature is not implemented: {}", desc) - } - DataFusionError::Internal(ref desc) => { - write!(f, "Internal error: {}. This was likely caused by a bug in DataFusion's \ - code and we would welcome that you file an bug report in our issue tracker", desc) - } - DataFusionError::Plan(ref desc) => { - write!(f, "Error during planning: {}", desc) - } - DataFusionError::Execution(ref desc) => { - write!(f, "Execution error: {}", desc) - } - DataFusionError::ResourcesExhausted(ref desc) => { - write!(f, "Resources exhausted: {}", desc) - } - DataFusionError::External(ref desc) => { - write!(f, "External error: {}", desc) - } - } - } -} - -impl error::Error for DataFusionError {} - -#[cfg(test)] -mod test { - use crate::error::DataFusionError; - use arrow::error::ArrowError; - - #[test] - fn arrow_error_to_datafusion() { - let res = return_arrow_error().unwrap_err(); - assert_eq!( - res.to_string(), - "External error: Error during planning: foo" - ); - } - - #[test] - fn datafusion_error_to_arrow() { - let res = return_datafusion_error().unwrap_err(); - assert_eq!( - res.to_string(), - "Arrow error: Invalid argument error: Schema error: bar" - ); - } - - /// Model what happens when implementing SendableRecrordBatchStream: - /// DataFusion code needs to return an ArrowError - #[allow(clippy::try_err)] - fn return_arrow_error() -> arrow::error::Result<()> { - // Expect the '?' to work - let _foo = Err(DataFusionError::Plan("foo".to_string()))?; - Ok(()) - } - - /// Model what happens when using arrow kernels in DataFusion - /// code: need to turn an ArrowError into a DataFusionError - #[allow(clippy::try_err)] - fn return_datafusion_error() -> crate::error::Result<()> { - // Expect the '?' to work - let _bar = Err(ArrowError::InvalidArgumentError( - "Schema error: bar".to_string(), - ))?; - Ok(()) - } -} +pub use datafusion_common::{DataFusionError, Result}; diff --git a/datafusion/src/execution/context.rs b/datafusion/src/execution/context.rs index 9aa2b476bc9f..09346a322abf 100644 --- a/datafusion/src/execution/context.rs +++ b/datafusion/src/execution/context.rs @@ -38,6 +38,8 @@ use crate::{ hash_build_probe_order::HashBuildProbeOrder, optimizer::PhysicalOptimizerRule, }, }; +use arrow::array::ArrayRef; +use arrow::chunk::Chunk; use log::debug; use parking_lot::Mutex; use std::collections::{HashMap, HashSet}; @@ -50,12 +52,13 @@ use futures::{StreamExt, TryStreamExt}; use tokio::task::{self, JoinHandle}; use crate::record_batch::RecordBatch; -use arrow::datatypes::SchemaRef; +use arrow::datatypes::{PhysicalType, SchemaRef}; use arrow::error::{ArrowError, Result as ArrowResult}; use arrow::io::csv; use arrow::io::parquet; -use arrow::io::parquet::write::FallibleStreamingIterator; -use arrow::io::parquet::write::WriteOptions; +use arrow::io::parquet::write::{ + row_group_iter, to_parquet_schema, Encoding, WriteOptions, +}; use crate::catalog::{ catalog::{CatalogProvider, MemoryCatalogProvider}, @@ -791,56 +794,54 @@ impl ExecutionContext { let stream = plan.execute(i, runtime.clone()).await?; let handle: JoinHandle> = task::spawn(async move { - let parquet_schema = parquet::write::to_parquet_schema(&schema)?; + let parquet_schema = to_parquet_schema(&schema)?; let a = parquet_schema.clone(); - let row_groups = stream.map(|batch: ArrowResult| { - // map each record batch to a row group - let r = batch.map(|batch| { - let batch_cols = batch.columns().to_vec(); - // column chunk in row group - let pages = - batch_cols - .into_iter() - .zip(a.columns().iter().cloned()) - .map(move |(array, descriptor)| { - parquet::write::array_to_pages( - array.as_ref(), - descriptor, - options, - parquet::write::Encoding::Plain, - ) - .map(move |pages| { - let encoded_pages = - parquet::write::DynIter::new( - pages.map(|x| Ok(x?)), - ); - let compressed_pages = - parquet::write::Compressor::new( - encoded_pages, - options.compression, - vec![], - ) - .map_err(ArrowError::from); - parquet::write::DynStreamingIterator::new( - compressed_pages, - ) - }) - }); - parquet::write::DynIter::new(pages) + let encodings: Vec = schema + .fields() + .iter() + .map(|field| match field.data_type().to_physical_type() { + PhysicalType::Binary + | PhysicalType::LargeBinary + | PhysicalType::Utf8 + | PhysicalType::LargeUtf8 => { + Encoding::DeltaLengthByteArray + } + _ => Encoding::Plain, + }) + .collect(); + + let mut row_groups = + stream.map(|batch: ArrowResult| { + // map each record batch to a row group + batch.map(|batch| { + // column chunk in row group + let chunk: Chunk = batch.into(); + let len = chunk.len(); + ( + row_group_iter( + chunk, + encodings.clone(), + a.columns().to_vec(), + options, + ), + len, + ) + }) }); - async { r } - }); - Ok(parquet::write::stream::write_stream( + let mut writer = parquet::write::FileWriter::try_new( &mut file, - row_groups, schema.as_ref().clone(), - parquet_schema, options, - None, - ) - .await?) + )?; + writer.start()?; + while let Some(row_group) = row_groups.next().await { + let (group, len) = row_group?; + writer.write(group, len)?; + } + let (written, _) = writer.end(None)?; + Ok(written) }); tasks.push(handle); } @@ -1204,7 +1205,7 @@ impl ExecutionProps { var_type: VarType, provider: Arc, ) -> Option> { - let mut var_providers = self.var_providers.take().unwrap_or_default(); + let mut var_providers = self.var_providers.take().unwrap_or_else(HashMap::new); let old_provider = var_providers.insert(var_type, provider); @@ -1345,11 +1346,9 @@ mod tests { use super::*; use crate::execution::context::QueryPlanner; use crate::field_util::{FieldExt, SchemaExt}; - use crate::logical_plan::plan::Projection; - use crate::logical_plan::TableScan; use crate::logical_plan::{binary_expr, lit, Operator}; + use crate::physical_plan::collect; use crate::physical_plan::functions::{make_scalar_function, Volatility}; - use crate::physical_plan::{collect, collect_partitioned}; use crate::record_batch::RecordBatch; use crate::test; use crate::variable::VarType; @@ -1367,8 +1366,7 @@ mod tests { use arrow::compute::arithmetics::basic::add; use arrow::datatypes::*; use arrow::io::parquet::write::{ - to_parquet_schema, write_file, Compression, Encoding, RowGroupIterator, Version, - WriteOptions, + Compression, Encoding, FileWriter, RowGroupIterator, Version, WriteOptions, }; use async_trait::async_trait; use std::collections::BTreeMap; @@ -1377,7 +1375,6 @@ mod tests { use std::thread::{self, JoinHandle}; use std::{io::prelude::*, sync::Mutex}; use tempfile::TempDir; - use test::*; #[tokio::test] async fn shared_memory_and_disk_manager() { @@ -1413,100 +1410,6 @@ mod tests { )); } - #[test] - fn optimize_explain() { - let schema = Schema::new(vec![Field::new("id", DataType::Int32, false)]); - - let plan = LogicalPlanBuilder::scan_empty(Some("employee"), &schema, None) - .unwrap() - .explain(true, false) - .unwrap() - .build() - .unwrap(); - - if let LogicalPlan::Explain(e) = &plan { - assert_eq!(e.stringified_plans.len(), 1); - } else { - panic!("plan was not an explain: {:?}", plan); - } - - // now optimize the plan and expect to see more plans - let optimized_plan = ExecutionContext::new().optimize(&plan).unwrap(); - if let LogicalPlan::Explain(e) = &optimized_plan { - // should have more than one plan - assert!( - e.stringified_plans.len() > 1, - "plans: {:#?}", - e.stringified_plans - ); - // should have at least one optimized plan - let opt = e - .stringified_plans - .iter() - .any(|p| matches!(p.plan_type, PlanType::OptimizedLogicalPlan { .. })); - - assert!(opt, "plans: {:#?}", e.stringified_plans); - } else { - panic!("plan was not an explain: {:?}", plan); - } - } - - #[tokio::test] - async fn parallel_projection() -> Result<()> { - let partition_count = 4; - let results = execute("SELECT c1, c2 FROM test", partition_count).await?; - - let expected = vec![ - "+----+----+", - "| c1 | c2 |", - "+----+----+", - "| 3 | 1 |", - "| 3 | 2 |", - "| 3 | 3 |", - "| 3 | 4 |", - "| 3 | 5 |", - "| 3 | 6 |", - "| 3 | 7 |", - "| 3 | 8 |", - "| 3 | 9 |", - "| 3 | 10 |", - "| 2 | 1 |", - "| 2 | 2 |", - "| 2 | 3 |", - "| 2 | 4 |", - "| 2 | 5 |", - "| 2 | 6 |", - "| 2 | 7 |", - "| 2 | 8 |", - "| 2 | 9 |", - "| 2 | 10 |", - "| 1 | 1 |", - "| 1 | 2 |", - "| 1 | 3 |", - "| 1 | 4 |", - "| 1 | 5 |", - "| 1 | 6 |", - "| 1 | 7 |", - "| 1 | 8 |", - "| 1 | 9 |", - "| 1 | 10 |", - "| 0 | 1 |", - "| 0 | 2 |", - "| 0 | 3 |", - "| 0 | 4 |", - "| 0 | 5 |", - "| 0 | 6 |", - "| 0 | 7 |", - "| 0 | 8 |", - "| 0 | 9 |", - "| 0 | 10 |", - "+----+----+", - ]; - assert_batches_sorted_eq!(expected, &results); - - Ok(()) - } - #[tokio::test] async fn create_variable_expr() -> Result<()> { let tmp_dir = TempDir::new()?; @@ -1551,184 +1454,6 @@ mod tests { Ok(()) } - #[tokio::test] - async fn parallel_query_with_filter() -> Result<()> { - let tmp_dir = TempDir::new()?; - let partition_count = 4; - let ctx = create_ctx(&tmp_dir, partition_count).await?; - - let logical_plan = - ctx.create_logical_plan("SELECT c1, c2 FROM test WHERE c1 > 0 AND c1 < 3")?; - let logical_plan = ctx.optimize(&logical_plan)?; - - let physical_plan = ctx.create_physical_plan(&logical_plan).await?; - - let runtime = ctx.state.lock().runtime_env.clone(); - let results = collect_partitioned(physical_plan, runtime).await?; - - // note that the order of partitions is not deterministic - let mut num_rows = 0; - for partition in &results { - for batch in partition { - num_rows += batch.num_rows(); - } - } - assert_eq!(20, num_rows); - - let results: Vec = results.into_iter().flatten().collect(); - let expected = vec![ - "+----+----+", - "| c1 | c2 |", - "+----+----+", - "| 1 | 1 |", - "| 1 | 10 |", - "| 1 | 2 |", - "| 1 | 3 |", - "| 1 | 4 |", - "| 1 | 5 |", - "| 1 | 6 |", - "| 1 | 7 |", - "| 1 | 8 |", - "| 1 | 9 |", - "| 2 | 1 |", - "| 2 | 10 |", - "| 2 | 2 |", - "| 2 | 3 |", - "| 2 | 4 |", - "| 2 | 5 |", - "| 2 | 6 |", - "| 2 | 7 |", - "| 2 | 8 |", - "| 2 | 9 |", - "+----+----+", - ]; - assert_batches_sorted_eq!(expected, &results); - - Ok(()) - } - - #[tokio::test] - async fn projection_on_table_scan() -> Result<()> { - let tmp_dir = TempDir::new()?; - let partition_count = 4; - let ctx = create_ctx(&tmp_dir, partition_count).await?; - let runtime = ctx.state.lock().runtime_env.clone(); - - let table = ctx.table("test")?; - let logical_plan = LogicalPlanBuilder::from(table.to_logical_plan()) - .project(vec![col("c2")])? - .build()?; - - let optimized_plan = ctx.optimize(&logical_plan)?; - match &optimized_plan { - LogicalPlan::Projection(Projection { input, .. }) => match &**input { - LogicalPlan::TableScan(TableScan { - source, - projected_schema, - .. - }) => { - assert_eq!(source.schema().fields().len(), 3); - assert_eq!(projected_schema.fields().len(), 1); - } - _ => panic!("input to projection should be TableScan"), - }, - _ => panic!("expect optimized_plan to be projection"), - } - - let expected = "Projection: #test.c2\ - \n TableScan: test projection=Some([1])"; - assert_eq!(format!("{:?}", optimized_plan), expected); - - let physical_plan = ctx.create_physical_plan(&optimized_plan).await?; - - assert_eq!(1, physical_plan.schema().fields().len()); - assert_eq!("c2", physical_plan.schema().field(0).name()); - - let batches = collect(physical_plan, runtime).await?; - assert_eq!(40, batches.iter().map(|x| x.num_rows()).sum::()); - - Ok(()) - } - - #[tokio::test] - async fn preserve_nullability_on_projection() -> Result<()> { - let tmp_dir = TempDir::new()?; - let ctx = create_ctx(&tmp_dir, 1).await?; - - let schema: Schema = ctx.table("test").unwrap().schema().clone().into(); - assert!(!schema.field_with_name("c1")?.is_nullable()); - - let plan = LogicalPlanBuilder::scan_empty(None, &schema, None)? - .project(vec![col("c1")])? - .build()?; - - let plan = ctx.optimize(&plan)?; - let physical_plan = ctx.create_physical_plan(&Arc::new(plan)).await?; - assert!(!physical_plan.schema().field_with_name("c1")?.is_nullable()); - Ok(()) - } - - #[tokio::test] - async fn projection_on_memory_scan() -> Result<()> { - let schema = Schema::new(vec![ - Field::new("a", DataType::Int32, false), - Field::new("b", DataType::Int32, false), - Field::new("c", DataType::Int32, false), - ]); - let schema = SchemaRef::new(schema); - - let partitions = vec![vec![RecordBatch::try_new( - schema.clone(), - vec![ - Arc::new(Int32Array::from_slice(&[1, 10, 10, 100])), - Arc::new(Int32Array::from_slice(&[2, 12, 12, 120])), - Arc::new(Int32Array::from_slice(&[3, 12, 12, 120])), - ], - )?]]; - - let plan = LogicalPlanBuilder::scan_memory(partitions, schema, None)? - .project(vec![col("b")])? - .build()?; - assert_fields_eq(&plan, vec!["b"]); - - let ctx = ExecutionContext::new(); - let optimized_plan = ctx.optimize(&plan)?; - match &optimized_plan { - LogicalPlan::Projection(Projection { input, .. }) => match &**input { - LogicalPlan::TableScan(TableScan { - source, - projected_schema, - .. - }) => { - assert_eq!(source.schema().fields().len(), 3); - assert_eq!(projected_schema.fields().len(), 1); - } - _ => panic!("input to projection should be InMemoryScan"), - }, - _ => panic!("expect optimized_plan to be projection"), - } - - let expected = format!( - "Projection: #{}.b\ - \n TableScan: {} projection=Some([1])", - UNNAMED_TABLE, UNNAMED_TABLE - ); - assert_eq!(format!("{:?}", optimized_plan), expected); - - let physical_plan = ctx.create_physical_plan(&optimized_plan).await?; - - assert_eq!(1, physical_plan.schema().fields().len()); - assert_eq!("b", physical_plan.schema().field(0).name()); - - let runtime = ctx.state.lock().runtime_env.clone(); - let batches = collect(physical_plan, runtime).await?; - assert_eq!(1, batches.len()); - assert_eq!(1, batches[0].num_columns()); - assert_eq!(4, batches[0].num_rows()); - - Ok(()) - } - #[tokio::test] async fn sort() -> Result<()> { let results = @@ -3684,7 +3409,6 @@ mod tests { let ids = Arc::new(Int32Array::from_slice(&[i as i32])); let names = Arc::new(Utf8Array::::from_slice(&["test"])); let schema_ref = schema.as_ref(); - let parquet_schema = to_parquet_schema(schema_ref).unwrap(); let iter = vec![Ok(Chunk::new(vec![ids as ArrayRef, names as ArrayRef]))]; let row_groups = RowGroupIterator::try_new( iter.into_iter(), @@ -3693,16 +3417,14 @@ mod tests { vec![Encoding::Plain, Encoding::Plain], ) .unwrap(); - - let _ = write_file( - &mut file, - row_groups, - schema_ref, - parquet_schema, - options, - None, - ) - .unwrap(); + let mut writer = + FileWriter::try_new(&mut file, schema_ref.clone(), options).unwrap(); + writer.start().unwrap(); + for rg in row_groups { + let (group, len) = rg.unwrap(); + writer.write(group, len).unwrap(); + } + writer.end(None).unwrap(); } } @@ -3718,6 +3440,173 @@ mod tests { assert_eq!(result[0].schema().metadata(), result[1].schema().metadata()); } + #[tokio::test] + async fn normalized_column_identifiers() { + // create local execution context + let mut ctx = ExecutionContext::new(); + + // register csv file with the execution context + ctx.register_csv( + "case_insensitive_test", + "tests/example.csv", + CsvReadOptions::new(), + ) + .await + .unwrap(); + + let sql = "SELECT A, b FROM case_insensitive_test"; + let result = plan_and_collect(&mut ctx, sql) + .await + .expect("ran plan correctly"); + let expected = vec![ + "+---+---+", + "| a | b |", + "+---+---+", + "| 1 | 2 |", + "+---+---+", + ]; + assert_batches_sorted_eq!(expected, &result); + + let sql = "SELECT t.A, b FROM case_insensitive_test AS t"; + let result = plan_and_collect(&mut ctx, sql) + .await + .expect("ran plan correctly"); + let expected = vec![ + "+---+---+", + "| a | b |", + "+---+---+", + "| 1 | 2 |", + "+---+---+", + ]; + assert_batches_sorted_eq!(expected, &result); + + // Aliases + + let sql = "SELECT t.A as x, b FROM case_insensitive_test AS t"; + let result = plan_and_collect(&mut ctx, sql) + .await + .expect("ran plan correctly"); + let expected = vec![ + "+---+---+", + "| x | b |", + "+---+---+", + "| 1 | 2 |", + "+---+---+", + ]; + assert_batches_sorted_eq!(expected, &result); + + let sql = "SELECT t.A AS X, b FROM case_insensitive_test AS t"; + let result = plan_and_collect(&mut ctx, sql) + .await + .expect("ran plan correctly"); + let expected = vec![ + "+---+---+", + "| x | b |", + "+---+---+", + "| 1 | 2 |", + "+---+---+", + ]; + assert_batches_sorted_eq!(expected, &result); + + let sql = r#"SELECT t.A AS "X", b FROM case_insensitive_test AS t"#; + let result = plan_and_collect(&mut ctx, sql) + .await + .expect("ran plan correctly"); + let expected = vec![ + "+---+---+", + "| X | b |", + "+---+---+", + "| 1 | 2 |", + "+---+---+", + ]; + assert_batches_sorted_eq!(expected, &result); + + // Order by + + let sql = "SELECT t.A AS x, b FROM case_insensitive_test AS t ORDER BY x"; + let result = plan_and_collect(&mut ctx, sql) + .await + .expect("ran plan correctly"); + let expected = vec![ + "+---+---+", + "| x | b |", + "+---+---+", + "| 1 | 2 |", + "+---+---+", + ]; + assert_batches_sorted_eq!(expected, &result); + + let sql = "SELECT t.A AS x, b FROM case_insensitive_test AS t ORDER BY X"; + let result = plan_and_collect(&mut ctx, sql) + .await + .expect("ran plan correctly"); + let expected = vec![ + "+---+---+", + "| x | b |", + "+---+---+", + "| 1 | 2 |", + "+---+---+", + ]; + assert_batches_sorted_eq!(expected, &result); + + let sql = r#"SELECT t.A AS "X", b FROM case_insensitive_test AS t ORDER BY "X""#; + let result = plan_and_collect(&mut ctx, sql) + .await + .expect("ran plan correctly"); + let expected = vec![ + "+---+---+", + "| X | b |", + "+---+---+", + "| 1 | 2 |", + "+---+---+", + ]; + assert_batches_sorted_eq!(expected, &result); + + // Where + + let sql = "SELECT a, b FROM case_insensitive_test where A IS NOT null"; + let result = plan_and_collect(&mut ctx, sql) + .await + .expect("ran plan correctly"); + let expected = vec![ + "+---+---+", + "| a | b |", + "+---+---+", + "| 1 | 2 |", + "+---+---+", + ]; + assert_batches_sorted_eq!(expected, &result); + + // Group by + + let sql = "SELECT a as x, count(*) as c FROM case_insensitive_test GROUP BY X"; + let result = plan_and_collect(&mut ctx, sql) + .await + .expect("ran plan correctly"); + let expected = vec![ + "+---+---+", + "| x | c |", + "+---+---+", + "| 1 | 1 |", + "+---+---+", + ]; + assert_batches_sorted_eq!(expected, &result); + + let sql = + r#"SELECT a as "X", count(*) as c FROM case_insensitive_test GROUP BY "X""#; + let result = plan_and_collect(&mut ctx, sql) + .await + .expect("ran plan correctly"); + let expected = vec![ + "+---+---+", + "| X | c |", + "+---+---+", + "| 1 | 1 |", + "+---+---+", + ]; + assert_batches_sorted_eq!(expected, &result); + } + struct MyPhysicalPlanner {} #[async_trait] diff --git a/datafusion/src/execution/dataframe_impl.rs b/datafusion/src/execution/dataframe_impl.rs index 1ad95950cdd7..73252f3aafa0 100644 --- a/datafusion/src/execution/dataframe_impl.rs +++ b/datafusion/src/execution/dataframe_impl.rs @@ -324,12 +324,12 @@ mod tests { use super::*; use crate::execution::options::CsvReadOptions; - use crate::physical_plan::functions::ScalarFunctionImplementation; - use crate::physical_plan::functions::Volatility; use crate::physical_plan::{window_functions, ColumnarValue}; use crate::{assert_batches_sorted_eq, execution::context::ExecutionContext}; use crate::{logical_plan::*, test_util}; use arrow::datatypes::DataType; + use datafusion_expr::ScalarFunctionImplementation; + use datafusion_expr::Volatility; #[tokio::test] async fn select_columns() -> Result<()> { diff --git a/datafusion/src/field_util.rs b/datafusion/src/field_util.rs index 2dfccb73092d..4ad799070da1 100644 --- a/datafusion/src/field_util.rs +++ b/datafusion/src/field_util.rs @@ -15,476 +15,6 @@ // specific language governing permissions and limitations // under the License. -//! Utility functions for complex field access +//! Field utils reimported from datafusion-common -use arrow::array::{ArrayRef, StructArray}; -use arrow::datatypes::{DataType, Field, Metadata, Schema}; -use arrow::error::ArrowError; -use std::borrow::Borrow; -use std::collections::BTreeMap; - -use crate::error::{DataFusionError, Result}; -use crate::scalar::ScalarValue; - -/// Returns the field access indexed by `key` from a [`DataType::List`] or [`DataType::Struct`] -/// # Error -/// Errors if -/// * the `data_type` is not a Struct or, -/// * there is no field key is not of the required index type -pub fn get_indexed_field(data_type: &DataType, key: &ScalarValue) -> Result { - match (data_type, key) { - (DataType::List(lt), ScalarValue::Int64(Some(i))) => { - if *i < 0 { - Err(DataFusionError::Plan(format!( - "List based indexed access requires a positive int, was {0}", - i - ))) - } else { - Ok(Field::new(&i.to_string(), lt.data_type().clone(), false)) - } - } - (DataType::Struct(fields), ScalarValue::Utf8(Some(s))) => { - if s.is_empty() { - Err(DataFusionError::Plan( - "Struct based indexed access requires a non empty string".to_string(), - )) - } else { - let field = fields.iter().find(|f| f.name() == s); - match field { - None => Err(DataFusionError::Plan(format!( - "Field {} not found in struct", - s - ))), - Some(f) => Ok(f.clone()), - } - } - } - (DataType::Struct(_), _) => Err(DataFusionError::Plan( - "Only utf8 strings are valid as an indexed field in a struct".to_string(), - )), - (DataType::List(_), _) => Err(DataFusionError::Plan( - "Only ints are valid as an indexed field in a list".to_string(), - )), - _ => Err(DataFusionError::Plan( - "The expression to get an indexed field is only valid for `List` types" - .to_string(), - )), - } -} - -/// Imitate arrow-rs StructArray behavior by extending arrow2 StructArray -pub trait StructArrayExt { - /// Return field names in this struct array - fn column_names(&self) -> Vec<&str>; - /// Return child array whose field name equals to column_name - fn column_by_name(&self, column_name: &str) -> Option<&ArrayRef>; - /// Return the number of fields in this struct array - fn num_columns(&self) -> usize; - /// Return the column at the position - fn column(&self, pos: usize) -> ArrayRef; -} - -impl StructArrayExt for StructArray { - fn column_names(&self) -> Vec<&str> { - self.fields().iter().map(|f| f.name.as_str()).collect() - } - - fn column_by_name(&self, column_name: &str) -> Option<&ArrayRef> { - self.fields() - .iter() - .position(|c| c.name() == column_name) - .map(|pos| self.values()[pos].borrow()) - } - - fn num_columns(&self) -> usize { - self.fields().len() - } - - fn column(&self, pos: usize) -> ArrayRef { - self.values()[pos].clone() - } -} - -/// Converts a list of field / array pairs to a struct array -pub fn struct_array_from(pairs: Vec<(Field, ArrayRef)>) -> StructArray { - let fields: Vec = pairs.iter().map(|v| v.0.clone()).collect(); - let values = pairs.iter().map(|v| v.1.clone()).collect(); - StructArray::from_data(DataType::Struct(fields), values, None) -} - -/// Imitate arrow-rs Schema behavior by extending arrow2 Schema -pub trait SchemaExt { - /// Creates a new [`Schema`] from a sequence of [`Field`] values. - /// - /// # Example - /// - /// ``` - /// use arrow::datatypes::{Field, DataType, Schema}; - /// use datafusion::field_util::SchemaExt; - /// let field_a = Field::new("a", DataType::Int64, false); - /// let field_b = Field::new("b", DataType::Boolean, false); - /// - /// let schema = Schema::new(vec![field_a, field_b]); - /// ``` - fn new(fields: Vec) -> Self; - - /// Creates a new [`Schema`] from a sequence of [`Field`] values and [`arrow::datatypes::Metadata`] - /// - /// # Example - /// - /// ``` - /// use std::collections::BTreeMap; - /// use arrow::datatypes::{Field, DataType, Schema}; - /// use datafusion::field_util::SchemaExt; - /// - /// let field_a = Field::new("a", DataType::Int64, false); - /// let field_b = Field::new("b", DataType::Boolean, false); - /// - /// let schema_metadata: BTreeMap = - /// vec![("baz".to_string(), "barf".to_string())] - /// .into_iter() - /// .collect(); - /// let schema = Schema::new_with_metadata(vec![field_a, field_b], schema_metadata); - /// ``` - fn new_with_metadata(fields: Vec, metadata: Metadata) -> Self; - - /// Creates an empty [`Schema`]. - fn empty() -> Self; - - /// Look up a column by name and return a immutable reference to the column along with - /// its index. - fn column_with_name(&self, name: &str) -> Option<(usize, &Field)>; - - /// Returns the first [`Field`] named `name`. - fn field_with_name(&self, name: &str) -> Result<&Field>; - - /// Find the index of the column with the given name. - fn index_of(&self, name: &str) -> Result; - - /// Returns the [`Field`] at position `i`. - /// # Panics - /// Panics iff `i` is larger than the number of fields in this [`Schema`]. - fn field(&self, index: usize) -> &Field; - - /// Returns all [`Field`]s in this schema. - fn fields(&self) -> &[Field]; - - /// Returns an immutable reference to the Map of custom metadata key-value pairs. - fn metadata(&self) -> &BTreeMap; - - /// Merge schema into self if it is compatible. Struct fields will be merged recursively. - /// - /// Example: - /// - /// ``` - /// use arrow::datatypes::*; - /// use datafusion::field_util::SchemaExt; - /// - /// let merged = Schema::try_merge(vec![ - /// Schema::new(vec![ - /// Field::new("c1", DataType::Int64, false), - /// Field::new("c2", DataType::Utf8, false), - /// ]), - /// Schema::new(vec![ - /// Field::new("c1", DataType::Int64, true), - /// Field::new("c2", DataType::Utf8, false), - /// Field::new("c3", DataType::Utf8, false), - /// ]), - /// ]).unwrap(); - /// - /// assert_eq!( - /// merged, - /// Schema::new(vec![ - /// Field::new("c1", DataType::Int64, true), - /// Field::new("c2", DataType::Utf8, false), - /// Field::new("c3", DataType::Utf8, false), - /// ]), - /// ); - /// ``` - fn try_merge(schemas: impl IntoIterator) -> Result - where - Self: Sized; - - /// Return the field names - fn field_names(&self) -> Vec; - - /// Returns a new schema with only the specified columns in the new schema - /// This carries metadata from the parent schema over as well - fn project(&self, indices: &[usize]) -> Result; -} - -impl SchemaExt for Schema { - fn new(fields: Vec) -> Self { - Self::from(fields) - } - - fn new_with_metadata(fields: Vec, metadata: Metadata) -> Self { - Self::new(fields).with_metadata(metadata) - } - - fn empty() -> Self { - Self::from(vec![]) - } - - fn column_with_name(&self, name: &str) -> Option<(usize, &Field)> { - self.fields.iter().enumerate().find(|(_, f)| f.name == name) - } - - fn field_with_name(&self, name: &str) -> Result<&Field> { - Ok(&self.fields[self.index_of(name)?]) - } - - fn index_of(&self, name: &str) -> Result { - self.column_with_name(name).map(|(i, _f)| i).ok_or_else(|| { - DataFusionError::ArrowError(ArrowError::InvalidArgumentError(format!( - "Unable to get field named \"{}\". Valid fields: {:?}", - name, - self.field_names() - ))) - }) - } - - fn field(&self, index: usize) -> &Field { - &self.fields[index] - } - - #[inline] - fn fields(&self) -> &[Field] { - &self.fields - } - - #[inline] - fn metadata(&self) -> &BTreeMap { - &self.metadata - } - - fn try_merge(schemas: impl IntoIterator) -> Result { - schemas - .into_iter() - .try_fold(Self::empty(), |mut merged, schema| { - let Schema { metadata, fields } = schema; - for (key, value) in metadata.into_iter() { - // merge metadata - if let Some(old_val) = merged.metadata.get(&key) { - if old_val != &value { - return Err(DataFusionError::ArrowError( - ArrowError::InvalidArgumentError( - "Fail to merge schema due to conflicting metadata." - .to_string(), - ), - )); - } - } - merged.metadata.insert(key, value); - } - // merge fields - for field in fields.into_iter() { - let mut new_field = true; - for merged_field in &mut merged.fields { - if field.name() != merged_field.name() { - continue; - } - new_field = false; - merged_field.try_merge(&field)? - } - // found a new field, add to field list - if new_field { - merged.fields.push(field); - } - } - Ok(merged) - }) - } - - fn field_names(&self) -> Vec { - self.fields.iter().map(|f| f.name.to_string()).collect() - } - - fn project(&self, indices: &[usize]) -> Result { - let new_fields = indices - .iter() - .map(|i| { - self.fields.get(*i).cloned().ok_or_else(|| { - DataFusionError::ArrowError(ArrowError::InvalidArgumentError( - format!( - "project index {} out of bounds, max field {}", - i, - self.fields().len() - ), - )) - }) - }) - .collect::>>()?; - Ok(Self::new_with_metadata(new_fields, self.metadata.clone())) - } -} - -/// Imitate arrow-rs Field behavior by extending arrow2 Field -pub trait FieldExt { - /// The field name - fn name(&self) -> &str; - - /// Whether the field is nullable - fn is_nullable(&self) -> bool; - - /// Returns the field metadata - fn metadata(&self) -> &BTreeMap; - - /// Merge field into self if it is compatible. Struct will be merged recursively. - /// NOTE: `self` may be updated to unexpected state in case of merge failure. - /// - /// Example: - /// - /// ``` - /// use arrow2::datatypes::*; - /// - /// let mut field = Field::new("c1", DataType::Int64, false); - /// assert!(field.try_merge(&Field::new("c1", DataType::Int64, true)).is_ok()); - /// assert!(field.is_nullable()); - /// ``` - fn try_merge(&mut self, from: &Field) -> Result<()>; - - /// Sets the `Field`'s optional custom metadata. - /// The metadata is set as `None` for empty map. - fn set_metadata(&mut self, metadata: Option>); -} - -impl FieldExt for Field { - #[inline] - fn name(&self) -> &str { - &self.name - } - - #[inline] - fn is_nullable(&self) -> bool { - self.is_nullable - } - - #[inline] - fn metadata(&self) -> &BTreeMap { - &self.metadata - } - - fn try_merge(&mut self, from: &Field) -> Result<()> { - // merge metadata - for (key, from_value) in from.metadata() { - if let Some(self_value) = self.metadata.get(key) { - if self_value != from_value { - return Err(DataFusionError::ArrowError(ArrowError::InvalidArgumentError(format!( - "Fail to merge field due to conflicting metadata data value for key {}", - key - )))); - } - } else { - self.metadata.insert(key.clone(), from_value.clone()); - } - } - - match &mut self.data_type { - DataType::Struct(nested_fields) => match &from.data_type { - DataType::Struct(from_nested_fields) => { - for from_field in from_nested_fields { - let mut is_new_field = true; - for self_field in nested_fields.iter_mut() { - if self_field.name != from_field.name { - continue; - } - is_new_field = false; - self_field.try_merge(from_field)?; - } - if is_new_field { - nested_fields.push(from_field.clone()); - } - } - } - _ => { - return Err(DataFusionError::ArrowError( - ArrowError::InvalidArgumentError( - "Fail to merge schema Field due to conflicting datatype" - .to_string(), - ), - )); - } - }, - DataType::Union(nested_fields, _, _) => match &from.data_type { - DataType::Union(from_nested_fields, _, _) => { - for from_field in from_nested_fields { - let mut is_new_field = true; - for self_field in nested_fields.iter_mut() { - if from_field == self_field { - is_new_field = false; - break; - } - } - if is_new_field { - nested_fields.push(from_field.clone()); - } - } - } - _ => { - return Err(DataFusionError::ArrowError( - ArrowError::InvalidArgumentError( - "Fail to merge schema Field due to conflicting datatype" - .to_string(), - ), - )); - } - }, - DataType::Null - | DataType::Boolean - | DataType::Int8 - | DataType::Int16 - | DataType::Int32 - | DataType::Int64 - | DataType::UInt8 - | DataType::UInt16 - | DataType::UInt32 - | DataType::UInt64 - | DataType::Float16 - | DataType::Float32 - | DataType::Float64 - | DataType::Timestamp(_, _) - | DataType::Date32 - | DataType::Date64 - | DataType::Time32(_) - | DataType::Time64(_) - | DataType::Duration(_) - | DataType::Binary - | DataType::LargeBinary - | DataType::Interval(_) - | DataType::LargeList(_) - | DataType::List(_) - | DataType::Dictionary(_, _, _) - | DataType::FixedSizeList(_, _) - | DataType::FixedSizeBinary(_) - | DataType::Utf8 - | DataType::LargeUtf8 - | DataType::Extension(_, _, _) - | DataType::Map(_, _) - | DataType::Decimal(_, _) => { - if self.data_type != from.data_type { - return Err(DataFusionError::ArrowError( - ArrowError::InvalidArgumentError( - "Fail to merge schema Field due to conflicting datatype" - .to_string(), - ), - )); - } - } - } - if from.is_nullable { - self.is_nullable = from.is_nullable; - } - - Ok(()) - } - - #[inline] - fn set_metadata(&mut self, metadata: Option>) { - if let Some(v) = metadata { - if !v.is_empty() { - self.metadata = v; - } - } - } -} +pub use datafusion_common::field_util::*; diff --git a/datafusion/src/lib.rs b/datafusion/src/lib.rs index 6b839807f9db..682675ea4bc0 100644 --- a/datafusion/src/lib.rs +++ b/datafusion/src/lib.rs @@ -160,8 +160,8 @@ //! * Sort: [`SortExec`](physical_plan::sort::SortExec) //! * Coalesce partitions: [`CoalescePartitionsExec`](physical_plan::coalesce_partitions::CoalescePartitionsExec) //! * Limit: [`LocalLimitExec`](physical_plan::limit::LocalLimitExec) and [`GlobalLimitExec`](physical_plan::limit::GlobalLimitExec) -//! * Scan a CSV: [`CsvExec`](physical_plan::csv::CsvExec) -//! * Scan a Parquet: [`ParquetExec`](physical_plan::parquet::ParquetExec) +//! * Scan a CSV: [`CsvExec`](physical_plan::file_format::CsvExec) +//! * Scan a Parquet: [`ParquetExec`](physical_plan::file_format::ParquetExec) //! * Scan from memory: [`MemoryExec`](physical_plan::memory::MemoryExec) //! * Explain the plan: [`ExplainExec`](physical_plan::explain::ExplainExec) //! diff --git a/datafusion/src/logical_plan/builder.rs b/datafusion/src/logical_plan/builder.rs index 3a64f7630a84..db4573cace56 100644 --- a/datafusion/src/logical_plan/builder.rs +++ b/datafusion/src/logical_plan/builder.rs @@ -26,6 +26,7 @@ use crate::datasource::{ }; use crate::error::{DataFusionError, Result}; use crate::field_util::SchemaExt; +use crate::logical_plan::expr_schema::ExprSchemable; use crate::logical_plan::plan::{ Aggregate, Analyze, EmptyRelation, Explain, Filter, Join, Projection, Sort, TableScan, ToStringifiedPlan, Union, Window, @@ -594,6 +595,17 @@ impl LogicalPlanBuilder { self.join_detailed(right, join_type, join_keys, false) } + fn normalize( + plan: &LogicalPlan, + column: impl Into + Clone, + ) -> Result { + let schemas = plan.all_schemas(); + let using_columns = plan.using_columns()?; + column + .into() + .normalize_with_schemas(&schemas, &using_columns) + } + /// Apply a join with on constraint and specified null equality /// If null_equals_null is true then null == null, else null != null pub fn join_detailed( @@ -632,7 +644,10 @@ impl LogicalPlanBuilder { match (l_is_left, l_is_right, r_is_left, r_is_right) { (_, Ok(_), Ok(_), _) => (Ok(r), Ok(l)), (Ok(_), _, _, Ok(_)) => (Ok(l), Ok(r)), - _ => (l.normalize(&self.plan), r.normalize(right)), + _ => ( + Self::normalize(&self.plan, l), + Self::normalize(right, r), + ), } } (Some(lr), None) => { @@ -642,9 +657,12 @@ impl LogicalPlanBuilder { right.schema().field_with_qualified_name(lr, &l.name); match (l_is_left, l_is_right) { - (Ok(_), _) => (Ok(l), r.normalize(right)), - (_, Ok(_)) => (r.normalize(&self.plan), Ok(l)), - _ => (l.normalize(&self.plan), r.normalize(right)), + (Ok(_), _) => (Ok(l), Self::normalize(right, r)), + (_, Ok(_)) => (Self::normalize(&self.plan, r), Ok(l)), + _ => ( + Self::normalize(&self.plan, l), + Self::normalize(right, r), + ), } } (None, Some(rr)) => { @@ -654,22 +672,25 @@ impl LogicalPlanBuilder { right.schema().field_with_qualified_name(rr, &r.name); match (r_is_left, r_is_right) { - (Ok(_), _) => (Ok(r), l.normalize(right)), - (_, Ok(_)) => (l.normalize(&self.plan), Ok(r)), - _ => (l.normalize(&self.plan), r.normalize(right)), + (Ok(_), _) => (Ok(r), Self::normalize(right, l)), + (_, Ok(_)) => (Self::normalize(&self.plan, l), Ok(r)), + _ => ( + Self::normalize(&self.plan, l), + Self::normalize(right, r), + ), } } (None, None) => { let mut swap = false; - let left_key = - l.clone().normalize(&self.plan).or_else(|_| { + let left_key = Self::normalize(&self.plan, l.clone()) + .or_else(|_| { swap = true; - l.normalize(right) + Self::normalize(right, l) }); if swap { - (r.normalize(&self.plan), left_key) + (Self::normalize(&self.plan, r), left_key) } else { - (left_key, r.normalize(right)) + (left_key, Self::normalize(right, r)) } } } @@ -704,11 +725,11 @@ impl LogicalPlanBuilder { let left_keys: Vec = using_keys .clone() .into_iter() - .map(|c| c.into().normalize(&self.plan)) + .map(|c| Self::normalize(&self.plan, c)) .collect::>()?; let right_keys: Vec = using_keys .into_iter() - .map(|c| c.into().normalize(right)) + .map(|c| Self::normalize(right, c)) .collect::>()?; let on: Vec<(_, _)> = left_keys.into_iter().zip(right_keys.into_iter()).collect(); diff --git a/datafusion/src/logical_plan/dfschema.rs b/datafusion/src/logical_plan/dfschema.rs index b89b2399e67a..eb624283ea4f 100644 --- a/datafusion/src/logical_plan/dfschema.rs +++ b/datafusion/src/logical_plan/dfschema.rs @@ -18,669 +18,4 @@ //! DFSchema is an extended schema struct that DataFusion uses to provide support for //! fields with optional relation names. -use std::collections::HashSet; -use std::convert::TryFrom; -use std::sync::Arc; - -use crate::error::{DataFusionError, Result}; -use crate::logical_plan::Column; - -use crate::field_util::{FieldExt, SchemaExt}; -use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; -use std::fmt::{Display, Formatter}; - -/// A reference-counted reference to a `DFSchema`. -pub type DFSchemaRef = Arc; - -/// DFSchema wraps an Arrow schema and adds relation names -#[derive(Debug, Clone, PartialEq, Eq)] -pub struct DFSchema { - /// Fields - fields: Vec, -} - -impl DFSchema { - /// Creates an empty `DFSchema` - pub fn empty() -> Self { - Self { fields: vec![] } - } - - /// Create a new `DFSchema` - pub fn new(fields: Vec) -> Result { - let mut qualified_names = HashSet::new(); - let mut unqualified_names = HashSet::new(); - - for field in &fields { - if let Some(qualifier) = field.qualifier() { - if !qualified_names.insert((qualifier, field.name())) { - return Err(DataFusionError::Plan(format!( - "Schema contains duplicate qualified field name '{}'", - field.qualified_name() - ))); - } - } else if !unqualified_names.insert(field.name()) { - return Err(DataFusionError::Plan(format!( - "Schema contains duplicate unqualified field name '{}'", - field.name() - ))); - } - } - - // check for mix of qualified and unqualified field with same unqualified name - // note that we need to sort the contents of the HashSet first so that errors are - // deterministic - let mut qualified_names = qualified_names - .iter() - .map(|(l, r)| (l.as_str(), r.to_owned())) - .collect::>(); - qualified_names.sort_by(|a, b| { - let a = format!("{}.{}", a.0, a.1); - let b = format!("{}.{}", b.0, b.1); - a.cmp(&b) - }); - for (qualifier, name) in &qualified_names { - if unqualified_names.contains(name) { - return Err(DataFusionError::Plan(format!( - "Schema contains qualified field name '{}.{}' \ - and unqualified field name '{}' which would be ambiguous", - qualifier, name, name - ))); - } - } - Ok(Self { fields }) - } - - /// Create a `DFSchema` from an Arrow schema - pub fn try_from_qualified_schema(qualifier: &str, schema: &Schema) -> Result { - Self::new( - schema - .fields() - .iter() - .map(|f| DFField::from_qualified(qualifier, f.clone())) - .collect(), - ) - } - - /// Combine two schemas - pub fn join(&self, schema: &DFSchema) -> Result { - let mut fields = self.fields.clone(); - fields.extend_from_slice(schema.fields().as_slice()); - Self::new(fields) - } - - /// Merge a schema into self - pub fn merge(&mut self, other_schema: &DFSchema) { - for field in other_schema.fields() { - // skip duplicate columns - let duplicated_field = match field.qualifier() { - Some(q) => self.field_with_name(Some(q.as_str()), field.name()).is_ok(), - // for unqualifed columns, check as unqualified name - None => self.field_with_unqualified_name(field.name()).is_ok(), - }; - if !duplicated_field { - self.fields.push(field.clone()); - } - } - } - - /// Get a list of fields - pub fn fields(&self) -> &Vec { - &self.fields - } - - /// Returns an immutable reference of a specific `Field` instance selected using an - /// offset within the internal `fields` vector - pub fn field(&self, i: usize) -> &DFField { - &self.fields[i] - } - - /// Find the index of the column with the given unqualified name - pub fn index_of(&self, name: &str) -> Result { - for i in 0..self.fields.len() { - if self.fields[i].name() == name { - return Ok(i); - } - } - Err(DataFusionError::Plan(format!( - "No field named '{}'. Valid fields are {}.", - name, - self.get_field_names() - ))) - } - - fn index_of_column_by_name( - &self, - qualifier: Option<&str>, - name: &str, - ) -> Result { - let mut matches = self - .fields - .iter() - .enumerate() - .filter(|(_, field)| match (qualifier, &field.qualifier) { - // field to lookup is qualified. - // current field is qualified and not shared between relations, compare both - // qualifier and name. - (Some(q), Some(field_q)) => q == field_q && field.name() == name, - // field to lookup is qualified but current field is unqualified. - (Some(_), None) => false, - // field to lookup is unqualified, no need to compare qualifier - (None, Some(_)) | (None, None) => field.name() == name, - }) - .map(|(idx, _)| idx); - match matches.next() { - None => Err(DataFusionError::Plan(format!( - "No field named '{}.{}'. Valid fields are {}.", - qualifier.unwrap_or(""), - name, - self.get_field_names() - ))), - Some(idx) => match matches.next() { - None => Ok(idx), - // found more than one matches - Some(_) => Err(DataFusionError::Internal(format!( - "Ambiguous reference to qualified field named '{}.{}'", - qualifier.unwrap_or(""), - name - ))), - }, - } - } - - /// Find the index of the column with the given qualifier and name - pub fn index_of_column(&self, col: &Column) -> Result { - self.index_of_column_by_name(col.relation.as_deref(), &col.name) - } - - /// Find the field with the given name - pub fn field_with_name( - &self, - qualifier: Option<&str>, - name: &str, - ) -> Result<&DFField> { - if let Some(qualifier) = qualifier { - self.field_with_qualified_name(qualifier, name) - } else { - self.field_with_unqualified_name(name) - } - } - - /// Find all fields match the given name - pub fn fields_with_unqualified_name(&self, name: &str) -> Vec<&DFField> { - self.fields - .iter() - .filter(|field| field.name() == name) - .collect() - } - - /// Find the field with the given name - pub fn field_with_unqualified_name(&self, name: &str) -> Result<&DFField> { - let matches = self.fields_with_unqualified_name(name); - match matches.len() { - 0 => Err(DataFusionError::Plan(format!( - "No field with unqualified name '{}'. Valid fields are {}.", - name, - self.get_field_names() - ))), - 1 => Ok(matches[0]), - _ => Err(DataFusionError::Plan(format!( - "Ambiguous reference to field named '{}'", - name - ))), - } - } - - /// Find the field with the given qualified name - pub fn field_with_qualified_name( - &self, - qualifier: &str, - name: &str, - ) -> Result<&DFField> { - let idx = self.index_of_column_by_name(Some(qualifier), name)?; - Ok(self.field(idx)) - } - - /// Find the field with the given qualified column - pub fn field_from_column(&self, column: &Column) -> Result<&DFField> { - match &column.relation { - Some(r) => self.field_with_qualified_name(r, &column.name), - None => self.field_with_unqualified_name(&column.name), - } - } - - /// Check to see if unqualified field names matches field names in Arrow schema - pub fn matches_arrow_schema(&self, arrow_schema: &Schema) -> bool { - self.fields - .iter() - .zip(arrow_schema.fields().iter()) - .all(|(dffield, arrowfield)| dffield.name() == arrowfield.name()) - } - - /// Strip all field qualifier in schema - pub fn strip_qualifiers(self) -> Self { - DFSchema { - fields: self - .fields - .into_iter() - .map(|f| f.strip_qualifier()) - .collect(), - } - } - - /// Replace all field qualifier with new value in schema - pub fn replace_qualifier(self, qualifier: &str) -> Self { - DFSchema { - fields: self - .fields - .into_iter() - .map(|f| { - DFField::new( - Some(qualifier), - f.name(), - f.data_type().to_owned(), - f.is_nullable(), - ) - }) - .collect(), - } - } - - /// Get comma-seperated list of field names for use in error messages - fn get_field_names(&self) -> String { - self.fields - .iter() - .map(|f| match f.qualifier() { - Some(qualifier) => format!("'{}.{}'", qualifier, f.name()), - None => format!("'{}'", f.name()), - }) - .collect::>() - .join(", ") - } -} - -impl From for Schema { - /// Convert DFSchema into a Schema - fn from(df_schema: DFSchema) -> Self { - Schema::new( - df_schema - .fields - .into_iter() - .map(|f| { - if f.qualifier().is_some() { - Field::new(f.name(), f.data_type().to_owned(), f.is_nullable()) - } else { - f.field - } - }) - .collect(), - ) - } -} - -impl From<&DFSchema> for Schema { - /// Convert DFSchema reference into a Schema - fn from(df_schema: &DFSchema) -> Self { - Schema::new(df_schema.fields.iter().map(|f| f.field.clone()).collect()) - } -} - -/// Create a `DFSchema` from an Arrow schema -impl TryFrom for DFSchema { - type Error = DataFusionError; - fn try_from(schema: Schema) -> std::result::Result { - Self::new( - schema - .fields() - .iter() - .map(|f| DFField::from(f.clone())) - .collect(), - ) - } -} - -impl From for SchemaRef { - fn from(df_schema: DFSchema) -> Self { - SchemaRef::new(df_schema.into()) - } -} - -/// Convenience trait to convert Schema like things to DFSchema and DFSchemaRef with fewer keystrokes -pub trait ToDFSchema -where - Self: Sized, -{ - /// Attempt to create a DSSchema - #[allow(clippy::wrong_self_convention)] - fn to_dfschema(self) -> Result; - - /// Attempt to create a DSSchemaRef - #[allow(clippy::wrong_self_convention)] - fn to_dfschema_ref(self) -> Result { - Ok(Arc::new(self.to_dfschema()?)) - } -} - -impl ToDFSchema for Schema { - #[allow(clippy::wrong_self_convention)] - fn to_dfschema(self) -> Result { - DFSchema::try_from(self) - } -} - -impl ToDFSchema for SchemaRef { - #[allow(clippy::wrong_self_convention)] - fn to_dfschema(self) -> Result { - // Attempt to use the Schema directly if there are no other - // references, otherwise clone - match Self::try_unwrap(self) { - Ok(schema) => DFSchema::try_from(schema), - Err(schemaref) => DFSchema::try_from(schemaref.as_ref().clone()), - } - } -} - -impl ToDFSchema for Vec { - fn to_dfschema(self) -> Result { - DFSchema::new(self) - } -} - -impl Display for DFSchema { - fn fmt(&self, f: &mut Formatter) -> std::fmt::Result { - write!( - f, - "{}", - self.fields - .iter() - .map(|field| field.qualified_name()) - .collect::>() - .join(", ") - ) - } -} - -/// DFField wraps an Arrow field and adds an optional qualifier -#[derive(Debug, Clone, PartialEq, Eq)] -pub struct DFField { - /// Optional qualifier (usually a table or relation name) - qualifier: Option, - /// Arrow field definition - field: Field, -} - -impl DFField { - /// Creates a new `DFField` - pub fn new( - qualifier: Option<&str>, - name: &str, - data_type: DataType, - nullable: bool, - ) -> Self { - DFField { - qualifier: qualifier.map(|s| s.to_owned()), - field: Field::new(name, data_type, nullable), - } - } - - /// Create an unqualified field from an existing Arrow field - pub fn from(field: Field) -> Self { - Self { - qualifier: None, - field, - } - } - - /// Create a qualified field from an existing Arrow field - pub fn from_qualified(qualifier: &str, field: Field) -> Self { - Self { - qualifier: Some(qualifier.to_owned()), - field, - } - } - - /// Returns an immutable reference to the `DFField`'s unqualified name - pub fn name(&self) -> &str { - self.field.name() - } - - /// Returns an immutable reference to the `DFField`'s data-type - pub fn data_type(&self) -> &DataType { - self.field.data_type() - } - - /// Indicates whether this `DFField` supports null values - pub fn is_nullable(&self) -> bool { - self.field.is_nullable() - } - - /// Returns a string to the `DFField`'s qualified name - pub fn qualified_name(&self) -> String { - if let Some(qualifier) = &self.qualifier { - format!("{}.{}", qualifier, self.field.name()) - } else { - self.field.name().to_owned() - } - } - - /// Builds a qualified column based on self - pub fn qualified_column(&self) -> Column { - Column { - relation: self.qualifier.clone(), - name: self.field.name().to_string(), - } - } - - /// Builds an unqualified column based on self - pub fn unqualified_column(&self) -> Column { - Column { - relation: None, - name: self.field.name().to_string(), - } - } - - /// Get the optional qualifier - pub fn qualifier(&self) -> Option<&String> { - self.qualifier.as_ref() - } - - /// Get the arrow field - pub fn field(&self) -> &Field { - &self.field - } - - /// Return field with qualifier stripped - pub fn strip_qualifier(mut self) -> Self { - self.qualifier = None; - self - } -} - -#[cfg(test)] -mod tests { - use super::*; - use arrow::datatypes::DataType; - - #[test] - fn from_unqualified_field() { - let field = Field::new("c0", DataType::Boolean, true); - let field = DFField::from(field); - assert_eq!("c0", field.name()); - assert_eq!("c0", field.qualified_name()); - } - - #[test] - fn from_qualified_field() { - let field = Field::new("c0", DataType::Boolean, true); - let field = DFField::from_qualified("t1", field); - assert_eq!("c0", field.name()); - assert_eq!("t1.c0", field.qualified_name()); - } - - #[test] - fn from_unqualified_schema() -> Result<()> { - let schema = DFSchema::try_from(test_schema_1())?; - assert_eq!("c0, c1", schema.to_string()); - Ok(()) - } - - #[test] - fn from_qualified_schema() -> Result<()> { - let schema = DFSchema::try_from_qualified_schema("t1", &test_schema_1())?; - assert_eq!("t1.c0, t1.c1", schema.to_string()); - Ok(()) - } - - #[test] - fn from_qualified_schema_into_arrow_schema() -> Result<()> { - let schema = DFSchema::try_from_qualified_schema("t1", &test_schema_1())?; - let arrow_schema: Schema = schema.into(); - let expected = - "[Field { name: \"c0\", data_type: Boolean, is_nullable: true, metadata: {} }, \ - Field { name: \"c1\", data_type: Boolean, is_nullable: true, metadata: {} }]"; - assert_eq!(expected, format!("{:?}", arrow_schema.fields)); - Ok(()) - } - - #[test] - fn join_qualified() -> Result<()> { - let left = DFSchema::try_from_qualified_schema("t1", &test_schema_1())?; - let right = DFSchema::try_from_qualified_schema("t2", &test_schema_1())?; - let join = left.join(&right)?; - assert_eq!("t1.c0, t1.c1, t2.c0, t2.c1", join.to_string()); - // test valid access - assert!(join.field_with_qualified_name("t1", "c0").is_ok()); - assert!(join.field_with_qualified_name("t2", "c0").is_ok()); - // test invalid access - assert!(join.field_with_unqualified_name("c0").is_err()); - assert!(join.field_with_unqualified_name("t1.c0").is_err()); - assert!(join.field_with_unqualified_name("t2.c0").is_err()); - Ok(()) - } - - #[test] - fn join_qualified_duplicate() -> Result<()> { - let left = DFSchema::try_from_qualified_schema("t1", &test_schema_1())?; - let right = DFSchema::try_from_qualified_schema("t1", &test_schema_1())?; - let join = left.join(&right); - assert!(join.is_err()); - assert_eq!( - "Error during planning: Schema contains duplicate \ - qualified field name \'t1.c0\'", - &format!("{}", join.err().unwrap()) - ); - Ok(()) - } - - #[test] - fn join_unqualified_duplicate() -> Result<()> { - let left = DFSchema::try_from(test_schema_1())?; - let right = DFSchema::try_from(test_schema_1())?; - let join = left.join(&right); - assert!(join.is_err()); - assert_eq!( - "Error during planning: Schema contains duplicate \ - unqualified field name \'c0\'", - &format!("{}", join.err().unwrap()) - ); - Ok(()) - } - - #[test] - fn join_mixed() -> Result<()> { - let left = DFSchema::try_from_qualified_schema("t1", &test_schema_1())?; - let right = DFSchema::try_from(test_schema_2())?; - let join = left.join(&right)?; - assert_eq!("t1.c0, t1.c1, c100, c101", join.to_string()); - // test valid access - assert!(join.field_with_qualified_name("t1", "c0").is_ok()); - assert!(join.field_with_unqualified_name("c0").is_ok()); - assert!(join.field_with_unqualified_name("c100").is_ok()); - assert!(join.field_with_name(None, "c100").is_ok()); - // test invalid access - assert!(join.field_with_unqualified_name("t1.c0").is_err()); - assert!(join.field_with_unqualified_name("t1.c100").is_err()); - assert!(join.field_with_qualified_name("", "c100").is_err()); - Ok(()) - } - - #[test] - fn join_mixed_duplicate() -> Result<()> { - let left = DFSchema::try_from_qualified_schema("t1", &test_schema_1())?; - let right = DFSchema::try_from(test_schema_1())?; - let join = left.join(&right); - assert!(join.is_err()); - assert_eq!( - "Error during planning: Schema contains qualified \ - field name \'t1.c0\' and unqualified field name \'c0\' which would be ambiguous", - &format!("{}", join.err().unwrap()) - ); - Ok(()) - } - - #[test] - fn helpful_error_messages() -> Result<()> { - let schema = DFSchema::try_from_qualified_schema("t1", &test_schema_1())?; - let expected_help = "Valid fields are \'t1.c0\', \'t1.c1\'."; - assert!(schema - .field_with_qualified_name("x", "y") - .unwrap_err() - .to_string() - .contains(expected_help)); - assert!(schema - .field_with_unqualified_name("y") - .unwrap_err() - .to_string() - .contains(expected_help)); - assert!(schema - .index_of("y") - .unwrap_err() - .to_string() - .contains(expected_help)); - Ok(()) - } - - #[test] - fn into() { - // Demonstrate how to convert back and forth between Schema, SchemaRef, DFSchema, and DFSchemaRef - let arrow_schema = Schema::new(vec![Field::new("c0", DataType::Int64, true)]); - let arrow_schema_ref = Arc::new(arrow_schema.clone()); - - let df_schema = - DFSchema::new(vec![DFField::new(None, "c0", DataType::Int64, true)]).unwrap(); - let df_schema_ref = Arc::new(df_schema.clone()); - - { - let arrow_schema = arrow_schema.clone(); - let arrow_schema_ref = arrow_schema_ref.clone(); - - assert_eq!(df_schema, arrow_schema.to_dfschema().unwrap()); - assert_eq!(df_schema, arrow_schema_ref.to_dfschema().unwrap()); - } - - { - let arrow_schema = arrow_schema.clone(); - let arrow_schema_ref = arrow_schema_ref.clone(); - - assert_eq!(df_schema_ref, arrow_schema.to_dfschema_ref().unwrap()); - assert_eq!(df_schema_ref, arrow_schema_ref.to_dfschema_ref().unwrap()); - } - - // Now, consume the refs - assert_eq!(df_schema_ref, arrow_schema.to_dfschema_ref().unwrap()); - assert_eq!(df_schema_ref, arrow_schema_ref.to_dfschema_ref().unwrap()); - } - - fn test_schema_1() -> Schema { - Schema::new(vec![ - Field::new("c0", DataType::Boolean, true), - Field::new("c1", DataType::Boolean, true), - ]) - } - - fn test_schema_2() -> Schema { - Schema::new(vec![ - Field::new("c100", DataType::Boolean, true), - Field::new("c101", DataType::Boolean, true), - ]) - } -} +pub use datafusion_common::{DFField, DFSchema, DFSchemaRef, ExprSchema, ToDFSchema}; diff --git a/datafusion/src/logical_plan/expr.rs b/datafusion/src/logical_plan/expr.rs index 2dd9f9eb3c41..3826e45f1443 100644 --- a/datafusion/src/logical_plan/expr.rs +++ b/datafusion/src/logical_plan/expr.rs @@ -20,1147 +20,25 @@ pub use super::Operator; -use arrow::{compute::cast::can_cast_types, datatypes::DataType}; +use arrow::datatypes::DataType; use crate::error::{DataFusionError, Result}; -use crate::execution::context::ExecutionProps; -use crate::field_util::{get_indexed_field, FieldExt}; -use crate::logical_plan::{ - plan::Aggregate, window_frames, DFField, DFSchema, LogicalPlan, +use crate::logical_plan::ExprSchemable; +use crate::logical_plan::{DFField, DFSchema}; +use crate::physical_plan::udaf::AggregateUDF; +use crate::physical_plan::{aggregates, functions, udf::ScalarUDF}; +pub use datafusion_common::{Column, ExprSchema}; +pub use datafusion_expr::expr_fn::col; +use datafusion_expr::AccumulatorFunctionImplementation; +pub use datafusion_expr::Expr; +use datafusion_expr::StateTypeFunction; +pub use datafusion_expr::{lit, lit_timestamp_nano, Literal}; +use datafusion_expr::{ + ReturnTypeFunction, ScalarFunctionImplementation, Signature, Volatility, }; -use crate::optimizer::simplify_expressions::{ConstEvaluator, Simplifier}; -use crate::physical_plan::functions::Volatility; -use crate::physical_plan::{ - aggregates, expressions::binary_operator_data_type, functions, udf::ScalarUDF, - window_functions, -}; -use crate::{physical_plan::udaf::AggregateUDF, scalar::ScalarValue}; -use aggregates::{AccumulatorFunctionImplementation, StateTypeFunction}; -use functions::{ReturnTypeFunction, ScalarFunctionImplementation, Signature}; -use std::collections::{HashMap, HashSet}; -use std::convert::Infallible; -use std::fmt; -use std::hash::{BuildHasher, Hash, Hasher}; -use std::ops::Not; -use std::str::FromStr; +use std::collections::HashSet; use std::sync::Arc; -/// A named reference to a qualified field in a schema. -#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)] -pub struct Column { - /// relation/table name. - pub relation: Option, - /// field/column name. - pub name: String, -} - -impl Column { - /// Create Column from unqualified name. - pub fn from_name(name: impl Into) -> Self { - Self { - relation: None, - name: name.into(), - } - } - - /// Deserialize a fully qualified name string into a column - pub fn from_qualified_name(flat_name: &str) -> Self { - use sqlparser::tokenizer::Token; - - let dialect = sqlparser::dialect::GenericDialect {}; - let mut tokenizer = sqlparser::tokenizer::Tokenizer::new(&dialect, flat_name); - if let Ok(tokens) = tokenizer.tokenize() { - if let [Token::Word(relation), Token::Period, Token::Word(name)] = - tokens.as_slice() - { - return Column { - relation: Some(relation.value.clone()), - name: name.value.clone(), - }; - } - } - // any expression that's not in the form of `foo.bar` will be treated as unqualified column - // name - Column { - relation: None, - name: String::from(flat_name), - } - } - - /// Serialize column into a flat name string - pub fn flat_name(&self) -> String { - match &self.relation { - Some(r) => format!("{}.{}", r, self.name), - None => self.name.clone(), - } - } - - /// Normalizes `self` if is unqualified (has no relation name) - /// with an explicit qualifier from the first matching input - /// schemas. - /// - /// For example, `foo` will be normalized to `t.foo` if there is a - /// column named `foo` in a relation named `t` found in `schemas` - pub fn normalize(self, plan: &LogicalPlan) -> Result { - let schemas = plan.all_schemas(); - let using_columns = plan.using_columns()?; - self.normalize_with_schemas(&schemas, &using_columns) - } - - // Internal implementation of normalize - fn normalize_with_schemas( - self, - schemas: &[&Arc], - using_columns: &[HashSet], - ) -> Result { - if self.relation.is_some() { - return Ok(self); - } - - for schema in schemas { - let fields = schema.fields_with_unqualified_name(&self.name); - match fields.len() { - 0 => continue, - 1 => { - return Ok(fields[0].qualified_column()); - } - _ => { - // More than 1 fields in this schema have their names set to self.name. - // - // This should only happen when a JOIN query with USING constraint references - // join columns using unqualified column name. For example: - // - // ```sql - // SELECT id FROM t1 JOIN t2 USING(id) - // ``` - // - // In this case, both `t1.id` and `t2.id` will match unqualified column `id`. - // We will use the relation from the first matched field to normalize self. - - // Compare matched fields with one USING JOIN clause at a time - for using_col in using_columns { - let all_matched = fields - .iter() - .all(|f| using_col.contains(&f.qualified_column())); - // All matched fields belong to the same using column set, in orther words - // the same join clause. We simply pick the qualifer from the first match. - if all_matched { - return Ok(fields[0].qualified_column()); - } - } - } - } - } - - Err(DataFusionError::Plan(format!( - "Column {} not found in provided schemas", - self - ))) - } -} - -impl From<&str> for Column { - fn from(c: &str) -> Self { - Self::from_qualified_name(c) - } -} - -impl FromStr for Column { - type Err = Infallible; - - fn from_str(s: &str) -> std::result::Result { - Ok(s.into()) - } -} - -impl fmt::Display for Column { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - match &self.relation { - Some(r) => write!(f, "#{}.{}", r, self.name), - None => write!(f, "#{}", self.name), - } - } -} - -/// `Expr` is a central struct of DataFusion's query API, and -/// represent logical expressions such as `A + 1`, or `CAST(c1 AS -/// int)`. -/// -/// An `Expr` can compute its [DataType](arrow::datatypes::DataType) -/// and nullability, and has functions for building up complex -/// expressions. -/// -/// # Examples -/// -/// ## Create an expression `c1` referring to column named "c1" -/// ``` -/// # use datafusion::logical_plan::*; -/// let expr = col("c1"); -/// assert_eq!(expr, Expr::Column(Column::from_name("c1"))); -/// ``` -/// -/// ## Create the expression `c1 + c2` to add columns "c1" and "c2" together -/// ``` -/// # use datafusion::logical_plan::*; -/// let expr = col("c1") + col("c2"); -/// -/// assert!(matches!(expr, Expr::BinaryExpr { ..} )); -/// if let Expr::BinaryExpr { left, right, op } = expr { -/// assert_eq!(*left, col("c1")); -/// assert_eq!(*right, col("c2")); -/// assert_eq!(op, Operator::Plus); -/// } -/// ``` -/// -/// ## Create expression `c1 = 42` to compare the value in coumn "c1" to the literal value `42` -/// ``` -/// # use datafusion::logical_plan::*; -/// # use datafusion::scalar::*; -/// let expr = col("c1").eq(lit(42)); -/// -/// assert!(matches!(expr, Expr::BinaryExpr { ..} )); -/// if let Expr::BinaryExpr { left, right, op } = expr { -/// assert_eq!(*left, col("c1")); -/// let scalar = ScalarValue::Int32(Some(42)); -/// assert_eq!(*right, Expr::Literal(scalar)); -/// assert_eq!(op, Operator::Eq); -/// } -/// ``` -#[derive(Clone, PartialEq, Hash)] -pub enum Expr { - /// An expression with a specific name. - Alias(Box, String), - /// A named reference to a qualified filed in a schema. - Column(Column), - /// A named reference to a variable in a registry. - ScalarVariable(Vec), - /// A constant value. - Literal(ScalarValue), - /// A binary expression such as "age > 21" - BinaryExpr { - /// Left-hand side of the expression - left: Box, - /// The comparison operator - op: Operator, - /// Right-hand side of the expression - right: Box, - }, - /// Negation of an expression. The expression's type must be a boolean to make sense. - Not(Box), - /// Whether an expression is not Null. This expression is never null. - IsNotNull(Box), - /// Whether an expression is Null. This expression is never null. - IsNull(Box), - /// arithmetic negation of an expression, the operand must be of a signed numeric data type - Negative(Box), - /// Returns the field of a [`ListArray`] or [`StructArray`] by key - GetIndexedField { - /// the expression to take the field from - expr: Box, - /// The name of the field to take - key: ScalarValue, - }, - /// Whether an expression is between a given range. - Between { - /// The value to compare - expr: Box, - /// Whether the expression is negated - negated: bool, - /// The low end of the range - low: Box, - /// The high end of the range - high: Box, - }, - /// The CASE expression is similar to a series of nested if/else and there are two forms that - /// can be used. The first form consists of a series of boolean "when" expressions with - /// corresponding "then" expressions, and an optional "else" expression. - /// - /// CASE WHEN condition THEN result - /// [WHEN ...] - /// [ELSE result] - /// END - /// - /// The second form uses a base expression and then a series of "when" clauses that match on a - /// literal value. - /// - /// CASE expression - /// WHEN value THEN result - /// [WHEN ...] - /// [ELSE result] - /// END - Case { - /// Optional base expression that can be compared to literal values in the "when" expressions - expr: Option>, - /// One or more when/then expressions - when_then_expr: Vec<(Box, Box)>, - /// Optional "else" expression - else_expr: Option>, - }, - /// Casts the expression to a given type and will return a runtime error if the expression cannot be cast. - /// This expression is guaranteed to have a fixed type. - Cast { - /// The expression being cast - expr: Box, - /// The `DataType` the expression will yield - data_type: DataType, - }, - /// Casts the expression to a given type and will return a null value if the expression cannot be cast. - /// This expression is guaranteed to have a fixed type. - TryCast { - /// The expression being cast - expr: Box, - /// The `DataType` the expression will yield - data_type: DataType, - }, - /// A sort expression, that can be used to sort values. - Sort { - /// The expression to sort on - expr: Box, - /// The direction of the sort - asc: bool, - /// Whether to put Nulls before all other data values - nulls_first: bool, - }, - /// Represents the call of a built-in scalar function with a set of arguments. - ScalarFunction { - /// The function - fun: functions::BuiltinScalarFunction, - /// List of expressions to feed to the functions as arguments - args: Vec, - }, - /// Represents the call of a user-defined scalar function with arguments. - ScalarUDF { - /// The function - fun: Arc, - /// List of expressions to feed to the functions as arguments - args: Vec, - }, - /// Represents the call of an aggregate built-in function with arguments. - AggregateFunction { - /// Name of the function - fun: aggregates::AggregateFunction, - /// List of expressions to feed to the functions as arguments - args: Vec, - /// Whether this is a DISTINCT aggregation or not - distinct: bool, - }, - /// Represents the call of a window function with arguments. - WindowFunction { - /// Name of the function - fun: window_functions::WindowFunction, - /// List of expressions to feed to the functions as arguments - args: Vec, - /// List of partition by expressions - partition_by: Vec, - /// List of order by expressions - order_by: Vec, - /// Window frame - window_frame: Option, - }, - /// aggregate function - AggregateUDF { - /// The function - fun: Arc, - /// List of expressions to feed to the functions as arguments - args: Vec, - }, - /// Returns whether the list contains the expr value. - InList { - /// The expression to compare - expr: Box, - /// A list of values to compare against - list: Vec, - /// Whether the expression is negated - negated: bool, - }, - /// Represents a reference to all fields in a schema. - Wildcard, -} - -/// Fixed seed for the hashing so that Ords are consistent across runs -const SEED: ahash::RandomState = ahash::RandomState::with_seeds(0, 0, 0, 0); - -impl PartialOrd for Expr { - fn partial_cmp(&self, other: &Self) -> Option { - let mut hasher = SEED.build_hasher(); - self.hash(&mut hasher); - let s = hasher.finish(); - - let mut hasher = SEED.build_hasher(); - other.hash(&mut hasher); - let o = hasher.finish(); - - Some(s.cmp(&o)) - } -} - -impl Expr { - /// Returns the [arrow::datatypes::DataType] of the expression based on [arrow::datatypes::Schema]. - /// - /// # Errors - /// - /// This function errors when it is not possible to compute its [arrow::datatypes::DataType]. - /// This happens when e.g. the expression refers to a column that does not exist in the schema, or when - /// the expression is incorrectly typed (e.g. `[utf8] + [bool]`). - pub fn get_type(&self, schema: &DFSchema) -> Result { - match self { - Expr::Alias(expr, _) | Expr::Sort { expr, .. } | Expr::Negative(expr) => { - expr.get_type(schema) - } - Expr::Column(c) => Ok(schema.field_from_column(c)?.data_type().clone()), - Expr::ScalarVariable(_) => Ok(DataType::Utf8), - Expr::Literal(l) => Ok(l.get_datatype()), - Expr::Case { when_then_expr, .. } => when_then_expr[0].1.get_type(schema), - Expr::Cast { data_type, .. } | Expr::TryCast { data_type, .. } => { - Ok(data_type.clone()) - } - Expr::ScalarUDF { fun, args } => { - let data_types = args - .iter() - .map(|e| e.get_type(schema)) - .collect::>>()?; - Ok((fun.return_type)(&data_types)?.as_ref().clone()) - } - Expr::ScalarFunction { fun, args } => { - let data_types = args - .iter() - .map(|e| e.get_type(schema)) - .collect::>>()?; - functions::return_type(fun, &data_types) - } - Expr::WindowFunction { fun, args, .. } => { - let data_types = args - .iter() - .map(|e| e.get_type(schema)) - .collect::>>()?; - window_functions::return_type(fun, &data_types) - } - Expr::AggregateFunction { fun, args, .. } => { - let data_types = args - .iter() - .map(|e| e.get_type(schema)) - .collect::>>()?; - aggregates::return_type(fun, &data_types) - } - Expr::AggregateUDF { fun, args, .. } => { - let data_types = args - .iter() - .map(|e| e.get_type(schema)) - .collect::>>()?; - Ok((fun.return_type)(&data_types)?.as_ref().clone()) - } - Expr::Not(_) - | Expr::IsNull(_) - | Expr::Between { .. } - | Expr::InList { .. } - | Expr::IsNotNull(_) => Ok(DataType::Boolean), - Expr::BinaryExpr { - ref left, - ref right, - ref op, - } => binary_operator_data_type( - &left.get_type(schema)?, - op, - &right.get_type(schema)?, - ), - Expr::Wildcard => Err(DataFusionError::Internal( - "Wildcard expressions are not valid in a logical query plan".to_owned(), - )), - Expr::GetIndexedField { ref expr, key } => { - let data_type = expr.get_type(schema)?; - - get_indexed_field(&data_type, key).map(|x| x.data_type().clone()) - } - } - } - - /// Returns the nullability of the expression based on [arrow::datatypes::Schema]. - /// - /// # Errors - /// - /// This function errors when it is not possible to compute its nullability. - /// This happens when the expression refers to a column that does not exist in the schema. - pub fn nullable(&self, input_schema: &DFSchema) -> Result { - match self { - Expr::Alias(expr, _) - | Expr::Not(expr) - | Expr::Negative(expr) - | Expr::Sort { expr, .. } - | Expr::Between { expr, .. } - | Expr::InList { expr, .. } => expr.nullable(input_schema), - Expr::Column(c) => Ok(input_schema.field_from_column(c)?.is_nullable()), - Expr::Literal(value) => Ok(value.is_null()), - Expr::Case { - when_then_expr, - else_expr, - .. - } => { - // this expression is nullable if any of the input expressions are nullable - let then_nullable = when_then_expr - .iter() - .map(|(_, t)| t.nullable(input_schema)) - .collect::>>()?; - if then_nullable.contains(&true) { - Ok(true) - } else if let Some(e) = else_expr { - e.nullable(input_schema) - } else { - Ok(false) - } - } - Expr::Cast { expr, .. } => expr.nullable(input_schema), - Expr::ScalarVariable(_) - | Expr::TryCast { .. } - | Expr::ScalarFunction { .. } - | Expr::ScalarUDF { .. } - | Expr::WindowFunction { .. } - | Expr::AggregateFunction { .. } - | Expr::AggregateUDF { .. } => Ok(true), - Expr::IsNull(_) | Expr::IsNotNull(_) => Ok(false), - Expr::BinaryExpr { - ref left, - ref right, - .. - } => Ok(left.nullable(input_schema)? || right.nullable(input_schema)?), - Expr::Wildcard => Err(DataFusionError::Internal( - "Wildcard expressions are not valid in a logical query plan".to_owned(), - )), - Expr::GetIndexedField { ref expr, key } => { - let data_type = expr.get_type(input_schema)?; - get_indexed_field(&data_type, key).map(|x| x.is_nullable()) - } - } - } - - /// Returns the name of this expression based on [crate::logical_plan::DFSchema]. - /// - /// This represents how a column with this expression is named when no alias is chosen - pub fn name(&self, input_schema: &DFSchema) -> Result { - create_name(self, input_schema) - } - - /// Returns a [arrow::datatypes::Field] compatible with this expression. - pub fn to_field(&self, input_schema: &DFSchema) -> Result { - match self { - Expr::Column(c) => Ok(DFField::new( - c.relation.as_deref(), - &c.name, - self.get_type(input_schema)?, - self.nullable(input_schema)?, - )), - _ => Ok(DFField::new( - None, - &self.name(input_schema)?, - self.get_type(input_schema)?, - self.nullable(input_schema)?, - )), - } - } - - /// Wraps this expression in a cast to a target [arrow::datatypes::DataType]. - /// - /// # Errors - /// - /// This function errors when it is impossible to cast the - /// expression to the target [arrow::datatypes::DataType]. - pub fn cast_to(self, cast_to_type: &DataType, schema: &DFSchema) -> Result { - // TODO(kszucs): most of the operations do not validate the type correctness - // like all of the binary expressions below. Perhaps Expr should track the - // type of the expression? - let this_type = self.get_type(schema)?; - if this_type == *cast_to_type { - Ok(self) - } else if can_cast_types(&this_type, cast_to_type) { - Ok(Expr::Cast { - expr: Box::new(self), - data_type: cast_to_type.clone(), - }) - } else { - Err(DataFusionError::Plan(format!( - "Cannot automatically convert {:?} to {:?}", - this_type, cast_to_type - ))) - } - } - - /// Return `self == other` - pub fn eq(self, other: Expr) -> Expr { - binary_expr(self, Operator::Eq, other) - } - - /// Return `self != other` - pub fn not_eq(self, other: Expr) -> Expr { - binary_expr(self, Operator::NotEq, other) - } - - /// Return `self > other` - pub fn gt(self, other: Expr) -> Expr { - binary_expr(self, Operator::Gt, other) - } - - /// Return `self >= other` - pub fn gt_eq(self, other: Expr) -> Expr { - binary_expr(self, Operator::GtEq, other) - } - - /// Return `self < other` - pub fn lt(self, other: Expr) -> Expr { - binary_expr(self, Operator::Lt, other) - } - - /// Return `self <= other` - pub fn lt_eq(self, other: Expr) -> Expr { - binary_expr(self, Operator::LtEq, other) - } - - /// Return `self && other` - pub fn and(self, other: Expr) -> Expr { - binary_expr(self, Operator::And, other) - } - - /// Return `self || other` - pub fn or(self, other: Expr) -> Expr { - binary_expr(self, Operator::Or, other) - } - - /// Return `!self` - #[allow(clippy::should_implement_trait)] - pub fn not(self) -> Expr { - !self - } - - /// Calculate the modulus of two expressions. - /// Return `self % other` - pub fn modulus(self, other: Expr) -> Expr { - binary_expr(self, Operator::Modulo, other) - } - - /// Return `self LIKE other` - pub fn like(self, other: Expr) -> Expr { - binary_expr(self, Operator::Like, other) - } - - /// Return `self NOT LIKE other` - pub fn not_like(self, other: Expr) -> Expr { - binary_expr(self, Operator::NotLike, other) - } - - /// Return `self AS name` alias expression - pub fn alias(self, name: &str) -> Expr { - Expr::Alias(Box::new(self), name.to_owned()) - } - - /// Return `self IN ` if `negated` is false, otherwise - /// return `self NOT IN `.a - pub fn in_list(self, list: Vec, negated: bool) -> Expr { - Expr::InList { - expr: Box::new(self), - list, - negated, - } - } - - /// Return `IsNull(Box(self)) - #[allow(clippy::wrong_self_convention)] - pub fn is_null(self) -> Expr { - Expr::IsNull(Box::new(self)) - } - - /// Return `IsNotNull(Box(self)) - #[allow(clippy::wrong_self_convention)] - pub fn is_not_null(self) -> Expr { - Expr::IsNotNull(Box::new(self)) - } - - /// Create a sort expression from an existing expression. - /// - /// ``` - /// # use datafusion::logical_plan::col; - /// let sort_expr = col("foo").sort(true, true); // SORT ASC NULLS_FIRST - /// ``` - pub fn sort(self, asc: bool, nulls_first: bool) -> Expr { - Expr::Sort { - expr: Box::new(self), - asc, - nulls_first, - } - } - - /// Performs a depth first walk of an expression and - /// its children, calling [`ExpressionVisitor::pre_visit`] and - /// `visitor.post_visit`. - /// - /// Implements the [visitor pattern](https://en.wikipedia.org/wiki/Visitor_pattern) to - /// separate expression algorithms from the structure of the - /// `Expr` tree and make it easier to add new types of expressions - /// and algorithms that walk the tree. - /// - /// For an expression tree such as - /// ```text - /// BinaryExpr (GT) - /// left: Column("foo") - /// right: Column("bar") - /// ``` - /// - /// The nodes are visited using the following order - /// ```text - /// pre_visit(BinaryExpr(GT)) - /// pre_visit(Column("foo")) - /// pre_visit(Column("bar")) - /// post_visit(Column("bar")) - /// post_visit(Column("bar")) - /// post_visit(BinaryExpr(GT)) - /// ``` - /// - /// If an Err result is returned, recursion is stopped immediately - /// - /// If `Recursion::Stop` is returned on a call to pre_visit, no - /// children of that expression are visited, nor is post_visit - /// called on that expression - /// - pub fn accept(&self, visitor: V) -> Result { - let visitor = match visitor.pre_visit(self)? { - Recursion::Continue(visitor) => visitor, - // If the recursion should stop, do not visit children - Recursion::Stop(visitor) => return Ok(visitor), - }; - - // recurse (and cover all expression types) - let visitor = match self { - Expr::Alias(expr, _) - | Expr::Not(expr) - | Expr::IsNotNull(expr) - | Expr::IsNull(expr) - | Expr::Negative(expr) - | Expr::Cast { expr, .. } - | Expr::TryCast { expr, .. } - | Expr::Sort { expr, .. } - | Expr::GetIndexedField { expr, .. } => expr.accept(visitor), - Expr::Column(_) - | Expr::ScalarVariable(_) - | Expr::Literal(_) - | Expr::Wildcard => Ok(visitor), - Expr::BinaryExpr { left, right, .. } => { - let visitor = left.accept(visitor)?; - right.accept(visitor) - } - Expr::Between { - expr, low, high, .. - } => { - let visitor = expr.accept(visitor)?; - let visitor = low.accept(visitor)?; - high.accept(visitor) - } - Expr::Case { - expr, - when_then_expr, - else_expr, - } => { - let visitor = if let Some(expr) = expr.as_ref() { - expr.accept(visitor) - } else { - Ok(visitor) - }?; - let visitor = when_then_expr.iter().try_fold( - visitor, - |visitor, (when, then)| { - let visitor = when.accept(visitor)?; - then.accept(visitor) - }, - )?; - if let Some(else_expr) = else_expr.as_ref() { - else_expr.accept(visitor) - } else { - Ok(visitor) - } - } - Expr::ScalarFunction { args, .. } - | Expr::ScalarUDF { args, .. } - | Expr::AggregateFunction { args, .. } - | Expr::AggregateUDF { args, .. } => args - .iter() - .try_fold(visitor, |visitor, arg| arg.accept(visitor)), - Expr::WindowFunction { - args, - partition_by, - order_by, - .. - } => { - let visitor = args - .iter() - .try_fold(visitor, |visitor, arg| arg.accept(visitor))?; - let visitor = partition_by - .iter() - .try_fold(visitor, |visitor, arg| arg.accept(visitor))?; - let visitor = order_by - .iter() - .try_fold(visitor, |visitor, arg| arg.accept(visitor))?; - Ok(visitor) - } - Expr::InList { expr, list, .. } => { - let visitor = expr.accept(visitor)?; - list.iter() - .try_fold(visitor, |visitor, arg| arg.accept(visitor)) - } - }?; - - visitor.post_visit(self) - } - - /// Performs a depth first walk of an expression and its children - /// to rewrite an expression, consuming `self` producing a new - /// [`Expr`]. - /// - /// Implements a modified version of the [visitor - /// pattern](https://en.wikipedia.org/wiki/Visitor_pattern) to - /// separate algorithms from the structure of the `Expr` tree and - /// make it easier to write new, efficient expression - /// transformation algorithms. - /// - /// For an expression tree such as - /// ```text - /// BinaryExpr (GT) - /// left: Column("foo") - /// right: Column("bar") - /// ``` - /// - /// The nodes are visited using the following order - /// ```text - /// pre_visit(BinaryExpr(GT)) - /// pre_visit(Column("foo")) - /// mutatate(Column("foo")) - /// pre_visit(Column("bar")) - /// mutate(Column("bar")) - /// mutate(BinaryExpr(GT)) - /// ``` - /// - /// If an Err result is returned, recursion is stopped immediately - /// - /// If [`false`] is returned on a call to pre_visit, no - /// children of that expression are visited, nor is mutate - /// called on that expression - /// - pub fn rewrite(self, rewriter: &mut R) -> Result - where - R: ExprRewriter, - { - let need_mutate = match rewriter.pre_visit(&self)? { - RewriteRecursion::Mutate => return rewriter.mutate(self), - RewriteRecursion::Stop => return Ok(self), - RewriteRecursion::Continue => true, - RewriteRecursion::Skip => false, - }; - - // recurse into all sub expressions(and cover all expression types) - let expr = match self { - Expr::Alias(expr, name) => Expr::Alias(rewrite_boxed(expr, rewriter)?, name), - Expr::Column(_) => self.clone(), - Expr::ScalarVariable(names) => Expr::ScalarVariable(names), - Expr::Literal(value) => Expr::Literal(value), - Expr::BinaryExpr { left, op, right } => Expr::BinaryExpr { - left: rewrite_boxed(left, rewriter)?, - op, - right: rewrite_boxed(right, rewriter)?, - }, - Expr::Not(expr) => Expr::Not(rewrite_boxed(expr, rewriter)?), - Expr::IsNotNull(expr) => Expr::IsNotNull(rewrite_boxed(expr, rewriter)?), - Expr::IsNull(expr) => Expr::IsNull(rewrite_boxed(expr, rewriter)?), - Expr::Negative(expr) => Expr::Negative(rewrite_boxed(expr, rewriter)?), - Expr::Between { - expr, - low, - high, - negated, - } => Expr::Between { - expr: rewrite_boxed(expr, rewriter)?, - low: rewrite_boxed(low, rewriter)?, - high: rewrite_boxed(high, rewriter)?, - negated, - }, - Expr::Case { - expr, - when_then_expr, - else_expr, - } => { - let expr = rewrite_option_box(expr, rewriter)?; - let when_then_expr = when_then_expr - .into_iter() - .map(|(when, then)| { - Ok(( - rewrite_boxed(when, rewriter)?, - rewrite_boxed(then, rewriter)?, - )) - }) - .collect::>>()?; - - let else_expr = rewrite_option_box(else_expr, rewriter)?; - - Expr::Case { - expr, - when_then_expr, - else_expr, - } - } - Expr::Cast { expr, data_type } => Expr::Cast { - expr: rewrite_boxed(expr, rewriter)?, - data_type, - }, - Expr::TryCast { expr, data_type } => Expr::TryCast { - expr: rewrite_boxed(expr, rewriter)?, - data_type, - }, - Expr::Sort { - expr, - asc, - nulls_first, - } => Expr::Sort { - expr: rewrite_boxed(expr, rewriter)?, - asc, - nulls_first, - }, - Expr::ScalarFunction { args, fun } => Expr::ScalarFunction { - args: rewrite_vec(args, rewriter)?, - fun, - }, - Expr::ScalarUDF { args, fun } => Expr::ScalarUDF { - args: rewrite_vec(args, rewriter)?, - fun, - }, - Expr::WindowFunction { - args, - fun, - partition_by, - order_by, - window_frame, - } => Expr::WindowFunction { - args: rewrite_vec(args, rewriter)?, - fun, - partition_by: rewrite_vec(partition_by, rewriter)?, - order_by: rewrite_vec(order_by, rewriter)?, - window_frame, - }, - Expr::AggregateFunction { - args, - fun, - distinct, - } => Expr::AggregateFunction { - args: rewrite_vec(args, rewriter)?, - fun, - distinct, - }, - Expr::AggregateUDF { args, fun } => Expr::AggregateUDF { - args: rewrite_vec(args, rewriter)?, - fun, - }, - Expr::InList { - expr, - list, - negated, - } => Expr::InList { - expr: rewrite_boxed(expr, rewriter)?, - list: rewrite_vec(list, rewriter)?, - negated, - }, - Expr::Wildcard => Expr::Wildcard, - Expr::GetIndexedField { expr, key } => Expr::GetIndexedField { - expr: rewrite_boxed(expr, rewriter)?, - key, - }, - }; - - // now rewrite this expression itself - if need_mutate { - rewriter.mutate(expr) - } else { - Ok(expr) - } - } - - /// Simplifies this [`Expr`]`s as much as possible, evaluating - /// constants and applying algebraic simplifications - /// - /// # Example: - /// `b > 2 AND b > 2` - /// can be written to - /// `b > 2` - /// - /// ``` - /// use datafusion::logical_plan::*; - /// use datafusion::error::Result; - /// use datafusion::execution::context::ExecutionProps; - /// - /// /// Simple implementation that provides `Simplifier` the information it needs - /// #[derive(Default)] - /// struct Info { - /// execution_props: ExecutionProps, - /// }; - /// - /// impl SimplifyInfo for Info { - /// fn is_boolean_type(&self, expr: &Expr) -> Result { - /// Ok(false) - /// } - /// fn nullable(&self, expr: &Expr) -> Result { - /// Ok(true) - /// } - /// fn execution_props(&self) -> &ExecutionProps { - /// &self.execution_props - /// } - /// } - /// - /// // b < 2 - /// let b_lt_2 = col("b").gt(lit(2)); - /// - /// // (b < 2) OR (b < 2) - /// let expr = b_lt_2.clone().or(b_lt_2.clone()); - /// - /// // (b < 2) OR (b < 2) --> (b < 2) - /// let expr = expr.simplify(&Info::default()).unwrap(); - /// assert_eq!(expr, b_lt_2); - /// ``` - pub fn simplify(self, info: &S) -> Result { - let mut rewriter = Simplifier::new(info); - let mut const_evaluator = ConstEvaluator::new(info.execution_props()); - - // TODO iterate until no changes are made during rewrite - // (evaluating constants can enable new simplifications and - // simplifications can enable new constant evaluation) - // https://github.com/apache/arrow-datafusion/issues/1160 - self.rewrite(&mut const_evaluator)?.rewrite(&mut rewriter) - } -} - -impl Not for Expr { - type Output = Self; - - fn not(self) -> Self::Output { - Expr::Not(Box::new(self)) - } -} - -impl std::fmt::Display for Expr { - fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { - match self { - Expr::BinaryExpr { - ref left, - ref right, - ref op, - } => write!(f, "{} {} {}", left, op, right), - Expr::AggregateFunction { - /// Name of the function - ref fun, - /// List of expressions to feed to the functions as arguments - ref args, - /// Whether this is a DISTINCT aggregation or not - ref distinct, - } => fmt_function(f, &fun.to_string(), *distinct, args, true), - Expr::ScalarFunction { - /// Name of the function - ref fun, - /// List of expressions to feed to the functions as arguments - ref args, - } => fmt_function(f, &fun.to_string(), false, args, true), - _ => write!(f, "{:?}", self), - } - } -} - -#[allow(clippy::boxed_local)] -fn rewrite_boxed(boxed_expr: Box, rewriter: &mut R) -> Result> -where - R: ExprRewriter, -{ - // TODO: It might be possible to avoid an allocation (the - // Box::new) below by reusing the box. - let expr: Expr = *boxed_expr; - let rewritten_expr = expr.rewrite(rewriter)?; - Ok(Box::new(rewritten_expr)) -} - -fn rewrite_option_box( - option_box: Option>, - rewriter: &mut R, -) -> Result>> -where - R: ExprRewriter, -{ - option_box - .map(|expr| rewrite_boxed(expr, rewriter)) - .transpose() -} - -/// rewrite a `Vec` of `Expr`s with the rewriter -fn rewrite_vec(v: Vec, rewriter: &mut R) -> Result> -where - R: ExprRewriter, -{ - v.into_iter().map(|expr| expr.rewrite(rewriter)).collect() -} - -/// Controls how the visitor recursion should proceed. -pub enum Recursion { - /// Attempt to visit all the children, recursively, of this expression. - Continue(V), - /// Do not visit the children of this expression, though the walk - /// of parents of this expression will not be affected - Stop(V), -} - -/// Encode the traversal of an expression tree. When passed to -/// `Expr::accept`, `ExpressionVisitor::visit` is invoked -/// recursively on all nodes of an expression tree. See the comments -/// on `Expr::accept` for details on its use -pub trait ExpressionVisitor: Sized { - /// Invoked before any children of `expr` are visisted. - fn pre_visit(self, expr: &Expr) -> Result>; - - /// Invoked after all children of `expr` are visited. Default - /// implementation does nothing. - fn post_visit(self, _expr: &Expr) -> Result { - Ok(self) - } -} - -/// Controls how the [ExprRewriter] recursion should proceed. -pub enum RewriteRecursion { - /// Continue rewrite / visit this expression. - Continue, - /// Call [mutate()] immediately and return. - Mutate, - /// Do not rewrite / visit the children of this expression. - Stop, - /// Keep recursive but skip mutate on this expression - Skip, -} - -/// Trait for potentially recursively rewriting an [`Expr`] expression -/// tree. When passed to `Expr::rewrite`, `ExpressionVisitor::mutate` is -/// invoked recursively on all nodes of an expression tree. See the -/// comments on `Expr::rewrite` for details on its use -pub trait ExprRewriter: Sized { - /// Invoked before any children of `expr` are rewritten / - /// visited. Default implementation returns `Ok(RewriteRecursion::Continue)` - fn pre_visit(&mut self, _expr: &Expr) -> Result { - Ok(RewriteRecursion::Continue) - } - - /// Invoked after all children of `expr` have been mutated and - /// returns a potentially modified expr. - fn mutate(&mut self, expr: Expr) -> Result; -} - -/// The information necessary to apply algebraic simplification to an -/// [Expr]. See [SimplifyContext] for one implementation -pub trait SimplifyInfo { - /// returns true if this Expr has boolean type - fn is_boolean_type(&self, expr: &Expr) -> Result; - - /// returns true of this expr is nullable (could possibly be NULL) - fn nullable(&self, expr: &Expr) -> Result; - - /// Returns details needed for partial expression evaluation - fn execution_props(&self) -> &ExecutionProps; -} - /// Helper struct for building [Expr::Case] pub struct CaseBuilder { expr: Option>, @@ -1251,15 +129,6 @@ pub fn when(when: Expr, then: Expr) -> CaseBuilder { } } -/// return a new expression l r -pub fn binary_expr(l: Expr, op: Operator, r: Expr) -> Expr { - Expr::BinaryExpr { - left: Box::new(l), - op, - right: Box::new(r), - } -} - /// return a new expression with a logical AND pub fn and(left: Expr, right: Expr) -> Expr { Expr::BinaryExpr { @@ -1292,11 +161,6 @@ pub fn or(left: Expr, right: Expr) -> Expr { } } -/// Create a column expression based on a qualified or unqualified column name -pub fn col(ident: &str) -> Expr { - Expr::Column(ident.into()) -} - /// Convert an expression into Column expression if it's already provided as input plan. /// /// For example, it rewrites: @@ -1329,183 +193,6 @@ pub fn columnize_expr(e: Expr, input_schema: &DFSchema) -> Expr { } } -/// Recursively replace all Column expressions in a given expression tree with Column expressions -/// provided by the hash map argument. -pub fn replace_col(e: Expr, replace_map: &HashMap<&Column, &Column>) -> Result { - struct ColumnReplacer<'a> { - replace_map: &'a HashMap<&'a Column, &'a Column>, - } - - impl<'a> ExprRewriter for ColumnReplacer<'a> { - fn mutate(&mut self, expr: Expr) -> Result { - if let Expr::Column(c) = &expr { - match self.replace_map.get(c) { - Some(new_c) => Ok(Expr::Column((*new_c).to_owned())), - None => Ok(expr), - } - } else { - Ok(expr) - } - } - } - - e.rewrite(&mut ColumnReplacer { replace_map }) -} - -/// Recursively call [`Column::normalize`] on all Column expressions -/// in the `expr` expression tree. -pub fn normalize_col(expr: Expr, plan: &LogicalPlan) -> Result { - normalize_col_with_schemas(expr, &plan.all_schemas(), &plan.using_columns()?) -} - -/// Recursively call [`Column::normalize`] on all Column expressions -/// in the `expr` expression tree. -fn normalize_col_with_schemas( - expr: Expr, - schemas: &[&Arc], - using_columns: &[HashSet], -) -> Result { - struct ColumnNormalizer<'a> { - schemas: &'a [&'a Arc], - using_columns: &'a [HashSet], - } - - impl<'a> ExprRewriter for ColumnNormalizer<'a> { - fn mutate(&mut self, expr: Expr) -> Result { - if let Expr::Column(c) = expr { - Ok(Expr::Column(c.normalize_with_schemas( - self.schemas, - self.using_columns, - )?)) - } else { - Ok(expr) - } - } - } - - expr.rewrite(&mut ColumnNormalizer { - schemas, - using_columns, - }) -} - -/// Recursively normalize all Column expressions in a list of expression trees -pub fn normalize_cols( - exprs: impl IntoIterator>, - plan: &LogicalPlan, -) -> Result> { - exprs - .into_iter() - .map(|e| normalize_col(e.into(), plan)) - .collect() -} - -/// Rewrite sort on aggregate expressions to sort on the column of aggregate output -/// For example, `max(x)` is written to `col("MAX(x)")` -pub fn rewrite_sort_cols_by_aggs( - exprs: impl IntoIterator>, - plan: &LogicalPlan, -) -> Result> { - exprs - .into_iter() - .map(|e| { - let expr = e.into(); - match expr { - Expr::Sort { - expr, - asc, - nulls_first, - } => { - let sort = Expr::Sort { - expr: Box::new(rewrite_sort_col_by_aggs(*expr, plan)?), - asc, - nulls_first, - }; - Ok(sort) - } - expr => Ok(expr), - } - }) - .collect() -} - -fn rewrite_sort_col_by_aggs(expr: Expr, plan: &LogicalPlan) -> Result { - match plan { - LogicalPlan::Aggregate(Aggregate { - input, aggr_expr, .. - }) => { - struct Rewriter<'a> { - plan: &'a LogicalPlan, - input: &'a LogicalPlan, - aggr_expr: &'a Vec, - } - - impl<'a> ExprRewriter for Rewriter<'a> { - fn mutate(&mut self, expr: Expr) -> Result { - let normalized_expr = normalize_col(expr.clone(), self.plan); - if normalized_expr.is_err() { - // The expr is not based on Aggregate plan output. Skip it. - return Ok(expr); - } - let normalized_expr = normalized_expr.unwrap(); - if let Some(found_agg) = - self.aggr_expr.iter().find(|a| (**a) == normalized_expr) - { - let agg = normalize_col(found_agg.clone(), self.plan)?; - let col = Expr::Column( - agg.to_field(self.input.schema()) - .map(|f| f.qualified_column())?, - ); - Ok(col) - } else { - Ok(expr) - } - } - } - - expr.rewrite(&mut Rewriter { - plan, - input, - aggr_expr, - }) - } - LogicalPlan::Projection(_) => rewrite_sort_col_by_aggs(expr, plan.inputs()[0]), - _ => Ok(expr), - } -} - -/// Recursively 'unnormalize' (remove all qualifiers) from an -/// expression tree. -/// -/// For example, if there were expressions like `foo.bar` this would -/// rewrite it to just `bar`. -pub fn unnormalize_col(expr: Expr) -> Expr { - struct RemoveQualifier {} - - impl ExprRewriter for RemoveQualifier { - fn mutate(&mut self, expr: Expr) -> Result { - if let Expr::Column(col) = expr { - //let Column { relation: _, name } = col; - Ok(Expr::Column(Column { - relation: None, - name: col.name, - })) - } else { - Ok(expr) - } - } - } - - expr.rewrite(&mut RemoveQualifier {}) - .expect("Unnormalize is infallable") -} - -/// Recursively un-normalize all Column expressions in a list of expression trees -#[inline] -pub fn unnormalize_cols(exprs: impl IntoIterator) -> Vec { - exprs.into_iter().map(unnormalize_col).collect() -} - /// Recursively un-alias an expressions #[inline] pub fn unalias(expr: Expr) -> Expr { @@ -1578,102 +265,6 @@ pub fn in_list(expr: Expr, list: Vec, negated: bool) -> Expr { } } -/// Trait for converting a type to a [`Literal`] literal expression. -pub trait Literal { - /// convert the value to a Literal expression - fn lit(&self) -> Expr; -} - -/// Trait for converting a type to a literal timestamp -pub trait TimestampLiteral { - fn lit_timestamp_nano(&self) -> Expr; -} - -impl Literal for &str { - fn lit(&self) -> Expr { - Expr::Literal(ScalarValue::Utf8(Some((*self).to_owned()))) - } -} - -impl Literal for String { - fn lit(&self) -> Expr { - Expr::Literal(ScalarValue::Utf8(Some((*self).to_owned()))) - } -} - -impl Literal for Vec { - fn lit(&self) -> Expr { - Expr::Literal(ScalarValue::Binary(Some((*self).to_owned()))) - } -} - -impl Literal for &[u8] { - fn lit(&self) -> Expr { - Expr::Literal(ScalarValue::Binary(Some((*self).to_owned()))) - } -} - -impl Literal for ScalarValue { - fn lit(&self) -> Expr { - Expr::Literal(self.clone()) - } -} - -macro_rules! make_literal { - ($TYPE:ty, $SCALAR:ident, $DOC: expr) => { - #[doc = $DOC] - impl Literal for $TYPE { - fn lit(&self) -> Expr { - Expr::Literal(ScalarValue::$SCALAR(Some(self.clone()))) - } - } - }; -} - -macro_rules! make_timestamp_literal { - ($TYPE:ty, $SCALAR:ident, $DOC: expr) => { - #[doc = $DOC] - impl TimestampLiteral for $TYPE { - fn lit_timestamp_nano(&self) -> Expr { - Expr::Literal(ScalarValue::TimestampNanosecond( - Some((self.clone()).into()), - None, - )) - } - } - }; -} - -make_literal!(bool, Boolean, "literal expression containing a bool"); -make_literal!(f32, Float32, "literal expression containing an f32"); -make_literal!(f64, Float64, "literal expression containing an f64"); -make_literal!(i8, Int8, "literal expression containing an i8"); -make_literal!(i16, Int16, "literal expression containing an i16"); -make_literal!(i32, Int32, "literal expression containing an i32"); -make_literal!(i64, Int64, "literal expression containing an i64"); -make_literal!(u8, UInt8, "literal expression containing a u8"); -make_literal!(u16, UInt16, "literal expression containing a u16"); -make_literal!(u32, UInt32, "literal expression containing a u32"); -make_literal!(u64, UInt64, "literal expression containing a u64"); - -make_timestamp_literal!(i8, Int8, "literal expression containing an i8"); -make_timestamp_literal!(i16, Int16, "literal expression containing an i16"); -make_timestamp_literal!(i32, Int32, "literal expression containing an i32"); -make_timestamp_literal!(i64, Int64, "literal expression containing an i64"); -make_timestamp_literal!(u8, UInt8, "literal expression containing a u8"); -make_timestamp_literal!(u16, UInt16, "literal expression containing a u16"); -make_timestamp_literal!(u32, UInt32, "literal expression containing a u32"); - -/// Create a literal expression -pub fn lit(n: T) -> Expr { - n.lit() -} - -/// Create a literal timestamp expression -pub fn lit_timestamp_nano(n: T) -> Expr { - n.lit_timestamp_nano() -} - /// Concatenates the text representations of all the arguments. NULL arguments are ignored. pub fn concat(args: &[Expr]) -> Expr { Expr::ScalarFunction { @@ -1878,311 +469,6 @@ pub fn create_udaf( ) } -fn fmt_function( - f: &mut fmt::Formatter, - fun: &str, - distinct: bool, - args: &[Expr], - display: bool, -) -> fmt::Result { - let args: Vec = match display { - true => args.iter().map(|arg| format!("{}", arg)).collect(), - false => args.iter().map(|arg| format!("{:?}", arg)).collect(), - }; - - // let args: Vec = args.iter().map(|arg| format!("{:?}", arg)).collect(); - let distinct_str = match distinct { - true => "DISTINCT ", - false => "", - }; - write!(f, "{}({}{})", fun, distinct_str, args.join(", ")) -} - -impl fmt::Debug for Expr { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - match self { - Expr::Alias(expr, alias) => write!(f, "{:?} AS {}", expr, alias), - Expr::Column(c) => write!(f, "{}", c), - Expr::ScalarVariable(var_names) => write!(f, "{}", var_names.join(".")), - Expr::Literal(v) => write!(f, "{:?}", v), - Expr::Case { - expr, - when_then_expr, - else_expr, - .. - } => { - write!(f, "CASE ")?; - if let Some(e) = expr { - write!(f, "{:?} ", e)?; - } - for (w, t) in when_then_expr { - write!(f, "WHEN {:?} THEN {:?} ", w, t)?; - } - if let Some(e) = else_expr { - write!(f, "ELSE {:?} ", e)?; - } - write!(f, "END") - } - Expr::Cast { expr, data_type } => { - write!(f, "CAST({:?} AS {:?})", expr, data_type) - } - Expr::TryCast { expr, data_type } => { - write!(f, "TRY_CAST({:?} AS {:?})", expr, data_type) - } - Expr::Not(expr) => write!(f, "NOT {:?}", expr), - Expr::Negative(expr) => write!(f, "(- {:?})", expr), - Expr::IsNull(expr) => write!(f, "{:?} IS NULL", expr), - Expr::IsNotNull(expr) => write!(f, "{:?} IS NOT NULL", expr), - Expr::BinaryExpr { left, op, right } => { - write!(f, "{:?} {} {:?}", left, op, right) - } - Expr::Sort { - expr, - asc, - nulls_first, - } => { - if *asc { - write!(f, "{:?} ASC", expr)?; - } else { - write!(f, "{:?} DESC", expr)?; - } - if *nulls_first { - write!(f, " NULLS FIRST") - } else { - write!(f, " NULLS LAST") - } - } - Expr::ScalarFunction { fun, args, .. } => { - fmt_function(f, &fun.to_string(), false, args, false) - } - Expr::ScalarUDF { fun, ref args, .. } => { - fmt_function(f, &fun.name, false, args, false) - } - Expr::WindowFunction { - fun, - args, - partition_by, - order_by, - window_frame, - } => { - fmt_function(f, &fun.to_string(), false, args, false)?; - if !partition_by.is_empty() { - write!(f, " PARTITION BY {:?}", partition_by)?; - } - if !order_by.is_empty() { - write!(f, " ORDER BY {:?}", order_by)?; - } - if let Some(window_frame) = window_frame { - write!( - f, - " {} BETWEEN {} AND {}", - window_frame.units, - window_frame.start_bound, - window_frame.end_bound - )?; - } - Ok(()) - } - Expr::AggregateFunction { - fun, - distinct, - ref args, - .. - } => fmt_function(f, &fun.to_string(), *distinct, args, true), - Expr::AggregateUDF { fun, ref args, .. } => { - fmt_function(f, &fun.name, false, args, false) - } - Expr::Between { - expr, - negated, - low, - high, - } => { - if *negated { - write!(f, "{:?} NOT BETWEEN {:?} AND {:?}", expr, low, high) - } else { - write!(f, "{:?} BETWEEN {:?} AND {:?}", expr, low, high) - } - } - Expr::InList { - expr, - list, - negated, - } => { - if *negated { - write!(f, "{:?} NOT IN ({:?})", expr, list) - } else { - write!(f, "{:?} IN ({:?})", expr, list) - } - } - Expr::Wildcard => write!(f, "*"), - Expr::GetIndexedField { ref expr, key } => { - write!(f, "({:?})[{}]", expr, key) - } - } - } -} - -fn create_function_name( - fun: &str, - distinct: bool, - args: &[Expr], - input_schema: &DFSchema, -) -> Result { - let names: Vec = args - .iter() - .map(|e| create_name(e, input_schema)) - .collect::>()?; - let distinct_str = match distinct { - true => "DISTINCT ", - false => "", - }; - Ok(format!("{}({}{})", fun, distinct_str, names.join(","))) -} - -/// Returns a readable name of an expression based on the input schema. -/// This function recursively transverses the expression for names such as "CAST(a > 2)". -fn create_name(e: &Expr, input_schema: &DFSchema) -> Result { - match e { - Expr::Alias(_, name) => Ok(name.clone()), - Expr::Column(c) => Ok(c.flat_name()), - Expr::ScalarVariable(variable_names) => Ok(variable_names.join(".")), - Expr::Literal(value) => Ok(format!("{:?}", value)), - Expr::BinaryExpr { left, op, right } => { - let left = create_name(left, input_schema)?; - let right = create_name(right, input_schema)?; - Ok(format!("{} {} {}", left, op, right)) - } - Expr::Case { - expr, - when_then_expr, - else_expr, - } => { - let mut name = "CASE ".to_string(); - if let Some(e) = expr { - let e = create_name(e, input_schema)?; - name += &format!("{} ", e); - } - for (w, t) in when_then_expr { - let when = create_name(w, input_schema)?; - let then = create_name(t, input_schema)?; - name += &format!("WHEN {} THEN {} ", when, then); - } - if let Some(e) = else_expr { - let e = create_name(e, input_schema)?; - name += &format!("ELSE {} ", e); - } - name += "END"; - Ok(name) - } - Expr::Cast { expr, data_type } => { - let expr = create_name(expr, input_schema)?; - Ok(format!("CAST({} AS {:?})", expr, data_type)) - } - Expr::TryCast { expr, data_type } => { - let expr = create_name(expr, input_schema)?; - Ok(format!("TRY_CAST({} AS {:?})", expr, data_type)) - } - Expr::Not(expr) => { - let expr = create_name(expr, input_schema)?; - Ok(format!("NOT {}", expr)) - } - Expr::Negative(expr) => { - let expr = create_name(expr, input_schema)?; - Ok(format!("(- {})", expr)) - } - Expr::IsNull(expr) => { - let expr = create_name(expr, input_schema)?; - Ok(format!("{} IS NULL", expr)) - } - Expr::IsNotNull(expr) => { - let expr = create_name(expr, input_schema)?; - Ok(format!("{} IS NOT NULL", expr)) - } - Expr::GetIndexedField { expr, key } => { - let expr = create_name(expr, input_schema)?; - Ok(format!("{}[{}]", expr, key)) - } - Expr::ScalarFunction { fun, args, .. } => { - create_function_name(&fun.to_string(), false, args, input_schema) - } - Expr::ScalarUDF { fun, args, .. } => { - create_function_name(&fun.name, false, args, input_schema) - } - Expr::WindowFunction { - fun, - args, - window_frame, - partition_by, - order_by, - } => { - let mut parts: Vec = vec![create_function_name( - &fun.to_string(), - false, - args, - input_schema, - )?]; - if !partition_by.is_empty() { - parts.push(format!("PARTITION BY {:?}", partition_by)); - } - if !order_by.is_empty() { - parts.push(format!("ORDER BY {:?}", order_by)); - } - if let Some(window_frame) = window_frame { - parts.push(format!("{}", window_frame)); - } - Ok(parts.join(" ")) - } - Expr::AggregateFunction { - fun, - distinct, - args, - .. - } => create_function_name(&fun.to_string(), *distinct, args, input_schema), - Expr::AggregateUDF { fun, args } => { - let mut names = Vec::with_capacity(args.len()); - for e in args { - names.push(create_name(e, input_schema)?); - } - Ok(format!("{}({})", fun.name, names.join(","))) - } - Expr::InList { - expr, - list, - negated, - } => { - let expr = create_name(expr, input_schema)?; - let list = list.iter().map(|expr| create_name(expr, input_schema)); - if *negated { - Ok(format!("{} NOT IN ({:?})", expr, list)) - } else { - Ok(format!("{} IN ({:?})", expr, list)) - } - } - Expr::Between { - expr, - negated, - low, - high, - } => { - let expr = create_name(expr, input_schema)?; - let low = create_name(low, input_schema)?; - let high = create_name(high, input_schema)?; - if *negated { - Ok(format!("{} NOT BETWEEN {} AND {}", expr, low, high)) - } else { - Ok(format!("{} BETWEEN {} AND {}", expr, low, high)) - } - } - Expr::Sort { .. } => Err(DataFusionError::Internal( - "Create name does not support sort expression".to_string(), - )), - Expr::Wildcard => Err(DataFusionError::Internal( - "Create name does not support wildcard".to_string(), - )), - } -} - /// Create field meta-data from an expression, for use in a result set schema pub fn exprlist_to_fields<'a>( expr: impl IntoIterator, @@ -2191,10 +477,25 @@ pub fn exprlist_to_fields<'a>( expr.into_iter().map(|e| e.to_field(input_schema)).collect() } +/// Calls a named built in function +/// ``` +/// use datafusion::logical_plan::*; +/// +/// // create the expression sin(x) < 0.2 +/// let expr = call_fn("sin", vec![col("x")]).unwrap().lt(lit(0.2)); +/// ``` +pub fn call_fn(name: impl AsRef, args: Vec) -> Result { + match name.as_ref().parse::() { + Ok(fun) => Ok(Expr::ScalarFunction { fun, args }), + Err(e) => Err(e), + } +} + #[cfg(test)] mod tests { use super::super::{col, lit, when}; use super::*; + use datafusion_expr::expr_fn::binary_expr; #[test] fn case_when_same_literal_then_types() -> Result<()> { @@ -2212,40 +513,6 @@ mod tests { assert!(maybe_expr.is_err()); } - #[test] - fn test_lit_timestamp_nano() { - let expr = col("time").eq(lit_timestamp_nano(10)); // 10 is an implicit i32 - let expected = - col("time").eq(lit(ScalarValue::TimestampNanosecond(Some(10), None))); - assert_eq!(expr, expected); - - let i: i64 = 10; - let expr = col("time").eq(lit_timestamp_nano(i)); - assert_eq!(expr, expected); - - let i: u32 = 10; - let expr = col("time").eq(lit_timestamp_nano(i)); - assert_eq!(expr, expected); - } - - #[test] - fn rewriter_visit() { - let mut rewriter = RecordingRewriter::default(); - col("state").eq(lit("CO")).rewrite(&mut rewriter).unwrap(); - - assert_eq!( - rewriter.v, - vec![ - "Previsited #state = Utf8(\"CO\")", - "Previsited #state", - "Mutated #state", - "Previsited Utf8(\"CO\")", - "Mutated Utf8(\"CO\")", - "Mutated #state = Utf8(\"CO\")" - ] - ) - } - #[test] fn filter_is_null_and_is_not_null() { let col_null = col("col1"); @@ -2257,128 +524,6 @@ mod tests { ); } - #[derive(Default)] - struct RecordingRewriter { - v: Vec, - } - impl ExprRewriter for RecordingRewriter { - fn mutate(&mut self, expr: Expr) -> Result { - self.v.push(format!("Mutated {:?}", expr)); - Ok(expr) - } - - fn pre_visit(&mut self, expr: &Expr) -> Result { - self.v.push(format!("Previsited {:?}", expr)); - Ok(RewriteRecursion::Continue) - } - } - - #[test] - fn rewriter_rewrite() { - let mut rewriter = FooBarRewriter {}; - - // rewrites "foo" --> "bar" - let rewritten = col("state").eq(lit("foo")).rewrite(&mut rewriter).unwrap(); - assert_eq!(rewritten, col("state").eq(lit("bar"))); - - // doesn't wrewrite - let rewritten = col("state").eq(lit("baz")).rewrite(&mut rewriter).unwrap(); - assert_eq!(rewritten, col("state").eq(lit("baz"))); - } - - /// rewrites all "foo" string literals to "bar" - struct FooBarRewriter {} - impl ExprRewriter for FooBarRewriter { - fn mutate(&mut self, expr: Expr) -> Result { - match expr { - Expr::Literal(ScalarValue::Utf8(Some(utf8_val))) => { - let utf8_val = if utf8_val == "foo" { - "bar".to_string() - } else { - utf8_val - }; - Ok(lit(utf8_val)) - } - // otherwise, return the expression unchanged - expr => Ok(expr), - } - } - } - - #[test] - fn normalize_cols() { - let expr = col("a") + col("b") + col("c"); - - // Schemas with some matching and some non matching cols - let schema_a = - DFSchema::new(vec![make_field("tableA", "a"), make_field("tableA", "aa")]) - .unwrap(); - let schema_c = - DFSchema::new(vec![make_field("tableC", "cc"), make_field("tableC", "c")]) - .unwrap(); - let schema_b = DFSchema::new(vec![make_field("tableB", "b")]).unwrap(); - // non matching - let schema_f = - DFSchema::new(vec![make_field("tableC", "f"), make_field("tableC", "ff")]) - .unwrap(); - let schemas = vec![schema_c, schema_f, schema_b, schema_a] - .into_iter() - .map(Arc::new) - .collect::>(); - let schemas = schemas.iter().collect::>(); - - let normalized_expr = normalize_col_with_schemas(expr, &schemas, &[]).unwrap(); - assert_eq!( - normalized_expr, - col("tableA.a") + col("tableB.b") + col("tableC.c") - ); - } - - #[test] - fn normalize_cols_priority() { - let expr = col("a") + col("b"); - // Schemas with multiple matches for column a, first takes priority - let schema_a = DFSchema::new(vec![make_field("tableA", "a")]).unwrap(); - let schema_b = DFSchema::new(vec![make_field("tableB", "b")]).unwrap(); - let schema_a2 = DFSchema::new(vec![make_field("tableA2", "a")]).unwrap(); - let schemas = vec![schema_a2, schema_b, schema_a] - .into_iter() - .map(Arc::new) - .collect::>(); - let schemas = schemas.iter().collect::>(); - - let normalized_expr = normalize_col_with_schemas(expr, &schemas, &[]).unwrap(); - assert_eq!(normalized_expr, col("tableA2.a") + col("tableB.b")); - } - - #[test] - fn normalize_cols_non_exist() { - // test normalizing columns when the name doesn't exist - let expr = col("a") + col("b"); - let schema_a = DFSchema::new(vec![make_field("tableA", "a")]).unwrap(); - let schemas = vec![schema_a].into_iter().map(Arc::new).collect::>(); - let schemas = schemas.iter().collect::>(); - - let error = normalize_col_with_schemas(expr, &schemas, &[]) - .unwrap_err() - .to_string(); - assert_eq!( - error, - "Error during planning: Column #b not found in provided schemas" - ); - } - - #[test] - fn unnormalize_cols() { - let expr = col("tableA.a") + col("tableB.b"); - let unnormalized_expr = unnormalize_col(expr); - assert_eq!(unnormalized_expr, col("a") + col("b")); - } - - fn make_field(relation: &str, column: &str) -> DFField { - DFField::new(Some(relation), column, DataType::Int8, false) - } - #[test] fn test_not() { assert_eq!(lit(1).not(), !lit(1)); @@ -2559,4 +704,57 @@ mod tests { combine_filters(&[filter1.clone(), filter2.clone(), filter3.clone()]); assert_eq!(result, Some(and(and(filter1, filter2), filter3))); } + + #[test] + fn expr_schema_nullability() { + let expr = col("foo").eq(lit(1)); + assert!(!expr.nullable(&MockExprSchema::new()).unwrap()); + assert!(expr + .nullable(&MockExprSchema::new().with_nullable(true)) + .unwrap()); + } + + #[test] + fn expr_schema_data_type() { + let expr = col("foo"); + assert_eq!( + DataType::Utf8, + expr.get_type(&MockExprSchema::new().with_data_type(DataType::Utf8)) + .unwrap() + ); + } + + struct MockExprSchema { + nullable: bool, + data_type: DataType, + } + + impl MockExprSchema { + fn new() -> Self { + Self { + nullable: false, + data_type: DataType::Null, + } + } + + fn with_nullable(mut self, nullable: bool) -> Self { + self.nullable = nullable; + self + } + + fn with_data_type(mut self, data_type: DataType) -> Self { + self.data_type = data_type; + self + } + } + + impl ExprSchema for MockExprSchema { + fn nullable(&self, _col: &Column) -> Result { + Ok(self.nullable) + } + + fn data_type(&self, _col: &Column) -> Result<&DataType> { + Ok(&self.data_type) + } + } } diff --git a/datafusion/src/logical_plan/expr_rewriter.rs b/datafusion/src/logical_plan/expr_rewriter.rs new file mode 100644 index 000000000000..5062d5fce7ad --- /dev/null +++ b/datafusion/src/logical_plan/expr_rewriter.rs @@ -0,0 +1,592 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Expression rewriter + +use super::Expr; +use crate::logical_plan::plan::Aggregate; +use crate::logical_plan::DFSchema; +use crate::logical_plan::ExprSchemable; +use crate::logical_plan::LogicalPlan; +use datafusion_common::Column; +use datafusion_common::Result; +use std::collections::HashMap; +use std::collections::HashSet; +use std::sync::Arc; + +/// Controls how the [ExprRewriter] recursion should proceed. +pub enum RewriteRecursion { + /// Continue rewrite / visit this expression. + Continue, + /// Call [mutate()] immediately and return. + Mutate, + /// Do not rewrite / visit the children of this expression. + Stop, + /// Keep recursive but skip mutate on this expression + Skip, +} + +/// Trait for potentially recursively rewriting an [`Expr`] expression +/// tree. When passed to `Expr::rewrite`, `ExpressionVisitor::mutate` is +/// invoked recursively on all nodes of an expression tree. See the +/// comments on `Expr::rewrite` for details on its use +pub trait ExprRewriter: Sized { + /// Invoked before any children of `expr` are rewritten / + /// visited. Default implementation returns `Ok(RewriteRecursion::Continue)` + fn pre_visit(&mut self, _expr: &E) -> Result { + Ok(RewriteRecursion::Continue) + } + + /// Invoked after all children of `expr` have been mutated and + /// returns a potentially modified expr. + fn mutate(&mut self, expr: E) -> Result; +} + +/// a trait for marking types that are rewritable by [ExprRewriter] +pub trait ExprRewritable: Sized { + /// rewrite the expression tree using the given [ExprRewriter] + fn rewrite>(self, rewriter: &mut R) -> Result; +} + +impl ExprRewritable for Expr { + /// Performs a depth first walk of an expression and its children + /// to rewrite an expression, consuming `self` producing a new + /// [`Expr`]. + /// + /// Implements a modified version of the [visitor + /// pattern](https://en.wikipedia.org/wiki/Visitor_pattern) to + /// separate algorithms from the structure of the `Expr` tree and + /// make it easier to write new, efficient expression + /// transformation algorithms. + /// + /// For an expression tree such as + /// ```text + /// BinaryExpr (GT) + /// left: Column("foo") + /// right: Column("bar") + /// ``` + /// + /// The nodes are visited using the following order + /// ```text + /// pre_visit(BinaryExpr(GT)) + /// pre_visit(Column("foo")) + /// mutatate(Column("foo")) + /// pre_visit(Column("bar")) + /// mutate(Column("bar")) + /// mutate(BinaryExpr(GT)) + /// ``` + /// + /// If an Err result is returned, recursion is stopped immediately + /// + /// If [`false`] is returned on a call to pre_visit, no + /// children of that expression are visited, nor is mutate + /// called on that expression + /// + fn rewrite(self, rewriter: &mut R) -> Result + where + R: ExprRewriter, + { + let need_mutate = match rewriter.pre_visit(&self)? { + RewriteRecursion::Mutate => return rewriter.mutate(self), + RewriteRecursion::Stop => return Ok(self), + RewriteRecursion::Continue => true, + RewriteRecursion::Skip => false, + }; + + // recurse into all sub expressions(and cover all expression types) + let expr = match self { + Expr::Alias(expr, name) => Expr::Alias(rewrite_boxed(expr, rewriter)?, name), + Expr::Column(_) => self.clone(), + Expr::ScalarVariable(names) => Expr::ScalarVariable(names), + Expr::Literal(value) => Expr::Literal(value), + Expr::BinaryExpr { left, op, right } => Expr::BinaryExpr { + left: rewrite_boxed(left, rewriter)?, + op, + right: rewrite_boxed(right, rewriter)?, + }, + Expr::Not(expr) => Expr::Not(rewrite_boxed(expr, rewriter)?), + Expr::IsNotNull(expr) => Expr::IsNotNull(rewrite_boxed(expr, rewriter)?), + Expr::IsNull(expr) => Expr::IsNull(rewrite_boxed(expr, rewriter)?), + Expr::Negative(expr) => Expr::Negative(rewrite_boxed(expr, rewriter)?), + Expr::Between { + expr, + low, + high, + negated, + } => Expr::Between { + expr: rewrite_boxed(expr, rewriter)?, + low: rewrite_boxed(low, rewriter)?, + high: rewrite_boxed(high, rewriter)?, + negated, + }, + Expr::Case { + expr, + when_then_expr, + else_expr, + } => { + let expr = rewrite_option_box(expr, rewriter)?; + let when_then_expr = when_then_expr + .into_iter() + .map(|(when, then)| { + Ok(( + rewrite_boxed(when, rewriter)?, + rewrite_boxed(then, rewriter)?, + )) + }) + .collect::>>()?; + + let else_expr = rewrite_option_box(else_expr, rewriter)?; + + Expr::Case { + expr, + when_then_expr, + else_expr, + } + } + Expr::Cast { expr, data_type } => Expr::Cast { + expr: rewrite_boxed(expr, rewriter)?, + data_type, + }, + Expr::TryCast { expr, data_type } => Expr::TryCast { + expr: rewrite_boxed(expr, rewriter)?, + data_type, + }, + Expr::Sort { + expr, + asc, + nulls_first, + } => Expr::Sort { + expr: rewrite_boxed(expr, rewriter)?, + asc, + nulls_first, + }, + Expr::ScalarFunction { args, fun } => Expr::ScalarFunction { + args: rewrite_vec(args, rewriter)?, + fun, + }, + Expr::ScalarUDF { args, fun } => Expr::ScalarUDF { + args: rewrite_vec(args, rewriter)?, + fun, + }, + Expr::WindowFunction { + args, + fun, + partition_by, + order_by, + window_frame, + } => Expr::WindowFunction { + args: rewrite_vec(args, rewriter)?, + fun, + partition_by: rewrite_vec(partition_by, rewriter)?, + order_by: rewrite_vec(order_by, rewriter)?, + window_frame, + }, + Expr::AggregateFunction { + args, + fun, + distinct, + } => Expr::AggregateFunction { + args: rewrite_vec(args, rewriter)?, + fun, + distinct, + }, + Expr::AggregateUDF { args, fun } => Expr::AggregateUDF { + args: rewrite_vec(args, rewriter)?, + fun, + }, + Expr::InList { + expr, + list, + negated, + } => Expr::InList { + expr: rewrite_boxed(expr, rewriter)?, + list: rewrite_vec(list, rewriter)?, + negated, + }, + Expr::Wildcard => Expr::Wildcard, + Expr::GetIndexedField { expr, key } => Expr::GetIndexedField { + expr: rewrite_boxed(expr, rewriter)?, + key, + }, + }; + + // now rewrite this expression itself + if need_mutate { + rewriter.mutate(expr) + } else { + Ok(expr) + } + } +} + +#[allow(clippy::boxed_local)] +fn rewrite_boxed(boxed_expr: Box, rewriter: &mut R) -> Result> +where + R: ExprRewriter, +{ + // TODO: It might be possible to avoid an allocation (the + // Box::new) below by reusing the box. + let expr: Expr = *boxed_expr; + let rewritten_expr = expr.rewrite(rewriter)?; + Ok(Box::new(rewritten_expr)) +} + +fn rewrite_option_box( + option_box: Option>, + rewriter: &mut R, +) -> Result>> +where + R: ExprRewriter, +{ + option_box + .map(|expr| rewrite_boxed(expr, rewriter)) + .transpose() +} + +/// rewrite a `Vec` of `Expr`s with the rewriter +fn rewrite_vec(v: Vec, rewriter: &mut R) -> Result> +where + R: ExprRewriter, +{ + v.into_iter().map(|expr| expr.rewrite(rewriter)).collect() +} + +/// Rewrite sort on aggregate expressions to sort on the column of aggregate output +/// For example, `max(x)` is written to `col("MAX(x)")` +pub fn rewrite_sort_cols_by_aggs( + exprs: impl IntoIterator>, + plan: &LogicalPlan, +) -> Result> { + exprs + .into_iter() + .map(|e| { + let expr = e.into(); + match expr { + Expr::Sort { + expr, + asc, + nulls_first, + } => { + let sort = Expr::Sort { + expr: Box::new(rewrite_sort_col_by_aggs(*expr, plan)?), + asc, + nulls_first, + }; + Ok(sort) + } + expr => Ok(expr), + } + }) + .collect() +} + +fn rewrite_sort_col_by_aggs(expr: Expr, plan: &LogicalPlan) -> Result { + match plan { + LogicalPlan::Aggregate(Aggregate { + input, aggr_expr, .. + }) => { + struct Rewriter<'a> { + plan: &'a LogicalPlan, + input: &'a LogicalPlan, + aggr_expr: &'a Vec, + } + + impl<'a> ExprRewriter for Rewriter<'a> { + fn mutate(&mut self, expr: Expr) -> Result { + let normalized_expr = normalize_col(expr.clone(), self.plan); + if normalized_expr.is_err() { + // The expr is not based on Aggregate plan output. Skip it. + return Ok(expr); + } + let normalized_expr = normalized_expr.unwrap(); + if let Some(found_agg) = + self.aggr_expr.iter().find(|a| (**a) == normalized_expr) + { + let agg = normalize_col(found_agg.clone(), self.plan)?; + let col = Expr::Column( + agg.to_field(self.input.schema()) + .map(|f| f.qualified_column())?, + ); + Ok(col) + } else { + Ok(expr) + } + } + } + + expr.rewrite(&mut Rewriter { + plan, + input, + aggr_expr, + }) + } + LogicalPlan::Projection(_) => rewrite_sort_col_by_aggs(expr, plan.inputs()[0]), + _ => Ok(expr), + } +} + +/// Recursively call [`Column::normalize`] on all Column expressions +/// in the `expr` expression tree. +pub fn normalize_col(expr: Expr, plan: &LogicalPlan) -> Result { + normalize_col_with_schemas(expr, &plan.all_schemas(), &plan.using_columns()?) +} + +/// Recursively call [`Column::normalize`] on all Column expressions +/// in the `expr` expression tree. +fn normalize_col_with_schemas( + expr: Expr, + schemas: &[&Arc], + using_columns: &[HashSet], +) -> Result { + struct ColumnNormalizer<'a> { + schemas: &'a [&'a Arc], + using_columns: &'a [HashSet], + } + + impl<'a> ExprRewriter for ColumnNormalizer<'a> { + fn mutate(&mut self, expr: Expr) -> Result { + if let Expr::Column(c) = expr { + Ok(Expr::Column(c.normalize_with_schemas( + self.schemas, + self.using_columns, + )?)) + } else { + Ok(expr) + } + } + } + + expr.rewrite(&mut ColumnNormalizer { + schemas, + using_columns, + }) +} + +/// Recursively normalize all Column expressions in a list of expression trees +pub fn normalize_cols( + exprs: impl IntoIterator>, + plan: &LogicalPlan, +) -> Result> { + exprs + .into_iter() + .map(|e| normalize_col(e.into(), plan)) + .collect() +} + +/// Recursively replace all Column expressions in a given expression tree with Column expressions +/// provided by the hash map argument. +pub fn replace_col(e: Expr, replace_map: &HashMap<&Column, &Column>) -> Result { + struct ColumnReplacer<'a> { + replace_map: &'a HashMap<&'a Column, &'a Column>, + } + + impl<'a> ExprRewriter for ColumnReplacer<'a> { + fn mutate(&mut self, expr: Expr) -> Result { + if let Expr::Column(c) = &expr { + match self.replace_map.get(c) { + Some(new_c) => Ok(Expr::Column((*new_c).to_owned())), + None => Ok(expr), + } + } else { + Ok(expr) + } + } + } + + e.rewrite(&mut ColumnReplacer { replace_map }) +} + +/// Recursively 'unnormalize' (remove all qualifiers) from an +/// expression tree. +/// +/// For example, if there were expressions like `foo.bar` this would +/// rewrite it to just `bar`. +pub fn unnormalize_col(expr: Expr) -> Expr { + struct RemoveQualifier {} + + impl ExprRewriter for RemoveQualifier { + fn mutate(&mut self, expr: Expr) -> Result { + if let Expr::Column(col) = expr { + //let Column { relation: _, name } = col; + Ok(Expr::Column(Column { + relation: None, + name: col.name, + })) + } else { + Ok(expr) + } + } + } + + expr.rewrite(&mut RemoveQualifier {}) + .expect("Unnormalize is infallable") +} + +/// Recursively un-normalize all Column expressions in a list of expression trees +#[inline] +pub fn unnormalize_cols(exprs: impl IntoIterator) -> Vec { + exprs.into_iter().map(unnormalize_col).collect() +} + +#[cfg(test)] +mod test { + use super::*; + use crate::logical_plan::DFField; + use crate::prelude::{col, lit}; + use arrow::datatypes::DataType; + use datafusion_common::ScalarValue; + + #[derive(Default)] + struct RecordingRewriter { + v: Vec, + } + impl ExprRewriter for RecordingRewriter { + fn mutate(&mut self, expr: Expr) -> Result { + self.v.push(format!("Mutated {:?}", expr)); + Ok(expr) + } + + fn pre_visit(&mut self, expr: &Expr) -> Result { + self.v.push(format!("Previsited {:?}", expr)); + Ok(RewriteRecursion::Continue) + } + } + + #[test] + fn rewriter_rewrite() { + let mut rewriter = FooBarRewriter {}; + + // rewrites "foo" --> "bar" + let rewritten = col("state").eq(lit("foo")).rewrite(&mut rewriter).unwrap(); + assert_eq!(rewritten, col("state").eq(lit("bar"))); + + // doesn't wrewrite + let rewritten = col("state").eq(lit("baz")).rewrite(&mut rewriter).unwrap(); + assert_eq!(rewritten, col("state").eq(lit("baz"))); + } + + /// rewrites all "foo" string literals to "bar" + struct FooBarRewriter {} + impl ExprRewriter for FooBarRewriter { + fn mutate(&mut self, expr: Expr) -> Result { + match expr { + Expr::Literal(ScalarValue::Utf8(Some(utf8_val))) => { + let utf8_val = if utf8_val == "foo" { + "bar".to_string() + } else { + utf8_val + }; + Ok(lit(utf8_val)) + } + // otherwise, return the expression unchanged + expr => Ok(expr), + } + } + } + + #[test] + fn normalize_cols() { + let expr = col("a") + col("b") + col("c"); + + // Schemas with some matching and some non matching cols + let schema_a = + DFSchema::new(vec![make_field("tableA", "a"), make_field("tableA", "aa")]) + .unwrap(); + let schema_c = + DFSchema::new(vec![make_field("tableC", "cc"), make_field("tableC", "c")]) + .unwrap(); + let schema_b = DFSchema::new(vec![make_field("tableB", "b")]).unwrap(); + // non matching + let schema_f = + DFSchema::new(vec![make_field("tableC", "f"), make_field("tableC", "ff")]) + .unwrap(); + let schemas = vec![schema_c, schema_f, schema_b, schema_a] + .into_iter() + .map(Arc::new) + .collect::>(); + let schemas = schemas.iter().collect::>(); + + let normalized_expr = normalize_col_with_schemas(expr, &schemas, &[]).unwrap(); + assert_eq!( + normalized_expr, + col("tableA.a") + col("tableB.b") + col("tableC.c") + ); + } + + #[test] + fn normalize_cols_priority() { + let expr = col("a") + col("b"); + // Schemas with multiple matches for column a, first takes priority + let schema_a = DFSchema::new(vec![make_field("tableA", "a")]).unwrap(); + let schema_b = DFSchema::new(vec![make_field("tableB", "b")]).unwrap(); + let schema_a2 = DFSchema::new(vec![make_field("tableA2", "a")]).unwrap(); + let schemas = vec![schema_a2, schema_b, schema_a] + .into_iter() + .map(Arc::new) + .collect::>(); + let schemas = schemas.iter().collect::>(); + + let normalized_expr = normalize_col_with_schemas(expr, &schemas, &[]).unwrap(); + assert_eq!(normalized_expr, col("tableA2.a") + col("tableB.b")); + } + + #[test] + fn normalize_cols_non_exist() { + // test normalizing columns when the name doesn't exist + let expr = col("a") + col("b"); + let schema_a = DFSchema::new(vec![make_field("tableA", "a")]).unwrap(); + let schemas = vec![schema_a].into_iter().map(Arc::new).collect::>(); + let schemas = schemas.iter().collect::>(); + + let error = normalize_col_with_schemas(expr, &schemas, &[]) + .unwrap_err() + .to_string(); + assert_eq!( + error, + "Error during planning: Column #b not found in provided schemas" + ); + } + + #[test] + fn unnormalize_cols() { + let expr = col("tableA.a") + col("tableB.b"); + let unnormalized_expr = unnormalize_col(expr); + assert_eq!(unnormalized_expr, col("a") + col("b")); + } + + fn make_field(relation: &str, column: &str) -> DFField { + DFField::new(Some(relation), column, DataType::Int8, false) + } + + #[test] + fn rewriter_visit() { + let mut rewriter = RecordingRewriter::default(); + col("state").eq(lit("CO")).rewrite(&mut rewriter).unwrap(); + + assert_eq!( + rewriter.v, + vec![ + "Previsited #state = Utf8(\"CO\")", + "Previsited #state", + "Mutated #state", + "Previsited Utf8(\"CO\")", + "Mutated Utf8(\"CO\")", + "Mutated #state = Utf8(\"CO\")" + ] + ) + } +} diff --git a/datafusion/src/logical_plan/expr_schema.rs b/datafusion/src/logical_plan/expr_schema.rs new file mode 100644 index 000000000000..7bad353deaa7 --- /dev/null +++ b/datafusion/src/logical_plan/expr_schema.rs @@ -0,0 +1,232 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use super::Expr; +use crate::field_util::get_indexed_field; +use crate::physical_plan::{ + aggregates, expressions::binary_operator_data_type, functions, window_functions, +}; +use arrow::compute::cast::can_cast_types; +use arrow::datatypes::DataType; +use datafusion_common::field_util::FieldExt; +use datafusion_common::{DFField, DFSchema, DataFusionError, ExprSchema, Result}; + +/// trait to allow expr to typable with respect to a schema +pub trait ExprSchemable { + /// given a schema, return the type of the expr + fn get_type(&self, schema: &S) -> Result; + + /// given a schema, return the nullability of the expr + fn nullable(&self, input_schema: &S) -> Result; + + /// convert to a field with respect to a schema + fn to_field(&self, input_schema: &DFSchema) -> Result; + + /// cast to a type with respect to a schema + fn cast_to(self, cast_to_type: &DataType, schema: &S) -> Result; +} + +impl ExprSchemable for Expr { + /// Returns the [arrow::datatypes::DataType] of the expression + /// based on [ExprSchema] + /// + /// Note: [DFSchema] implements [ExprSchema]. + /// + /// # Errors + /// + /// This function errors when it is not possible to compute its + /// [arrow::datatypes::DataType]. This happens when e.g. the + /// expression refers to a column that does not exist in the + /// schema, or when the expression is incorrectly typed + /// (e.g. `[utf8] + [bool]`). + fn get_type(&self, schema: &S) -> Result { + match self { + Expr::Alias(expr, _) | Expr::Sort { expr, .. } | Expr::Negative(expr) => { + expr.get_type(schema) + } + Expr::Column(c) => Ok(schema.data_type(c)?.clone()), + Expr::ScalarVariable(_) => Ok(DataType::Utf8), + Expr::Literal(l) => Ok(l.get_datatype()), + Expr::Case { when_then_expr, .. } => when_then_expr[0].1.get_type(schema), + Expr::Cast { data_type, .. } | Expr::TryCast { data_type, .. } => { + Ok(data_type.clone()) + } + Expr::ScalarUDF { fun, args } => { + let data_types = args + .iter() + .map(|e| e.get_type(schema)) + .collect::>>()?; + Ok((fun.return_type)(&data_types)?.as_ref().clone()) + } + Expr::ScalarFunction { fun, args } => { + let data_types = args + .iter() + .map(|e| e.get_type(schema)) + .collect::>>()?; + functions::return_type(fun, &data_types) + } + Expr::WindowFunction { fun, args, .. } => { + let data_types = args + .iter() + .map(|e| e.get_type(schema)) + .collect::>>()?; + window_functions::return_type(fun, &data_types) + } + Expr::AggregateFunction { fun, args, .. } => { + let data_types = args + .iter() + .map(|e| e.get_type(schema)) + .collect::>>()?; + aggregates::return_type(fun, &data_types) + } + Expr::AggregateUDF { fun, args, .. } => { + let data_types = args + .iter() + .map(|e| e.get_type(schema)) + .collect::>>()?; + Ok((fun.return_type)(&data_types)?.as_ref().clone()) + } + Expr::Not(_) + | Expr::IsNull(_) + | Expr::Between { .. } + | Expr::InList { .. } + | Expr::IsNotNull(_) => Ok(DataType::Boolean), + Expr::BinaryExpr { + ref left, + ref right, + ref op, + } => binary_operator_data_type( + &left.get_type(schema)?, + op, + &right.get_type(schema)?, + ), + Expr::Wildcard => Err(DataFusionError::Internal( + "Wildcard expressions are not valid in a logical query plan".to_owned(), + )), + Expr::GetIndexedField { ref expr, key } => { + let data_type = expr.get_type(schema)?; + + get_indexed_field(&data_type, key).map(|x| x.data_type().clone()) + } + } + } + + /// Returns the nullability of the expression based on [ExprSchema]. + /// + /// Note: [DFSchema] implements [ExprSchema]. + /// + /// # Errors + /// + /// This function errors when it is not possible to compute its + /// nullability. This happens when the expression refers to a + /// column that does not exist in the schema. + fn nullable(&self, input_schema: &S) -> Result { + match self { + Expr::Alias(expr, _) + | Expr::Not(expr) + | Expr::Negative(expr) + | Expr::Sort { expr, .. } + | Expr::Between { expr, .. } + | Expr::InList { expr, .. } => expr.nullable(input_schema), + Expr::Column(c) => input_schema.nullable(c), + Expr::Literal(value) => Ok(value.is_null()), + Expr::Case { + when_then_expr, + else_expr, + .. + } => { + // this expression is nullable if any of the input expressions are nullable + let then_nullable = when_then_expr + .iter() + .map(|(_, t)| t.nullable(input_schema)) + .collect::>>()?; + if then_nullable.contains(&true) { + Ok(true) + } else if let Some(e) = else_expr { + e.nullable(input_schema) + } else { + Ok(false) + } + } + Expr::Cast { expr, .. } => expr.nullable(input_schema), + Expr::ScalarVariable(_) + | Expr::TryCast { .. } + | Expr::ScalarFunction { .. } + | Expr::ScalarUDF { .. } + | Expr::WindowFunction { .. } + | Expr::AggregateFunction { .. } + | Expr::AggregateUDF { .. } => Ok(true), + Expr::IsNull(_) | Expr::IsNotNull(_) => Ok(false), + Expr::BinaryExpr { + ref left, + ref right, + .. + } => Ok(left.nullable(input_schema)? || right.nullable(input_schema)?), + Expr::Wildcard => Err(DataFusionError::Internal( + "Wildcard expressions are not valid in a logical query plan".to_owned(), + )), + Expr::GetIndexedField { ref expr, key } => { + let data_type = expr.get_type(input_schema)?; + get_indexed_field(&data_type, key).map(|x| x.is_nullable()) + } + } + } + + /// Returns a [arrow::datatypes::Field] compatible with this expression. + fn to_field(&self, input_schema: &DFSchema) -> Result { + match self { + Expr::Column(c) => Ok(DFField::new( + c.relation.as_deref(), + &c.name, + self.get_type(input_schema)?, + self.nullable(input_schema)?, + )), + _ => Ok(DFField::new( + None, + &self.name(input_schema)?, + self.get_type(input_schema)?, + self.nullable(input_schema)?, + )), + } + } + + /// Wraps this expression in a cast to a target [arrow::datatypes::DataType]. + /// + /// # Errors + /// + /// This function errors when it is impossible to cast the + /// expression to the target [arrow::datatypes::DataType]. + fn cast_to(self, cast_to_type: &DataType, schema: &S) -> Result { + // TODO(kszucs): most of the operations do not validate the type correctness + // like all of the binary expressions below. Perhaps Expr should track the + // type of the expression? + let this_type = self.get_type(schema)?; + if this_type == *cast_to_type { + Ok(self) + } else if can_cast_types(&this_type, cast_to_type) { + Ok(Expr::Cast { + expr: Box::new(self), + data_type: cast_to_type.clone(), + }) + } else { + Err(DataFusionError::Plan(format!( + "Cannot automatically convert {:?} to {:?}", + this_type, cast_to_type + ))) + } + } +} diff --git a/datafusion/src/logical_plan/expr_simplier.rs b/datafusion/src/logical_plan/expr_simplier.rs new file mode 100644 index 000000000000..06e58566f8a2 --- /dev/null +++ b/datafusion/src/logical_plan/expr_simplier.rs @@ -0,0 +1,97 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Expression simplifier + +use super::Expr; +use super::ExprRewritable; +use crate::execution::context::ExecutionProps; +use crate::optimizer::simplify_expressions::{ConstEvaluator, Simplifier}; +use datafusion_common::Result; + +/// The information necessary to apply algebraic simplification to an +/// [Expr]. See [SimplifyContext] for one implementation +pub trait SimplifyInfo { + /// returns true if this Expr has boolean type + fn is_boolean_type(&self, expr: &Expr) -> Result; + + /// returns true of this expr is nullable (could possibly be NULL) + fn nullable(&self, expr: &Expr) -> Result; + + /// Returns details needed for partial expression evaluation + fn execution_props(&self) -> &ExecutionProps; +} + +/// trait for types that can be simplified +pub trait ExprSimplifiable: Sized { + /// simplify this trait object using the given SimplifyInfo + fn simplify(self, info: &S) -> Result; +} + +impl ExprSimplifiable for Expr { + /// Simplifies this [`Expr`]`s as much as possible, evaluating + /// constants and applying algebraic simplifications + /// + /// # Example: + /// `b > 2 AND b > 2` + /// can be written to + /// `b > 2` + /// + /// ``` + /// use datafusion::logical_plan::*; + /// use datafusion::error::Result; + /// use datafusion::execution::context::ExecutionProps; + /// + /// /// Simple implementation that provides `Simplifier` the information it needs + /// #[derive(Default)] + /// struct Info { + /// execution_props: ExecutionProps, + /// }; + /// + /// impl SimplifyInfo for Info { + /// fn is_boolean_type(&self, expr: &Expr) -> Result { + /// Ok(false) + /// } + /// fn nullable(&self, expr: &Expr) -> Result { + /// Ok(true) + /// } + /// fn execution_props(&self) -> &ExecutionProps { + /// &self.execution_props + /// } + /// } + /// + /// // b < 2 + /// let b_lt_2 = col("b").gt(lit(2)); + /// + /// // (b < 2) OR (b < 2) + /// let expr = b_lt_2.clone().or(b_lt_2.clone()); + /// + /// // (b < 2) OR (b < 2) --> (b < 2) + /// let expr = expr.simplify(&Info::default()).unwrap(); + /// assert_eq!(expr, b_lt_2); + /// ``` + fn simplify(self, info: &S) -> Result { + let mut rewriter = Simplifier::new(info); + let mut const_evaluator = ConstEvaluator::new(info.execution_props()); + + // TODO iterate until no changes are made during rewrite + // (evaluating constants can enable new simplifications and + // simplifications can enable new constant evaluation) + // https://github.com/apache/arrow-datafusion/issues/1160 + self.rewrite(&mut const_evaluator)?.rewrite(&mut rewriter) + } +} diff --git a/datafusion/src/logical_plan/expr_visitor.rs b/datafusion/src/logical_plan/expr_visitor.rs new file mode 100644 index 000000000000..26084fb95f0b --- /dev/null +++ b/datafusion/src/logical_plan/expr_visitor.rs @@ -0,0 +1,176 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Expression visitor + +use super::Expr; +use datafusion_common::Result; + +/// Controls how the visitor recursion should proceed. +pub enum Recursion { + /// Attempt to visit all the children, recursively, of this expression. + Continue(V), + /// Do not visit the children of this expression, though the walk + /// of parents of this expression will not be affected + Stop(V), +} + +/// Encode the traversal of an expression tree. When passed to +/// `Expr::accept`, `ExpressionVisitor::visit` is invoked +/// recursively on all nodes of an expression tree. See the comments +/// on `Expr::accept` for details on its use +pub trait ExpressionVisitor: Sized { + /// Invoked before any children of `expr` are visisted. + fn pre_visit(self, expr: &E) -> Result> + where + Self: ExpressionVisitor; + + /// Invoked after all children of `expr` are visited. Default + /// implementation does nothing. + fn post_visit(self, _expr: &E) -> Result { + Ok(self) + } +} + +/// trait for types that can be visited by [`ExpressionVisitor`] +pub trait ExprVisitable: Sized { + /// accept a visitor, calling `visit` on all children of this + fn accept>(&self, visitor: V) -> Result; +} + +impl ExprVisitable for Expr { + /// Performs a depth first walk of an expression and + /// its children, calling [`ExpressionVisitor::pre_visit`] and + /// `visitor.post_visit`. + /// + /// Implements the [visitor pattern](https://en.wikipedia.org/wiki/Visitor_pattern) to + /// separate expression algorithms from the structure of the + /// `Expr` tree and make it easier to add new types of expressions + /// and algorithms that walk the tree. + /// + /// For an expression tree such as + /// ```text + /// BinaryExpr (GT) + /// left: Column("foo") + /// right: Column("bar") + /// ``` + /// + /// The nodes are visited using the following order + /// ```text + /// pre_visit(BinaryExpr(GT)) + /// pre_visit(Column("foo")) + /// pre_visit(Column("bar")) + /// post_visit(Column("bar")) + /// post_visit(Column("bar")) + /// post_visit(BinaryExpr(GT)) + /// ``` + /// + /// If an Err result is returned, recursion is stopped immediately + /// + /// If `Recursion::Stop` is returned on a call to pre_visit, no + /// children of that expression are visited, nor is post_visit + /// called on that expression + /// + fn accept(&self, visitor: V) -> Result { + let visitor = match visitor.pre_visit(self)? { + Recursion::Continue(visitor) => visitor, + // If the recursion should stop, do not visit children + Recursion::Stop(visitor) => return Ok(visitor), + }; + + // recurse (and cover all expression types) + let visitor = match self { + Expr::Alias(expr, _) + | Expr::Not(expr) + | Expr::IsNotNull(expr) + | Expr::IsNull(expr) + | Expr::Negative(expr) + | Expr::Cast { expr, .. } + | Expr::TryCast { expr, .. } + | Expr::Sort { expr, .. } + | Expr::GetIndexedField { expr, .. } => expr.accept(visitor), + Expr::Column(_) + | Expr::ScalarVariable(_) + | Expr::Literal(_) + | Expr::Wildcard => Ok(visitor), + Expr::BinaryExpr { left, right, .. } => { + let visitor = left.accept(visitor)?; + right.accept(visitor) + } + Expr::Between { + expr, low, high, .. + } => { + let visitor = expr.accept(visitor)?; + let visitor = low.accept(visitor)?; + high.accept(visitor) + } + Expr::Case { + expr, + when_then_expr, + else_expr, + } => { + let visitor = if let Some(expr) = expr.as_ref() { + expr.accept(visitor) + } else { + Ok(visitor) + }?; + let visitor = when_then_expr.iter().try_fold( + visitor, + |visitor, (when, then)| { + let visitor = when.accept(visitor)?; + then.accept(visitor) + }, + )?; + if let Some(else_expr) = else_expr.as_ref() { + else_expr.accept(visitor) + } else { + Ok(visitor) + } + } + Expr::ScalarFunction { args, .. } + | Expr::ScalarUDF { args, .. } + | Expr::AggregateFunction { args, .. } + | Expr::AggregateUDF { args, .. } => args + .iter() + .try_fold(visitor, |visitor, arg| arg.accept(visitor)), + Expr::WindowFunction { + args, + partition_by, + order_by, + .. + } => { + let visitor = args + .iter() + .try_fold(visitor, |visitor, arg| arg.accept(visitor))?; + let visitor = partition_by + .iter() + .try_fold(visitor, |visitor, arg| arg.accept(visitor))?; + let visitor = order_by + .iter() + .try_fold(visitor, |visitor, arg| arg.accept(visitor))?; + Ok(visitor) + } + Expr::InList { expr, list, .. } => { + let visitor = expr.accept(visitor)?; + list.iter() + .try_fold(visitor, |visitor, arg| arg.accept(visitor)) + } + }?; + + visitor.post_visit(self) + } +} diff --git a/datafusion/src/logical_plan/mod.rs b/datafusion/src/logical_plan/mod.rs index 25714514d78a..24d6723210c7 100644 --- a/datafusion/src/logical_plan/mod.rs +++ b/datafusion/src/logical_plan/mod.rs @@ -25,6 +25,10 @@ pub(crate) mod builder; mod dfschema; mod display; mod expr; +mod expr_rewriter; +mod expr_schema; +mod expr_simplier; +mod expr_visitor; mod extension; mod operators; pub mod plan; @@ -33,22 +37,28 @@ pub mod window_frames; pub use builder::{ build_join_schema, union_with_alias, LogicalPlanBuilder, UNNAMED_TABLE, }; +pub use datafusion_expr::expr_fn::binary_expr; pub use dfschema::{DFField, DFSchema, DFSchemaRef, ToDFSchema}; pub use display::display_schema; pub use expr::{ abs, acos, and, approx_distinct, approx_percentile_cont, array, ascii, asin, atan, - avg, binary_expr, bit_length, btrim, case, ceil, character_length, chr, col, + avg, bit_length, btrim, call_fn, case, ceil, character_length, chr, col, columnize_expr, combine_filters, concat, concat_ws, cos, count, count_distinct, create_udaf, create_udf, date_part, date_trunc, digest, exp, exprlist_to_fields, floor, in_list, initcap, left, length, lit, lit_timestamp_nano, ln, log10, log2, - lower, lpad, ltrim, max, md5, min, normalize_col, normalize_cols, now, octet_length, - or, random, regexp_match, regexp_replace, repeat, replace, replace_col, reverse, - rewrite_sort_cols_by_aggs, right, round, rpad, rtrim, sha224, sha256, sha384, sha512, - signum, sin, split_part, sqrt, starts_with, strpos, substr, sum, tan, to_hex, - translate, trim, trunc, unalias, unnormalize_col, unnormalize_cols, upper, when, - Column, Expr, ExprRewriter, ExpressionVisitor, Literal, Recursion, RewriteRecursion, - SimplifyInfo, + lower, lpad, ltrim, max, md5, min, now, octet_length, or, random, regexp_match, + regexp_replace, repeat, replace, reverse, right, round, rpad, rtrim, sha224, sha256, + sha384, sha512, signum, sin, split_part, sqrt, starts_with, strpos, substr, sum, tan, + to_hex, translate, trim, trunc, unalias, upper, when, Column, Expr, ExprSchema, + Literal, }; +pub use expr_rewriter::{ + normalize_col, normalize_cols, replace_col, rewrite_sort_cols_by_aggs, + unnormalize_col, unnormalize_cols, ExprRewritable, ExprRewriter, RewriteRecursion, +}; +pub use expr_schema::ExprSchemable; +pub use expr_simplier::{ExprSimplifiable, SimplifyInfo}; +pub use expr_visitor::{ExprVisitable, ExpressionVisitor, Recursion}; pub use extension::UserDefinedLogicalNode; pub use operators::Operator; pub use plan::{ diff --git a/datafusion/src/logical_plan/operators.rs b/datafusion/src/logical_plan/operators.rs index 14ccab0537bd..2f129284fa71 100644 --- a/datafusion/src/logical_plan/operators.rs +++ b/datafusion/src/logical_plan/operators.rs @@ -15,128 +15,7 @@ // specific language governing permissions and limitations // under the License. -use std::{fmt, ops}; - -use super::{binary_expr, Expr}; - -/// Operators applied to expressions -#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Hash)] -pub enum Operator { - /// Expressions are equal - Eq, - /// Expressions are not equal - NotEq, - /// Left side is smaller than right side - Lt, - /// Left side is smaller or equal to right side - LtEq, - /// Left side is greater than right side - Gt, - /// Left side is greater or equal to right side - GtEq, - /// Addition - Plus, - /// Subtraction - Minus, - /// Multiplication operator, like `*` - Multiply, - /// Division operator, like `/` - Divide, - /// Remainder operator, like `%` - Modulo, - /// Logical AND, like `&&` - And, - /// Logical OR, like `||` - Or, - /// Matches a wildcard pattern - Like, - /// Does not match a wildcard pattern - NotLike, - /// IS DISTINCT FROM - IsDistinctFrom, - /// IS NOT DISTINCT FROM - IsNotDistinctFrom, - /// Case sensitive regex match - RegexMatch, - /// Case insensitive regex match - RegexIMatch, - /// Case sensitive regex not match - RegexNotMatch, - /// Case insensitive regex not match - RegexNotIMatch, - /// Bitwise and, like `&` - BitwiseAnd, -} - -impl fmt::Display for Operator { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - let display = match &self { - Operator::Eq => "=", - Operator::NotEq => "!=", - Operator::Lt => "<", - Operator::LtEq => "<=", - Operator::Gt => ">", - Operator::GtEq => ">=", - Operator::Plus => "+", - Operator::Minus => "-", - Operator::Multiply => "*", - Operator::Divide => "/", - Operator::Modulo => "%", - Operator::And => "AND", - Operator::Or => "OR", - Operator::Like => "LIKE", - Operator::NotLike => "NOT LIKE", - Operator::RegexMatch => "~", - Operator::RegexIMatch => "~*", - Operator::RegexNotMatch => "!~", - Operator::RegexNotIMatch => "!~*", - Operator::IsDistinctFrom => "IS DISTINCT FROM", - Operator::IsNotDistinctFrom => "IS NOT DISTINCT FROM", - Operator::BitwiseAnd => "&", - }; - write!(f, "{}", display) - } -} - -impl ops::Add for Expr { - type Output = Self; - - fn add(self, rhs: Self) -> Self { - binary_expr(self, Operator::Plus, rhs) - } -} - -impl ops::Sub for Expr { - type Output = Self; - - fn sub(self, rhs: Self) -> Self { - binary_expr(self, Operator::Minus, rhs) - } -} - -impl ops::Mul for Expr { - type Output = Self; - - fn mul(self, rhs: Self) -> Self { - binary_expr(self, Operator::Multiply, rhs) - } -} - -impl ops::Div for Expr { - type Output = Self; - - fn div(self, rhs: Self) -> Self { - binary_expr(self, Operator::Divide, rhs) - } -} - -impl ops::Rem for Expr { - type Output = Self; - - fn rem(self, rhs: Self) -> Self { - binary_expr(self, Operator::Modulo, rhs) - } -} +pub use datafusion_expr::Operator; #[cfg(test)] mod tests { diff --git a/datafusion/src/logical_plan/window_frames.rs b/datafusion/src/logical_plan/window_frames.rs index 50e2ee7f8a04..519582089db4 100644 --- a/datafusion/src/logical_plan/window_frames.rs +++ b/datafusion/src/logical_plan/window_frames.rs @@ -15,365 +15,6 @@ // specific language governing permissions and limitations // under the License. -//! Window frame -//! -//! The frame-spec determines which output rows are read by an aggregate window function. The frame-spec consists of four parts: -//! - A frame type - either ROWS, RANGE or GROUPS, -//! - A starting frame boundary, -//! - An ending frame boundary, -//! - An EXCLUDE clause. +//! Window frame types, reimported from datafusion_expr -use crate::error::{DataFusionError, Result}; -use sqlparser::ast; -use std::cmp::Ordering; -use std::convert::{From, TryFrom}; -use std::fmt; -use std::hash::{Hash, Hasher}; - -/// The frame-spec determines which output rows are read by an aggregate window function. -/// -/// The ending frame boundary can be omitted (if the BETWEEN and AND keywords that surround the -/// starting frame boundary are also omitted), in which case the ending frame boundary defaults to -/// CURRENT ROW. -#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Hash)] -pub struct WindowFrame { - /// A frame type - either ROWS, RANGE or GROUPS - pub units: WindowFrameUnits, - /// A starting frame boundary - pub start_bound: WindowFrameBound, - /// An ending frame boundary - pub end_bound: WindowFrameBound, -} - -impl fmt::Display for WindowFrame { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!( - f, - "{} BETWEEN {} AND {}", - self.units, self.start_bound, self.end_bound - )?; - Ok(()) - } -} - -impl TryFrom for WindowFrame { - type Error = DataFusionError; - - fn try_from(value: ast::WindowFrame) -> Result { - let start_bound = value.start_bound.into(); - let end_bound = value - .end_bound - .map(WindowFrameBound::from) - .unwrap_or(WindowFrameBound::CurrentRow); - - if let WindowFrameBound::Following(None) = start_bound { - Err(DataFusionError::Execution( - "Invalid window frame: start bound cannot be unbounded following" - .to_owned(), - )) - } else if let WindowFrameBound::Preceding(None) = end_bound { - Err(DataFusionError::Execution( - "Invalid window frame: end bound cannot be unbounded preceding" - .to_owned(), - )) - } else if start_bound > end_bound { - Err(DataFusionError::Execution(format!( - "Invalid window frame: start bound ({}) cannot be larger than end bound ({})", - start_bound, end_bound - ))) - } else { - let units = value.units.into(); - if units == WindowFrameUnits::Range { - for bound in &[start_bound, end_bound] { - match bound { - WindowFrameBound::Preceding(Some(v)) - | WindowFrameBound::Following(Some(v)) - if *v > 0 => - { - Err(DataFusionError::NotImplemented(format!( - "With WindowFrameUnits={}, the bound cannot be {} PRECEDING or FOLLOWING at the moment", - units, v - ))) - } - _ => Ok(()), - }?; - } - } - Ok(Self { - units, - start_bound, - end_bound, - }) - } - } -} - -impl Default for WindowFrame { - fn default() -> Self { - WindowFrame { - units: WindowFrameUnits::Range, - start_bound: WindowFrameBound::Preceding(None), - end_bound: WindowFrameBound::CurrentRow, - } - } -} - -/// There are five ways to describe starting and ending frame boundaries: -/// -/// 1. UNBOUNDED PRECEDING -/// 2. PRECEDING -/// 3. CURRENT ROW -/// 4. FOLLOWING -/// 5. UNBOUNDED FOLLOWING -/// -/// in this implementation we'll only allow to be u64 (i.e. no dynamic boundary) -#[derive(Debug, Clone, Copy, Eq)] -pub enum WindowFrameBound { - /// 1. UNBOUNDED PRECEDING - /// The frame boundary is the first row in the partition. - /// - /// 2. PRECEDING - /// must be a non-negative constant numeric expression. The boundary is a row that - /// is "units" prior to the current row. - Preceding(Option), - /// 3. The current row. - /// - /// For RANGE and GROUPS frame types, peers of the current row are also - /// included in the frame, unless specifically excluded by the EXCLUDE clause. - /// This is true regardless of whether CURRENT ROW is used as the starting or ending frame - /// boundary. - CurrentRow, - /// 4. This is the same as " PRECEDING" except that the boundary is units after the - /// current rather than before the current row. - /// - /// 5. UNBOUNDED FOLLOWING - /// The frame boundary is the last row in the partition. - Following(Option), -} - -impl From for WindowFrameBound { - fn from(value: ast::WindowFrameBound) -> Self { - match value { - ast::WindowFrameBound::Preceding(v) => Self::Preceding(v), - ast::WindowFrameBound::Following(v) => Self::Following(v), - ast::WindowFrameBound::CurrentRow => Self::CurrentRow, - } - } -} - -impl fmt::Display for WindowFrameBound { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - match self { - WindowFrameBound::CurrentRow => f.write_str("CURRENT ROW"), - WindowFrameBound::Preceding(None) => f.write_str("UNBOUNDED PRECEDING"), - WindowFrameBound::Following(None) => f.write_str("UNBOUNDED FOLLOWING"), - WindowFrameBound::Preceding(Some(n)) => write!(f, "{} PRECEDING", n), - WindowFrameBound::Following(Some(n)) => write!(f, "{} FOLLOWING", n), - } - } -} - -impl Hash for WindowFrameBound { - fn hash(&self, state: &mut H) { - self.get_rank().hash(state) - } -} - -impl PartialEq for WindowFrameBound { - fn eq(&self, other: &Self) -> bool { - self.cmp(other) == Ordering::Equal - } -} - -impl PartialOrd for WindowFrameBound { - fn partial_cmp(&self, other: &Self) -> Option { - Some(self.cmp(other)) - } -} - -impl Ord for WindowFrameBound { - fn cmp(&self, other: &Self) -> Ordering { - self.get_rank().cmp(&other.get_rank()) - } -} - -impl WindowFrameBound { - /// get the rank of this window frame bound. - /// - /// the rank is a tuple of (u8, u64) because we'll firstly compare the kind and then the value - /// which requires special handling e.g. with preceding the larger the value the smaller the - /// rank and also for 0 preceding / following it is the same as current row - fn get_rank(&self) -> (u8, u64) { - match self { - WindowFrameBound::Preceding(None) => (0, 0), - WindowFrameBound::Following(None) => (4, 0), - WindowFrameBound::Preceding(Some(0)) - | WindowFrameBound::CurrentRow - | WindowFrameBound::Following(Some(0)) => (2, 0), - WindowFrameBound::Preceding(Some(v)) => (1, u64::MAX - *v), - WindowFrameBound::Following(Some(v)) => (3, *v), - } - } -} - -/// There are three frame types: ROWS, GROUPS, and RANGE. The frame type determines how the -/// starting and ending boundaries of the frame are measured. -#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Hash)] -pub enum WindowFrameUnits { - /// The ROWS frame type means that the starting and ending boundaries for the frame are - /// determined by counting individual rows relative to the current row. - Rows, - /// The RANGE frame type requires that the ORDER BY clause of the window have exactly one - /// term. Call that term "X". With the RANGE frame type, the elements of the frame are - /// determined by computing the value of expression X for all rows in the partition and framing - /// those rows for which the value of X is within a certain range of the value of X for the - /// current row. - Range, - /// The GROUPS frame type means that the starting and ending boundaries are determine - /// by counting "groups" relative to the current group. A "group" is a set of rows that all have - /// equivalent values for all all terms of the window ORDER BY clause. - Groups, -} - -impl fmt::Display for WindowFrameUnits { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - f.write_str(match self { - WindowFrameUnits::Rows => "ROWS", - WindowFrameUnits::Range => "RANGE", - WindowFrameUnits::Groups => "GROUPS", - }) - } -} - -impl From for WindowFrameUnits { - fn from(value: ast::WindowFrameUnits) -> Self { - match value { - ast::WindowFrameUnits::Range => Self::Range, - ast::WindowFrameUnits::Groups => Self::Groups, - ast::WindowFrameUnits::Rows => Self::Rows, - } - } -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_window_frame_creation() -> Result<()> { - let window_frame = ast::WindowFrame { - units: ast::WindowFrameUnits::Range, - start_bound: ast::WindowFrameBound::Following(None), - end_bound: None, - }; - let result = WindowFrame::try_from(window_frame); - assert_eq!( - result.err().unwrap().to_string(), - "Execution error: Invalid window frame: start bound cannot be unbounded following".to_owned() - ); - - let window_frame = ast::WindowFrame { - units: ast::WindowFrameUnits::Range, - start_bound: ast::WindowFrameBound::Preceding(None), - end_bound: Some(ast::WindowFrameBound::Preceding(None)), - }; - let result = WindowFrame::try_from(window_frame); - assert_eq!( - result.err().unwrap().to_string(), - "Execution error: Invalid window frame: end bound cannot be unbounded preceding".to_owned() - ); - - let window_frame = ast::WindowFrame { - units: ast::WindowFrameUnits::Range, - start_bound: ast::WindowFrameBound::Preceding(Some(1)), - end_bound: Some(ast::WindowFrameBound::Preceding(Some(2))), - }; - let result = WindowFrame::try_from(window_frame); - assert_eq!( - result.err().unwrap().to_string(), - "Execution error: Invalid window frame: start bound (1 PRECEDING) cannot be larger than end bound (2 PRECEDING)".to_owned() - ); - - let window_frame = ast::WindowFrame { - units: ast::WindowFrameUnits::Range, - start_bound: ast::WindowFrameBound::Preceding(Some(2)), - end_bound: Some(ast::WindowFrameBound::Preceding(Some(1))), - }; - let result = WindowFrame::try_from(window_frame); - assert_eq!( - result.err().unwrap().to_string(), - "This feature is not implemented: With WindowFrameUnits=RANGE, the bound cannot be 2 PRECEDING or FOLLOWING at the moment".to_owned() - ); - - let window_frame = ast::WindowFrame { - units: ast::WindowFrameUnits::Rows, - start_bound: ast::WindowFrameBound::Preceding(Some(2)), - end_bound: Some(ast::WindowFrameBound::Preceding(Some(1))), - }; - let result = WindowFrame::try_from(window_frame); - assert!(result.is_ok()); - Ok(()) - } - - #[test] - fn test_eq() { - assert_eq!( - WindowFrameBound::Preceding(Some(0)), - WindowFrameBound::CurrentRow - ); - assert_eq!( - WindowFrameBound::CurrentRow, - WindowFrameBound::Following(Some(0)) - ); - assert_eq!( - WindowFrameBound::Following(Some(2)), - WindowFrameBound::Following(Some(2)) - ); - assert_eq!( - WindowFrameBound::Following(None), - WindowFrameBound::Following(None) - ); - assert_eq!( - WindowFrameBound::Preceding(Some(2)), - WindowFrameBound::Preceding(Some(2)) - ); - assert_eq!( - WindowFrameBound::Preceding(None), - WindowFrameBound::Preceding(None) - ); - } - - #[test] - fn test_ord() { - assert!(WindowFrameBound::Preceding(Some(1)) < WindowFrameBound::CurrentRow); - // ! yes this is correct! - assert!( - WindowFrameBound::Preceding(Some(2)) < WindowFrameBound::Preceding(Some(1)) - ); - assert!( - WindowFrameBound::Preceding(Some(u64::MAX)) - < WindowFrameBound::Preceding(Some(u64::MAX - 1)) - ); - assert!( - WindowFrameBound::Preceding(None) - < WindowFrameBound::Preceding(Some(1000000)) - ); - assert!( - WindowFrameBound::Preceding(None) - < WindowFrameBound::Preceding(Some(u64::MAX)) - ); - assert!(WindowFrameBound::Preceding(None) < WindowFrameBound::Following(Some(0))); - assert!( - WindowFrameBound::Preceding(Some(1)) < WindowFrameBound::Following(Some(1)) - ); - assert!(WindowFrameBound::CurrentRow < WindowFrameBound::Following(Some(1))); - assert!( - WindowFrameBound::Following(Some(1)) < WindowFrameBound::Following(Some(2)) - ); - assert!(WindowFrameBound::Following(Some(2)) < WindowFrameBound::Following(None)); - assert!( - WindowFrameBound::Following(Some(u64::MAX)) - < WindowFrameBound::Following(None) - ); - } -} +pub use datafusion_expr::{WindowFrame, WindowFrameBound, WindowFrameUnits}; diff --git a/datafusion/src/optimizer/common_subexpr_eliminate.rs b/datafusion/src/optimizer/common_subexpr_eliminate.rs index 947073409d05..2ed45be25bc1 100644 --- a/datafusion/src/optimizer/common_subexpr_eliminate.rs +++ b/datafusion/src/optimizer/common_subexpr_eliminate.rs @@ -23,8 +23,8 @@ use crate::logical_plan::plan::{Filter, Projection, Window}; use crate::logical_plan::{ col, plan::{Aggregate, Sort}, - DFField, DFSchema, Expr, ExprRewriter, ExpressionVisitor, LogicalPlan, Recursion, - RewriteRecursion, + DFField, DFSchema, Expr, ExprRewritable, ExprRewriter, ExprSchemable, ExprVisitable, + ExpressionVisitor, LogicalPlan, Recursion, RewriteRecursion, }; use crate::optimizer::optimizer::OptimizerRule; use crate::optimizer::utils; diff --git a/datafusion/src/optimizer/filter_push_down.rs b/datafusion/src/optimizer/filter_push_down.rs index 7f631d37018f..a07d24c56aeb 100644 --- a/datafusion/src/optimizer/filter_push_down.rs +++ b/datafusion/src/optimizer/filter_push_down.rs @@ -16,9 +16,9 @@ use crate::datasource::datasource::TableProviderFilterPushDown; use crate::execution::context::ExecutionProps; -use crate::logical_plan::plan::{Aggregate, Filter, Join, Projection}; +use crate::logical_plan::plan::{Aggregate, Filter, Join, Projection, Union}; use crate::logical_plan::{ - and, replace_col, Column, CrossJoin, JoinType, Limit, LogicalPlan, TableScan, + and, col, replace_col, Column, CrossJoin, JoinType, Limit, LogicalPlan, TableScan, }; use crate::logical_plan::{DFSchema, Expr}; use crate::optimizer::optimizer::OptimizerRule; @@ -393,8 +393,29 @@ fn optimize(plan: &LogicalPlan, mut state: State) -> Result { // sort is filter-commutable push_down(&state, plan) } - LogicalPlan::Union(_) => { - // union all is filter-commutable + LogicalPlan::Union(Union { + inputs: _, + schema, + alias: _, + }) => { + // union changing all qualifiers while building logical plan so we need + // to rewrite filters to push unqualified columns to inputs + let projection = schema + .fields() + .iter() + .map(|field| (field.qualified_name(), col(field.name()))) + .collect::>(); + + // rewriting predicate expressions using unqualified names as replacements + if !projection.is_empty() { + for (predicate, columns) in state.filters.iter_mut() { + *predicate = rewrite(predicate, &projection)?; + + columns.clear(); + utils::expr_to_columns(predicate, columns)?; + } + } + push_down(&state, plan) } LogicalPlan::Limit(Limit { input, .. }) => { @@ -574,7 +595,9 @@ mod tests { use super::*; use crate::datasource::TableProvider; use crate::field_util::SchemaExt; - use crate::logical_plan::{lit, sum, DFSchema, Expr, LogicalPlanBuilder, Operator}; + use crate::logical_plan::{ + lit, sum, union_with_alias, DFSchema, Expr, LogicalPlanBuilder, Operator, + }; use crate::physical_plan::ExecutionPlan; use crate::test::*; use crate::{logical_plan::col, prelude::JoinType}; @@ -901,6 +924,27 @@ mod tests { Ok(()) } + #[test] + fn union_all_with_alias() -> Result<()> { + let table_scan = test_table_scan()?; + let union = + union_with_alias(table_scan.clone(), table_scan, Some("t".to_string()))?; + + let plan = LogicalPlanBuilder::from(union) + .filter(col("t.a").eq(lit(1i64)))? + .build()?; + + // filter appears below Union without relation qualifier + let expected = "\ + Union\ + \n Filter: #a = Int64(1)\ + \n TableScan: test projection=None\ + \n Filter: #a = Int64(1)\ + \n TableScan: test projection=None"; + assert_optimized_plan_eq(&plan, expected); + Ok(()) + } + /// verifies that filters with the same columns are correctly placed #[test] fn filter_2_breaks_limits() -> Result<()> { diff --git a/datafusion/src/optimizer/simplify_expressions.rs b/datafusion/src/optimizer/simplify_expressions.rs index e03babef49ef..a3da9d191e5d 100644 --- a/datafusion/src/optimizer/simplify_expressions.rs +++ b/datafusion/src/optimizer/simplify_expressions.rs @@ -17,23 +17,23 @@ //! Simplify expressions optimizer rule -use crate::record_batch::RecordBatch; -use arrow::array::new_null_array; -use arrow::datatypes::{DataType, Field, Schema}; - use crate::error::DataFusionError; use crate::execution::context::ExecutionProps; -use crate::field_util::SchemaExt; +use crate::logical_plan::ExprSchemable; use crate::logical_plan::{ - lit, DFSchema, DFSchemaRef, Expr, ExprRewriter, LogicalPlan, RewriteRecursion, - SimplifyInfo, + lit, DFSchema, DFSchemaRef, Expr, ExprRewritable, ExprRewriter, ExprSimplifiable, + LogicalPlan, RewriteRecursion, SimplifyInfo, }; use crate::optimizer::optimizer::OptimizerRule; use crate::optimizer::utils; use crate::physical_plan::functions::Volatility; use crate::physical_plan::planner::create_physical_expr; +use crate::record_batch::RecordBatch; use crate::scalar::ScalarValue; use crate::{error::Result, logical_plan::Operator}; +use arrow::array::new_null_array; +use arrow::datatypes::{DataType, Field, Schema}; +use datafusion_common::field_util::SchemaExt; /// Provides simplification information based on schema and properties struct SimplifyContext<'a, 'b> { @@ -253,6 +253,7 @@ impl SimplifyExpressions { /// /// ``` /// # use datafusion::prelude::*; +/// # use datafusion::logical_plan::ExprRewritable; /// # use datafusion::optimizer::simplify_expressions::ConstEvaluator; /// # use datafusion::execution::context::ExecutionProps; /// @@ -403,7 +404,7 @@ impl<'a> ConstEvaluator<'a> { let phys_expr = create_physical_expr( &expr, &self.input_schema, - self.input_batch.schema(), + &self.input_batch.schema(), self.execution_props, )?; let col_val = phys_expr.evaluate(&self.input_batch)?; @@ -736,8 +737,8 @@ mod tests { use super::*; use crate::assert_contains; use crate::logical_plan::{ - and, binary_expr, col, create_udf, lit, lit_timestamp_nano, DFField, Expr, - LogicalPlanBuilder, + and, binary_expr, call_fn, col, create_udf, lit, lit_timestamp_nano, DFField, + Expr, LogicalPlanBuilder, }; use crate::physical_plan::functions::{make_scalar_function, BuiltinScalarFunction}; use crate::physical_plan::udf::ScalarUDF; @@ -1011,46 +1012,29 @@ mod tests { #[test] fn test_const_evaluator_scalar_functions() { // concat("foo", "bar") --> "foobar" - let expr = Expr::ScalarFunction { - args: vec![lit("foo"), lit("bar")], - fun: BuiltinScalarFunction::Concat, - }; + let expr = call_fn("concat", vec![lit("foo"), lit("bar")]).unwrap(); test_evaluate(expr, lit("foobar")); // ensure arguments are also constant folded // concat("foo", concat("bar", "baz")) --> "foobarbaz" - let concat1 = Expr::ScalarFunction { - args: vec![lit("bar"), lit("baz")], - fun: BuiltinScalarFunction::Concat, - }; - let expr = Expr::ScalarFunction { - args: vec![lit("foo"), concat1], - fun: BuiltinScalarFunction::Concat, - }; + let concat1 = call_fn("concat", vec![lit("bar"), lit("baz")]).unwrap(); + let expr = call_fn("concat", vec![lit("foo"), concat1]).unwrap(); test_evaluate(expr, lit("foobarbaz")); // Check non string arguments // to_timestamp("2020-09-08T12:00:00+00:00") --> timestamp(1599566400000000000i64) - let expr = Expr::ScalarFunction { - args: vec![lit("2020-09-08T12:00:00+00:00")], - fun: BuiltinScalarFunction::ToTimestamp, - }; + let expr = + call_fn("to_timestamp", vec![lit("2020-09-08T12:00:00+00:00")]).unwrap(); test_evaluate(expr, lit_timestamp_nano(1599566400000000000i64)); // check that non foldable arguments are folded // to_timestamp(a) --> to_timestamp(a) [no rewrite possible] - let expr = Expr::ScalarFunction { - args: vec![col("a")], - fun: BuiltinScalarFunction::ToTimestamp, - }; + let expr = call_fn("to_timestamp", vec![col("a")]).unwrap(); test_evaluate(expr.clone(), expr); // check that non foldable arguments are folded // to_timestamp(a) --> to_timestamp(a) [no rewrite possible] - let expr = Expr::ScalarFunction { - args: vec![col("a")], - fun: BuiltinScalarFunction::ToTimestamp, - }; + let expr = call_fn("to_timestamp", vec![col("a")]).unwrap(); test_evaluate(expr.clone(), expr); // volatile / stable functions should not be evaluated @@ -1091,10 +1075,7 @@ mod tests { } fn now_expr() -> Expr { - Expr::ScalarFunction { - args: vec![], - fun: BuiltinScalarFunction::Now, - } + call_fn("now", vec![]).unwrap() } fn cast_to_int64_expr(expr: Expr) -> Expr { @@ -1105,10 +1086,7 @@ mod tests { } fn to_timestamp_expr(arg: impl Into) -> Expr { - Expr::ScalarFunction { - args: vec![lit(arg.into())], - fun: BuiltinScalarFunction::ToTimestamp, - } + call_fn("to_timestamp", vec![lit(arg.into())]).unwrap() } #[test] diff --git a/datafusion/src/optimizer/single_distinct_to_groupby.rs b/datafusion/src/optimizer/single_distinct_to_groupby.rs index 02a24e214495..2e0bd5ff0549 100644 --- a/datafusion/src/optimizer/single_distinct_to_groupby.rs +++ b/datafusion/src/optimizer/single_distinct_to_groupby.rs @@ -20,6 +20,7 @@ use crate::error::Result; use crate::execution::context::ExecutionProps; use crate::logical_plan::plan::{Aggregate, Projection}; +use crate::logical_plan::ExprSchemable; use crate::logical_plan::{col, columnize_expr, DFSchema, Expr, LogicalPlan}; use crate::optimizer::optimizer::OptimizerRule; use crate::optimizer::utils; diff --git a/datafusion/src/optimizer/utils.rs b/datafusion/src/optimizer/utils.rs index f7ab836b398c..41d1e4bca03b 100644 --- a/datafusion/src/optimizer/utils.rs +++ b/datafusion/src/optimizer/utils.rs @@ -22,9 +22,11 @@ use crate::execution::context::ExecutionProps; use crate::logical_plan::plan::{ Aggregate, Analyze, Extension, Filter, Join, Projection, Sort, Window, }; + use crate::logical_plan::{ - build_join_schema, Column, CreateMemoryTable, DFSchemaRef, Expr, Limit, LogicalPlan, - LogicalPlanBuilder, Operator, Partitioning, Recursion, Repartition, Union, Values, + build_join_schema, Column, CreateMemoryTable, DFSchemaRef, Expr, ExprVisitable, + Limit, LogicalPlan, LogicalPlanBuilder, Operator, Partitioning, Recursion, + Repartition, Union, Values, }; use crate::prelude::lit; use crate::scalar::ScalarValue; diff --git a/datafusion/src/physical_optimizer/repartition.rs b/datafusion/src/physical_optimizer/repartition.rs index 461d19445ea4..ed7c7b2c14f8 100644 --- a/datafusion/src/physical_optimizer/repartition.rs +++ b/datafusion/src/physical_optimizer/repartition.rs @@ -19,10 +19,10 @@ use std::sync::Arc; use super::optimizer::PhysicalOptimizerRule; +use crate::physical_plan::Partitioning::*; use crate::physical_plan::{ empty::EmptyExec, repartition::RepartitionExec, ExecutionPlan, }; -use crate::physical_plan::{Distribution, Partitioning::*}; use crate::{error::Result, execution::context::ExecutionConfig}; /// Optimizer that introduces repartition to introduce more parallelism in the plan @@ -38,8 +38,8 @@ impl Repartition { fn optimize_partitions( target_partitions: usize, - requires_single_partition: bool, plan: Arc, + should_repartition: bool, ) -> Result> { // Recurse into children bottom-up (added nodes should be as deep as possible) @@ -47,17 +47,15 @@ fn optimize_partitions( // leaf node - don't replace children plan.clone() } else { + let should_repartition_children = plan.should_repartition_children(); let children = plan .children() .iter() .map(|child| { optimize_partitions( target_partitions, - matches!( - plan.required_child_distribution(), - Distribution::SinglePartition - ), child.clone(), + should_repartition_children, ) }) .collect::>()?; @@ -77,7 +75,7 @@ fn optimize_partitions( // But also not very useful to inlude let is_empty_exec = plan.as_any().downcast_ref::().is_some(); - if perform_repartition && !requires_single_partition && !is_empty_exec { + if perform_repartition && should_repartition && !is_empty_exec { Ok(Arc::new(RepartitionExec::try_new( new_plan, RoundRobinBatch(target_partitions), @@ -97,7 +95,7 @@ impl PhysicalOptimizerRule for Repartition { if config.target_partitions == 1 { Ok(plan) } else { - optimize_partitions(config.target_partitions, true, plan) + optimize_partitions(config.target_partitions, plan, false) } } @@ -107,94 +105,176 @@ impl PhysicalOptimizerRule for Repartition { } #[cfg(test)] mod tests { - use arrow::datatypes::Schema; + use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; use super::*; use crate::datasource::PartitionedFile; use crate::field_util::SchemaExt; + use crate::physical_plan::expressions::col; use crate::physical_plan::file_format::{FileScanConfig, ParquetExec}; - use crate::physical_plan::projection::ProjectionExec; - use crate::physical_plan::Statistics; + use crate::physical_plan::filter::FilterExec; + use crate::physical_plan::hash_aggregate::{AggregateMode, HashAggregateExec}; + use crate::physical_plan::limit::{GlobalLimitExec, LocalLimitExec}; + use crate::physical_plan::union::UnionExec; + use crate::physical_plan::{displayable, Statistics}; use crate::test::object_store::TestObjectStore; + fn schema() -> SchemaRef { + Arc::new(Schema::new(vec![Field::new("c1", DataType::Boolean, true)])) + } + + fn parquet_exec() -> Arc { + Arc::new(ParquetExec::new( + FileScanConfig { + object_store: TestObjectStore::new_arc(&[("x", 100)]), + file_schema: schema(), + file_groups: vec![vec![PartitionedFile::new("x".to_string(), 100)]], + statistics: Statistics::default(), + projection: None, + limit: None, + table_partition_cols: vec![], + }, + None, + )) + } + + fn filter_exec(input: Arc) -> Arc { + Arc::new(FilterExec::try_new(col("c1", &schema()).unwrap(), input).unwrap()) + } + + fn hash_aggregate(input: Arc) -> Arc { + let schema = schema(); + Arc::new( + HashAggregateExec::try_new( + AggregateMode::Final, + vec![], + vec![], + Arc::new( + HashAggregateExec::try_new( + AggregateMode::Partial, + vec![], + vec![], + input, + schema.clone(), + ) + .unwrap(), + ), + schema, + ) + .unwrap(), + ) + } + + fn limit_exec(input: Arc) -> Arc { + Arc::new(GlobalLimitExec::new( + Arc::new(LocalLimitExec::new(input, 100)), + 100, + )) + } + + fn trim_plan_display(plan: &str) -> Vec<&str> { + plan.split('\n') + .map(|s| s.trim()) + .filter(|s| !s.is_empty()) + .collect() + } + #[test] fn added_repartition_to_single_partition() -> Result<()> { - let file_schema = Arc::new(Schema::empty()); - let parquet_project = ProjectionExec::try_new( - vec![], - Arc::new(ParquetExec::new( - FileScanConfig { - object_store: TestObjectStore::new_arc(&[("x", 100)]), - file_schema, - file_groups: vec![vec![PartitionedFile::new("x".to_string(), 100)]], - statistics: Statistics::default(), - projection: None, - limit: None, - table_partition_cols: vec![], - }, - None, - )), - )?; - let optimizer = Repartition {}; let optimized = optimizer.optimize( - Arc::new(parquet_project), + hash_aggregate(parquet_exec()), &ExecutionConfig::new().with_target_partitions(10), )?; - assert_eq!( - optimized.children()[0] - .output_partitioning() - .partition_count(), - 10 - ); + let plan = displayable(optimized.as_ref()).indent().to_string(); + + let expected = &[ + "HashAggregateExec: mode=Final, gby=[], aggr=[]", + "HashAggregateExec: mode=Partial, gby=[], aggr=[]", + "RepartitionExec: partitioning=RoundRobinBatch(10)", + "ParquetExec: limit=None, partitions=[x]", + ]; + assert_eq!(&trim_plan_display(&plan), &expected); Ok(()) } #[test] fn repartition_deepest_node() -> Result<()> { - let file_schema = Arc::new(Schema::empty()); - let parquet_project = ProjectionExec::try_new( - vec![], - Arc::new(ProjectionExec::try_new( - vec![], - Arc::new(ParquetExec::new( - FileScanConfig { - object_store: TestObjectStore::new_arc(&[("x", 100)]), - file_schema, - file_groups: vec![vec![PartitionedFile::new( - "x".to_string(), - 100, - )]], - statistics: Statistics::default(), - projection: None, - limit: None, - table_partition_cols: vec![], - }, - None, - )), - )?), + let optimizer = Repartition {}; + + let optimized = optimizer.optimize( + hash_aggregate(filter_exec(parquet_exec())), + &ExecutionConfig::new().with_target_partitions(10), )?; + let plan = displayable(optimized.as_ref()).indent().to_string(); + + let expected = &[ + "HashAggregateExec: mode=Final, gby=[], aggr=[]", + "HashAggregateExec: mode=Partial, gby=[], aggr=[]", + "FilterExec: c1@0", + "RepartitionExec: partitioning=RoundRobinBatch(10)", + "ParquetExec: limit=None, partitions=[x]", + ]; + + assert_eq!(&trim_plan_display(&plan), &expected); + Ok(()) + } + + #[test] + fn repartition_ignores_limit() -> Result<()> { let optimizer = Repartition {}; let optimized = optimizer.optimize( - Arc::new(parquet_project), + hash_aggregate(limit_exec(filter_exec(limit_exec(parquet_exec())))), &ExecutionConfig::new().with_target_partitions(10), )?; - // RepartitionExec is added to deepest node - assert!(optimized.children()[0] - .as_any() - .downcast_ref::() - .is_none()); - assert!(optimized.children()[0].children()[0] - .as_any() - .downcast_ref::() - .is_some()); + let plan = displayable(optimized.as_ref()).indent().to_string(); + + let expected = &[ + "HashAggregateExec: mode=Final, gby=[], aggr=[]", + "HashAggregateExec: mode=Partial, gby=[], aggr=[]", + "RepartitionExec: partitioning=RoundRobinBatch(10)", + "GlobalLimitExec: limit=100", + "LocalLimitExec: limit=100", + "FilterExec: c1@0", + "RepartitionExec: partitioning=RoundRobinBatch(10)", + "GlobalLimitExec: limit=100", + "LocalLimitExec: limit=100", + // Expect no repartition to happen for local limit + "ParquetExec: limit=None, partitions=[x]", + ]; + + assert_eq!(&trim_plan_display(&plan), &expected); + Ok(()) + } + + #[test] + fn repartition_ignores_union() -> Result<()> { + let optimizer = Repartition {}; + + let optimized = optimizer.optimize( + Arc::new(UnionExec::new(vec![parquet_exec(); 5])), + &ExecutionConfig::new().with_target_partitions(5), + )?; + + let plan = displayable(optimized.as_ref()).indent().to_string(); + + let expected = &[ + "UnionExec", + // Expect no repartition of ParquetExec + "ParquetExec: limit=None, partitions=[x]", + "ParquetExec: limit=None, partitions=[x]", + "ParquetExec: limit=None, partitions=[x]", + "ParquetExec: limit=None, partitions=[x]", + "ParquetExec: limit=None, partitions=[x]", + ]; + assert_eq!(&trim_plan_display(&plan), &expected); Ok(()) } } diff --git a/datafusion/src/physical_plan/aggregates.rs b/datafusion/src/physical_plan/aggregates.rs index 59231fccfc65..3af912caf527 100644 --- a/datafusion/src/physical_plan/aggregates.rs +++ b/datafusion/src/physical_plan/aggregates.rs @@ -28,7 +28,7 @@ use super::{ functions::{Signature, TypeSignature, Volatility}, - Accumulator, AggregateExpr, PhysicalExpr, + AggregateExpr, PhysicalExpr, }; use crate::error::{DataFusionError, Result}; use crate::physical_plan::coercion_rule::aggregate_rule::{coerce_exprs, coerce_types}; @@ -38,90 +38,9 @@ use expressions::{ avg_return_type, correlation_return_type, covariance_return_type, stddev_return_type, sum_return_type, variance_return_type, }; -use std::{fmt, str::FromStr, sync::Arc}; - -/// the implementation of an aggregate function -pub type AccumulatorFunctionImplementation = - Arc Result> + Send + Sync>; - -/// This signature corresponds to which types an aggregator serializes -/// its state, given its return datatype. -pub type StateTypeFunction = - Arc Result>> + Send + Sync>; - -/// Enum of all built-in aggregate functions -#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Hash)] -pub enum AggregateFunction { - /// count - Count, - /// sum - Sum, - /// min - Min, - /// max - Max, - /// avg - Avg, - /// Approximate aggregate function - ApproxDistinct, - /// array_agg - ArrayAgg, - /// Variance (Sample) - Variance, - /// Variance (Population) - VariancePop, - /// Standard Deviation (Sample) - Stddev, - /// Standard Deviation (Population) - StddevPop, - /// Covariance (Sample) - Covariance, - /// Covariance (Population) - CovariancePop, - /// Correlation - Correlation, - /// Approximate continuous percentile function - ApproxPercentileCont, -} - -impl fmt::Display for AggregateFunction { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - // uppercase of the debug. - write!(f, "{}", format!("{:?}", self).to_uppercase()) - } -} +use std::sync::Arc; -impl FromStr for AggregateFunction { - type Err = DataFusionError; - fn from_str(name: &str) -> Result { - Ok(match name { - "min" => AggregateFunction::Min, - "max" => AggregateFunction::Max, - "count" => AggregateFunction::Count, - "avg" => AggregateFunction::Avg, - "sum" => AggregateFunction::Sum, - "approx_distinct" => AggregateFunction::ApproxDistinct, - "array_agg" => AggregateFunction::ArrayAgg, - "var" => AggregateFunction::Variance, - "var_samp" => AggregateFunction::Variance, - "var_pop" => AggregateFunction::VariancePop, - "stddev" => AggregateFunction::Stddev, - "stddev_samp" => AggregateFunction::Stddev, - "stddev_pop" => AggregateFunction::StddevPop, - "covar" => AggregateFunction::Covariance, - "covar_samp" => AggregateFunction::Covariance, - "covar_pop" => AggregateFunction::CovariancePop, - "corr" => AggregateFunction::Correlation, - "approx_percentile_cont" => AggregateFunction::ApproxPercentileCont, - _ => { - return Err(DataFusionError::Plan(format!( - "There is no built-in function named {}", - name - ))); - } - }) - } -} +pub use datafusion_expr::AggregateFunction; /// Returns the datatype of the aggregate function. /// This is used to get the returned data type for aggregate expr. diff --git a/datafusion/src/physical_plan/expressions/try_cast.rs b/datafusion/src/physical_plan/expressions/try_cast.rs index a2e74bbac798..0e5c5e81ea94 100644 --- a/datafusion/src/physical_plan/expressions/try_cast.rs +++ b/datafusion/src/physical_plan/expressions/try_cast.rs @@ -279,7 +279,7 @@ mod tests { // decimal to i8 generic_decimal_to_other_test_cast!( decimal_array, - DataType::Decimal(10, 0), + DataType::Decimal(10, 3), Int8Array, DataType::Int8, vec![ diff --git a/datafusion/src/physical_plan/file_format/parquet.rs b/datafusion/src/physical_plan/file_format/parquet.rs index 063023b5488d..525e202d0c06 100644 --- a/datafusion/src/physical_plan/file_format/parquet.rs +++ b/datafusion/src/physical_plan/file_format/parquet.rs @@ -478,9 +478,9 @@ fn read_partition( let file_schema = fetch_schema(object_reader)?; let adapted_projections = schema_adapter.map_projections(&file_schema.clone(), projection)?; - let mut record_reader = read::RecordReader::try_new( + let mut record_reader = read::FileReader::try_new( reader, - Some(adapted_projections.clone()), + Some(&adapted_projections), limit, None, None, @@ -501,7 +501,7 @@ fn read_partition( total_rows += chunk.len(); let batch = RecordBatch::try_new( - read_schema.clone(), + Arc::new(read_schema.clone()), chunk.columns().to_vec(), )?; @@ -549,15 +549,15 @@ mod tests { use super::*; use crate::field_util::FieldExt; use crate::physical_plan::collect; + use ::parquet::statistics::Statistics as ParquetStatistics; use arrow::datatypes::{DataType, Field}; - use arrow::io::parquet::write::{to_parquet_schema, write_file, RowGroupIterator}; - use arrow::io::parquet::write::{ColumnDescriptor, SchemaDescriptor}; + use arrow::io::parquet; + use arrow::io::parquet::read::ColumnChunkMetaData; + use arrow::io::parquet::write::{ + to_parquet_schema, ColumnDescriptor, Compression, Encoding, FileWriter, + RowGroupIterator, SchemaDescriptor, Version, WriteOptions, + }; use futures::StreamExt; - use parquet::compression::Compression; - use parquet::encoding::Encoding; - use parquet::metadata::ColumnChunkMetaData; - use parquet::statistics::Statistics as ParquetStatistics; - use parquet::write::{Version, WriteOptions}; use parquet_format_async_temp::RowGroup; /// writes each RecordBatch as an individual parquet file and then @@ -573,7 +573,7 @@ mod tests { .map(|batch| { let output = tempfile::NamedTempFile::new().expect("creating temp file"); - let mut file: std::fs::File = (*output.as_file()) + let file: std::fs::File = (*output.as_file()) .try_clone() .expect("cloning file descriptor"); let options = WriteOptions { @@ -582,7 +582,6 @@ mod tests { version: Version::V2, }; let schema_ref = &batch.schema().clone(); - let parquet_schema = to_parquet_schema(schema_ref).unwrap(); let iter = vec![Ok(batch.into())]; let row_groups = RowGroupIterator::try_new( @@ -593,15 +592,15 @@ mod tests { ) .unwrap(); - write_file( - &mut file, - row_groups, - schema_ref, - parquet_schema, - options, - None, - ) - .expect("Writing batch"); + let mut writer = + FileWriter::try_new(file, schema_ref.as_ref().clone(), options) + .unwrap(); + writer.start().unwrap(); + for rg in row_groups { + let (group, len) = rg.unwrap(); + writer.write(group, len).unwrap(); + } + writer.end(None).unwrap(); output }) .collect(); @@ -965,7 +964,7 @@ mod tests { ParquetFileMetrics::new(0, "file.parquet", &metrics) } - fn parquet_primitive_column_stats( + fn parquet_primitive_column_stats( column_descr: ColumnDescriptor, min: Option, max: Option, @@ -1271,7 +1270,6 @@ mod tests { schema_descr: &SchemaDescriptor, column_statistics: Vec<&dyn ParquetStatistics>, ) -> RowGroupMetaData { - use parquet::schema::types::{physical_type_to_type, ParquetType}; use parquet_format_async_temp::{ColumnChunk, ColumnMetaData}; let mut chunks = vec![]; @@ -1279,8 +1277,8 @@ mod tests { for (i, s) in column_statistics.into_iter().enumerate() { let column_descr = schema_descr.column(i); let type_ = match column_descr.type_() { - ParquetType::PrimitiveType { physical_type, .. } => { - physical_type_to_type(physical_type).0 + parquet::write::ParquetType::PrimitiveType { physical_type, .. } => { + ::parquet::schema::types::physical_type_to_type(physical_type).0 } _ => { panic!("Trying to write a row group of a non-physical type") @@ -1293,7 +1291,7 @@ mod tests { type_, Vec::new(), column_descr.path_in_schema().to_vec(), - parquet::compression::Compression::Uncompressed.into(), + Compression::Uncompressed.into(), 0, 0, 0, @@ -1301,7 +1299,7 @@ mod tests { 0, None, None, - Some(parquet::statistics::serialize_statistics(s)), + Some(::parquet::statistics::serialize_statistics(s)), None, None, )), diff --git a/datafusion/src/physical_plan/functions.rs b/datafusion/src/physical_plan/functions.rs index 3a4d88fc3cbe..8a3a48cc27af 100644 --- a/datafusion/src/physical_plan/functions.rs +++ b/datafusion/src/physical_plan/functions.rs @@ -48,6 +48,7 @@ use crate::{ scalar::ScalarValue, }; use arrow::{ + array::ArrayRef, array::*, compute::length::length, datatypes::TimeUnit, @@ -55,107 +56,10 @@ use arrow::{ error::{ArrowError, Result as ArrowResult}, types::NativeType, }; +pub use datafusion_expr::NullColumnarValue; +pub use datafusion_expr::{BuiltinScalarFunction, Signature, TypeSignature, Volatility}; use fmt::{Debug, Formatter}; -use std::convert::From; -use std::{any::Any, fmt, str::FromStr, sync::Arc}; - -/// A function's type signature, which defines the function's supported argument types. -#[derive(Debug, Clone, PartialEq, Hash)] -pub enum TypeSignature { - /// arbitrary number of arguments of an common type out of a list of valid types - // A function such as `concat` is `Variadic(vec![DataType::Utf8, DataType::LargeUtf8])` - Variadic(Vec), - /// arbitrary number of arguments of an arbitrary but equal type - // A function such as `array` is `VariadicEqual` - // The first argument decides the type used for coercion - VariadicEqual, - /// fixed number of arguments of an arbitrary but equal type out of a list of valid types - // A function of one argument of f64 is `Uniform(1, vec![DataType::Float64])` - // A function of one argument of f64 or f32 is `Uniform(1, vec![DataType::Float32, DataType::Float64])` - Uniform(usize, Vec), - /// exact number of arguments of an exact type - Exact(Vec), - /// fixed number of arguments of arbitrary types - Any(usize), - /// One of a list of signatures - OneOf(Vec), -} - -///The Signature of a function defines its supported input types as well as its volatility. -#[derive(Debug, Clone, PartialEq, Hash)] -pub struct Signature { - /// type_signature - The types that the function accepts. See [TypeSignature] for more information. - pub type_signature: TypeSignature, - /// volatility - The volatility of the function. See [Volatility] for more information. - pub volatility: Volatility, -} - -impl Signature { - /// new - Creates a new Signature from any type signature and the volatility. - pub fn new(type_signature: TypeSignature, volatility: Volatility) -> Self { - Signature { - type_signature, - volatility, - } - } - /// variadic - Creates a variadic signature that represents an arbitrary number of arguments all from a type in common_types. - pub fn variadic(common_types: Vec, volatility: Volatility) -> Self { - Self { - type_signature: TypeSignature::Variadic(common_types), - volatility, - } - } - /// variadic_equal - Creates a variadic signature that represents an arbitrary number of arguments of the same type. - pub fn variadic_equal(volatility: Volatility) -> Self { - Self { - type_signature: TypeSignature::VariadicEqual, - volatility, - } - } - /// uniform - Creates a function with a fixed number of arguments of the same type, which must be from valid_types. - pub fn uniform( - arg_count: usize, - valid_types: Vec, - volatility: Volatility, - ) -> Self { - Self { - type_signature: TypeSignature::Uniform(arg_count, valid_types), - volatility, - } - } - /// exact - Creates a signture which must match the types in exact_types in order. - pub fn exact(exact_types: Vec, volatility: Volatility) -> Self { - Signature { - type_signature: TypeSignature::Exact(exact_types), - volatility, - } - } - /// any - Creates a signature which can a be made of any type but of a specified number - pub fn any(arg_count: usize, volatility: Volatility) -> Self { - Signature { - type_signature: TypeSignature::Any(arg_count), - volatility, - } - } - /// one_of Creates a signature which can match any of the [TypeSignature]s which are passed in. - pub fn one_of(type_signatures: Vec, volatility: Volatility) -> Self { - Signature { - type_signature: TypeSignature::OneOf(type_signatures), - volatility, - } - } -} - -///A function's volatility, which defines the functions eligibility for certain optimizations -#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Copy, Hash)] -pub enum Volatility { - /// Immutable - An immutable function will always return the same output when given the same input. An example of this is [BuiltinScalarFunction::Cos]. - Immutable, - /// Stable - A stable function may return different values given the same input accross different queries but must return the same value for a given input within a query. An example of this is [BuiltinScalarFunction::Now]. - Stable, - /// Volatile - A volatile function may change the return value from evaluation to evaluation. Mutiple invocations of a volatile function may return different results when used in the same query. An example of this is [BuiltinScalarFunction::Random]. - Volatile, -} +use std::{any::Any, fmt, sync::Arc}; /// Scalar function /// @@ -172,313 +76,6 @@ pub type ScalarFunctionImplementation = pub type ReturnTypeFunction = Arc Result> + Send + Sync>; -/// Enum of all built-in scalar functions -#[derive(Debug, Clone, PartialEq, Eq, Hash)] -pub enum BuiltinScalarFunction { - // math functions - /// abs - Abs, - /// acos - Acos, - /// asin - Asin, - /// atan - Atan, - /// ceil - Ceil, - /// cos - Cos, - /// Digest - Digest, - /// exp - Exp, - /// floor - Floor, - /// ln, Natural logarithm - Ln, - /// log, same as log10 - Log, - /// log10 - Log10, - /// log2 - Log2, - /// round - Round, - /// signum - Signum, - /// sin - Sin, - /// sqrt - Sqrt, - /// tan - Tan, - /// trunc - Trunc, - - // string functions - /// construct an array from columns - Array, - /// ascii - Ascii, - /// bit_length - BitLength, - /// btrim - Btrim, - /// character_length - CharacterLength, - /// chr - Chr, - /// concat - Concat, - /// concat_ws - ConcatWithSeparator, - /// date_part - DatePart, - /// date_trunc - DateTrunc, - /// initcap - InitCap, - /// left - Left, - /// lpad - Lpad, - /// lower - Lower, - /// ltrim - Ltrim, - /// md5 - MD5, - /// nullif - NullIf, - /// octet_length - OctetLength, - /// random - Random, - /// regexp_replace - RegexpReplace, - /// repeat - Repeat, - /// replace - Replace, - /// reverse - Reverse, - /// right - Right, - /// rpad - Rpad, - /// rtrim - Rtrim, - /// sha224 - SHA224, - /// sha256 - SHA256, - /// sha384 - SHA384, - /// Sha512 - SHA512, - /// split_part - SplitPart, - /// starts_with - StartsWith, - /// strpos - Strpos, - /// substr - Substr, - /// to_hex - ToHex, - /// to_timestamp - ToTimestamp, - /// to_timestamp_millis - ToTimestampMillis, - /// to_timestamp_micros - ToTimestampMicros, - /// to_timestamp_seconds - ToTimestampSeconds, - ///now - Now, - /// translate - Translate, - /// trim - Trim, - /// upper - Upper, - /// regexp_match - RegexpMatch, -} - -impl BuiltinScalarFunction { - /// an allowlist of functions to take zero arguments, so that they will get special treatment - /// while executing. - fn supports_zero_argument(&self) -> bool { - matches!( - self, - BuiltinScalarFunction::Random | BuiltinScalarFunction::Now - ) - } - /// Returns the [Volatility] of the builtin function. - pub fn volatility(&self) -> Volatility { - match self { - //Immutable scalar builtins - BuiltinScalarFunction::Abs => Volatility::Immutable, - BuiltinScalarFunction::Acos => Volatility::Immutable, - BuiltinScalarFunction::Asin => Volatility::Immutable, - BuiltinScalarFunction::Atan => Volatility::Immutable, - BuiltinScalarFunction::Ceil => Volatility::Immutable, - BuiltinScalarFunction::Cos => Volatility::Immutable, - BuiltinScalarFunction::Exp => Volatility::Immutable, - BuiltinScalarFunction::Floor => Volatility::Immutable, - BuiltinScalarFunction::Ln => Volatility::Immutable, - BuiltinScalarFunction::Log => Volatility::Immutable, - BuiltinScalarFunction::Log10 => Volatility::Immutable, - BuiltinScalarFunction::Log2 => Volatility::Immutable, - BuiltinScalarFunction::Round => Volatility::Immutable, - BuiltinScalarFunction::Signum => Volatility::Immutable, - BuiltinScalarFunction::Sin => Volatility::Immutable, - BuiltinScalarFunction::Sqrt => Volatility::Immutable, - BuiltinScalarFunction::Tan => Volatility::Immutable, - BuiltinScalarFunction::Trunc => Volatility::Immutable, - BuiltinScalarFunction::Array => Volatility::Immutable, - BuiltinScalarFunction::Ascii => Volatility::Immutable, - BuiltinScalarFunction::BitLength => Volatility::Immutable, - BuiltinScalarFunction::Btrim => Volatility::Immutable, - BuiltinScalarFunction::CharacterLength => Volatility::Immutable, - BuiltinScalarFunction::Chr => Volatility::Immutable, - BuiltinScalarFunction::Concat => Volatility::Immutable, - BuiltinScalarFunction::ConcatWithSeparator => Volatility::Immutable, - BuiltinScalarFunction::DatePart => Volatility::Immutable, - BuiltinScalarFunction::DateTrunc => Volatility::Immutable, - BuiltinScalarFunction::InitCap => Volatility::Immutable, - BuiltinScalarFunction::Left => Volatility::Immutable, - BuiltinScalarFunction::Lpad => Volatility::Immutable, - BuiltinScalarFunction::Lower => Volatility::Immutable, - BuiltinScalarFunction::Ltrim => Volatility::Immutable, - BuiltinScalarFunction::MD5 => Volatility::Immutable, - BuiltinScalarFunction::NullIf => Volatility::Immutable, - BuiltinScalarFunction::OctetLength => Volatility::Immutable, - BuiltinScalarFunction::RegexpReplace => Volatility::Immutable, - BuiltinScalarFunction::Repeat => Volatility::Immutable, - BuiltinScalarFunction::Replace => Volatility::Immutable, - BuiltinScalarFunction::Reverse => Volatility::Immutable, - BuiltinScalarFunction::Right => Volatility::Immutable, - BuiltinScalarFunction::Rpad => Volatility::Immutable, - BuiltinScalarFunction::Rtrim => Volatility::Immutable, - BuiltinScalarFunction::SHA224 => Volatility::Immutable, - BuiltinScalarFunction::SHA256 => Volatility::Immutable, - BuiltinScalarFunction::SHA384 => Volatility::Immutable, - BuiltinScalarFunction::SHA512 => Volatility::Immutable, - BuiltinScalarFunction::Digest => Volatility::Immutable, - BuiltinScalarFunction::SplitPart => Volatility::Immutable, - BuiltinScalarFunction::StartsWith => Volatility::Immutable, - BuiltinScalarFunction::Strpos => Volatility::Immutable, - BuiltinScalarFunction::Substr => Volatility::Immutable, - BuiltinScalarFunction::ToHex => Volatility::Immutable, - BuiltinScalarFunction::ToTimestamp => Volatility::Immutable, - BuiltinScalarFunction::ToTimestampMillis => Volatility::Immutable, - BuiltinScalarFunction::ToTimestampMicros => Volatility::Immutable, - BuiltinScalarFunction::ToTimestampSeconds => Volatility::Immutable, - BuiltinScalarFunction::Translate => Volatility::Immutable, - BuiltinScalarFunction::Trim => Volatility::Immutable, - BuiltinScalarFunction::Upper => Volatility::Immutable, - BuiltinScalarFunction::RegexpMatch => Volatility::Immutable, - - //Stable builtin functions - BuiltinScalarFunction::Now => Volatility::Stable, - - //Volatile builtin functions - BuiltinScalarFunction::Random => Volatility::Volatile, - } - } -} - -impl fmt::Display for BuiltinScalarFunction { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - // lowercase of the debug. - write!(f, "{}", format!("{:?}", self).to_lowercase()) - } -} - -impl FromStr for BuiltinScalarFunction { - type Err = DataFusionError; - fn from_str(name: &str) -> Result { - Ok(match name { - // math functions - "abs" => BuiltinScalarFunction::Abs, - "acos" => BuiltinScalarFunction::Acos, - "asin" => BuiltinScalarFunction::Asin, - "atan" => BuiltinScalarFunction::Atan, - "ceil" => BuiltinScalarFunction::Ceil, - "cos" => BuiltinScalarFunction::Cos, - "exp" => BuiltinScalarFunction::Exp, - "floor" => BuiltinScalarFunction::Floor, - "ln" => BuiltinScalarFunction::Ln, - "log" => BuiltinScalarFunction::Log, - "log10" => BuiltinScalarFunction::Log10, - "log2" => BuiltinScalarFunction::Log2, - "round" => BuiltinScalarFunction::Round, - "signum" => BuiltinScalarFunction::Signum, - "sin" => BuiltinScalarFunction::Sin, - "sqrt" => BuiltinScalarFunction::Sqrt, - "tan" => BuiltinScalarFunction::Tan, - "trunc" => BuiltinScalarFunction::Trunc, - - // string functions - "array" => BuiltinScalarFunction::Array, - "ascii" => BuiltinScalarFunction::Ascii, - "bit_length" => BuiltinScalarFunction::BitLength, - "btrim" => BuiltinScalarFunction::Btrim, - "char_length" => BuiltinScalarFunction::CharacterLength, - "character_length" => BuiltinScalarFunction::CharacterLength, - "concat" => BuiltinScalarFunction::Concat, - "concat_ws" => BuiltinScalarFunction::ConcatWithSeparator, - "chr" => BuiltinScalarFunction::Chr, - "date_part" | "datepart" => BuiltinScalarFunction::DatePart, - "date_trunc" | "datetrunc" => BuiltinScalarFunction::DateTrunc, - "initcap" => BuiltinScalarFunction::InitCap, - "left" => BuiltinScalarFunction::Left, - "length" => BuiltinScalarFunction::CharacterLength, - "lower" => BuiltinScalarFunction::Lower, - "lpad" => BuiltinScalarFunction::Lpad, - "ltrim" => BuiltinScalarFunction::Ltrim, - "md5" => BuiltinScalarFunction::MD5, - "nullif" => BuiltinScalarFunction::NullIf, - "octet_length" => BuiltinScalarFunction::OctetLength, - "random" => BuiltinScalarFunction::Random, - "regexp_replace" => BuiltinScalarFunction::RegexpReplace, - "repeat" => BuiltinScalarFunction::Repeat, - "replace" => BuiltinScalarFunction::Replace, - "reverse" => BuiltinScalarFunction::Reverse, - "right" => BuiltinScalarFunction::Right, - "rpad" => BuiltinScalarFunction::Rpad, - "rtrim" => BuiltinScalarFunction::Rtrim, - "sha224" => BuiltinScalarFunction::SHA224, - "sha256" => BuiltinScalarFunction::SHA256, - "sha384" => BuiltinScalarFunction::SHA384, - "sha512" => BuiltinScalarFunction::SHA512, - "digest" => BuiltinScalarFunction::Digest, - "split_part" => BuiltinScalarFunction::SplitPart, - "starts_with" => BuiltinScalarFunction::StartsWith, - "strpos" => BuiltinScalarFunction::Strpos, - "substr" => BuiltinScalarFunction::Substr, - "to_hex" => BuiltinScalarFunction::ToHex, - "to_timestamp" => BuiltinScalarFunction::ToTimestamp, - "to_timestamp_millis" => BuiltinScalarFunction::ToTimestampMillis, - "to_timestamp_micros" => BuiltinScalarFunction::ToTimestampMicros, - "to_timestamp_seconds" => BuiltinScalarFunction::ToTimestampSeconds, - "now" => BuiltinScalarFunction::Now, - "translate" => BuiltinScalarFunction::Translate, - "trim" => BuiltinScalarFunction::Trim, - "upper" => BuiltinScalarFunction::Upper, - "regexp_match" => BuiltinScalarFunction::RegexpMatch, - _ => { - return Err(DataFusionError::Plan(format!( - "There is no built-in function named {}", - name - ))) - } - }) - } -} - macro_rules! make_utf8_to_return_type { ($FUNC:ident, $largeUtf8Type:expr, $utf8Type:expr) => { fn $FUNC(arg_type: &DataType, name: &str) -> Result { @@ -1654,17 +1251,6 @@ impl fmt::Display for ScalarFunctionExpr { } } -/// null columnar values are implemented as a null array in order to pass batch -/// num_rows -type NullColumnarValue = ColumnarValue; - -impl From<&RecordBatch> for NullColumnarValue { - fn from(batch: &RecordBatch) -> Self { - let num_rows = batch.num_rows(); - ColumnarValue::Array(Arc::new(NullArray::from_data(DataType::Null, num_rows))) - } -} - impl PhysicalExpr for ScalarFunctionExpr { /// Return a reference to Any that can be used for downcasting fn as_any(&self) -> &dyn Any { @@ -3636,6 +3222,18 @@ mod tests { StringArray ); #[cfg(feature = "unicode_expressions")] + test_function!( + Substr, + &[ + lit(ScalarValue::Utf8(Some("joséésoj".to_string()))), + lit(ScalarValue::Int64(Some(-5))), + ], + Ok(Some("joséésoj")), + &str, + Utf8, + StringArray + ); + #[cfg(feature = "unicode_expressions")] test_function!( Substr, &[ diff --git a/datafusion/src/physical_plan/hash_utils.rs b/datafusion/src/physical_plan/hash_utils.rs index 2f063f3577f4..f8fd8fc09c31 100644 --- a/datafusion/src/physical_plan/hash_utils.rs +++ b/datafusion/src/physical_plan/hash_utils.rs @@ -381,6 +381,16 @@ mod noforce_hash_collisions { multi_col ); } + DataType::Timestamp(TimeUnit::Second, None) => { + hash_array_primitive!( + Int64Array, + col, + i64, + hashes_buffer, + random_state, + multi_col + ); + } DataType::Timestamp(TimeUnit::Millisecond, None) => { hash_array_primitive!( Int64Array, diff --git a/datafusion/src/physical_plan/limit.rs b/datafusion/src/physical_plan/limit.rs index 762c598d46c7..07fdbe642a2c 100644 --- a/datafusion/src/physical_plan/limit.rs +++ b/datafusion/src/physical_plan/limit.rs @@ -301,6 +301,11 @@ impl ExecutionPlan for LocalLimitExec { _ => Statistics::default(), } } + + fn should_repartition_children(&self) -> bool { + // No reason to repartition children as this node is just limiting each input partition. + false + } } /// Truncate a RecordBatch to maximum of n rows diff --git a/datafusion/src/physical_plan/mod.rs b/datafusion/src/physical_plan/mod.rs index 5d605a02abe9..d1fc0cc3b627 100644 --- a/datafusion/src/physical_plan/mod.rs +++ b/datafusion/src/physical_plan/mod.rs @@ -37,6 +37,8 @@ use arrow::compute::sort::SortColumn as ArrowSortColumn; use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; use arrow::error::Result as ArrowResult; use async_trait::async_trait; +pub use datafusion_expr::Accumulator; +pub use datafusion_expr::ColumnarValue; pub use display::DisplayFormatType; use futures::stream::Stream; use sorts::SortColumn; @@ -138,14 +140,32 @@ pub trait ExecutionPlan: Debug + Send + Sync { /// Returns the execution plan as [`Any`](std::any::Any) so that it can be /// downcast to a specific implementation. fn as_any(&self) -> &dyn Any; + /// Get the schema for this execution plan fn schema(&self) -> SchemaRef; + /// Specifies the output partitioning scheme of this plan fn output_partitioning(&self) -> Partitioning; + /// Specifies the data distribution requirements of all the children for this operator fn required_child_distribution(&self) -> Distribution { Distribution::UnspecifiedDistribution } + + /// Returns `true` if the direct children of this `ExecutionPlan` should be repartitioned + /// to introduce greater concurrency to the plan + /// + /// The default implementation returns `true` unless `Self::required_child_distribution` + /// returns `Distribution::SinglePartition` + /// + /// Operators that do not benefit from additional partitioning may want to return `false` + fn should_repartition_children(&self) -> bool { + !matches!( + self.required_child_distribution(), + Distribution::SinglePartition + ) + } + /// Get a list of child execution plans that provide the input for this plan. The returned list /// will be empty for leaf nodes, will contain a single value for unary nodes, or two /// values for binary nodes (such as joins). @@ -404,32 +424,6 @@ pub enum Distribution { HashPartitioned(Vec>), } -/// Represents the result from an expression -#[derive(Clone, Debug)] -pub enum ColumnarValue { - /// Array of values - Array(ArrayRef), - /// A single value - Scalar(ScalarValue), -} - -impl ColumnarValue { - fn data_type(&self) -> DataType { - match self { - ColumnarValue::Array(array_value) => array_value.data_type().clone(), - ColumnarValue::Scalar(scalar_value) => scalar_value.get_datatype(), - } - } - - /// Convert a columnar value into an ArrayRef - pub fn into_array(self, num_rows: usize) -> ArrayRef { - match self { - ColumnarValue::Array(array) => array, - ColumnarValue::Scalar(scalar) => scalar.to_array_of_size(num_rows), - } - } -} - /// Expression that can be evaluated against a RecordBatch /// A Physical expression knows its type, nullability and how to evaluate itself. pub trait PhysicalExpr: Send + Sync + Display + Debug { @@ -567,30 +561,6 @@ pub trait WindowExpr: Send + Sync + Debug { } } -/// An accumulator represents a stateful object that lives throughout the evaluation of multiple rows and -/// generically accumulates values. -/// -/// An accumulator knows how to: -/// * update its state from inputs via `update_batch` -/// * convert its internal state to a vector of scalar values -/// * update its state from multiple accumulators' states via `merge_batch` -/// * compute the final value from its internal state via `evaluate` -pub trait Accumulator: Send + Sync + Debug { - /// Returns the state of the accumulator at the end of the accumulation. - // in the case of an average on which we track `sum` and `n`, this function should return a vector - // of two values, sum and n. - fn state(&self) -> Result>; - - /// updates the accumulator's state from a vector of arrays. - fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()>; - - /// updates the accumulator's state from a vector of states. - fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()>; - - /// returns its value based on its current state. - fn evaluate(&self) -> Result; -} - /// Applies an optional projection to a [`SchemaRef`], returning the /// projected schema /// diff --git a/datafusion/src/physical_plan/udaf.rs b/datafusion/src/physical_plan/udaf.rs index 0de696d61172..71e7e0657596 100644 --- a/datafusion/src/physical_plan/udaf.rs +++ b/datafusion/src/physical_plan/udaf.rs @@ -17,7 +17,7 @@ //! This module contains functions and structs supporting user-defined aggregate functions. -use fmt::{Debug, Formatter}; +use fmt::Debug; use std::any::Any; use std::fmt; @@ -26,85 +26,14 @@ use arrow::{ datatypes::{DataType, Schema}, }; -use crate::physical_plan::PhysicalExpr; -use crate::{error::Result, logical_plan::Expr}; - use super::{ - aggregates::AccumulatorFunctionImplementation, - aggregates::StateTypeFunction, - expressions::format_state_name, - functions::{ReturnTypeFunction, Signature}, - type_coercion::coerce, - Accumulator, AggregateExpr, + expressions::format_state_name, type_coercion::coerce, Accumulator, AggregateExpr, }; -use std::sync::Arc; - -/// Logical representation of a user-defined aggregate function (UDAF) -/// A UDAF is different from a UDF in that it is stateful across batches. -#[derive(Clone)] -pub struct AggregateUDF { - /// name - pub name: String, - /// signature - pub signature: Signature, - /// Return type - pub return_type: ReturnTypeFunction, - /// actual implementation - pub accumulator: AccumulatorFunctionImplementation, - /// the accumulator's state's description as a function of the return type - pub state_type: StateTypeFunction, -} - -impl Debug for AggregateUDF { - fn fmt(&self, f: &mut Formatter) -> fmt::Result { - f.debug_struct("AggregateUDF") - .field("name", &self.name) - .field("signature", &self.signature) - .field("fun", &"") - .finish() - } -} - -impl PartialEq for AggregateUDF { - fn eq(&self, other: &Self) -> bool { - self.name == other.name && self.signature == other.signature - } -} - -impl std::hash::Hash for AggregateUDF { - fn hash(&self, state: &mut H) { - self.name.hash(state); - self.signature.hash(state); - } -} - -impl AggregateUDF { - /// Create a new AggregateUDF - pub fn new( - name: &str, - signature: &Signature, - return_type: &ReturnTypeFunction, - accumulator: &AccumulatorFunctionImplementation, - state_type: &StateTypeFunction, - ) -> Self { - Self { - name: name.to_owned(), - signature: signature.clone(), - return_type: return_type.clone(), - accumulator: accumulator.clone(), - state_type: state_type.clone(), - } - } +use crate::error::Result; +use crate::physical_plan::PhysicalExpr; +pub use datafusion_expr::AggregateUDF; - /// creates a logical expression with a call of the UDAF - /// This utility allows using the UDAF without requiring access to the registry. - pub fn call(&self, args: Vec) -> Expr { - Expr::AggregateUDF { - fun: Arc::new(self.clone()), - args, - } - } -} +use std::sync::Arc; /// Creates a physical expression of the UDAF, that includes all necessary type coercion. /// This function errors when `args`' can't be coerced to a valid argument type of the UDAF. diff --git a/datafusion/src/physical_plan/udf.rs b/datafusion/src/physical_plan/udf.rs index 7355746a368b..58e66da48a7d 100644 --- a/datafusion/src/physical_plan/udf.rs +++ b/datafusion/src/physical_plan/udf.rs @@ -17,91 +17,16 @@ //! UDF support -use fmt::{Debug, Formatter}; -use std::fmt; - +use super::type_coercion::coerce; +use crate::error::Result; +use crate::physical_plan::functions::ScalarFunctionExpr; +use crate::physical_plan::PhysicalExpr; use arrow::datatypes::Schema; -use crate::error::Result; -use crate::{logical_plan::Expr, physical_plan::PhysicalExpr}; +pub use datafusion_expr::ScalarUDF; -use super::{ - functions::{ - ReturnTypeFunction, ScalarFunctionExpr, ScalarFunctionImplementation, Signature, - }, - type_coercion::coerce, -}; use std::sync::Arc; -/// Logical representation of a UDF. -#[derive(Clone)] -pub struct ScalarUDF { - /// name - pub name: String, - /// signature - pub signature: Signature, - /// Return type - pub return_type: ReturnTypeFunction, - /// actual implementation - /// - /// The fn param is the wrapped function but be aware that the function will - /// be passed with the slice / vec of columnar values (either scalar or array) - /// with the exception of zero param function, where a singular element vec - /// will be passed. In that case the single element is a null array to indicate - /// the batch's row count (so that the generative zero-argument function can know - /// the result array size). - pub fun: ScalarFunctionImplementation, -} - -impl Debug for ScalarUDF { - fn fmt(&self, f: &mut Formatter) -> fmt::Result { - f.debug_struct("ScalarUDF") - .field("name", &self.name) - .field("signature", &self.signature) - .field("fun", &"") - .finish() - } -} - -impl PartialEq for ScalarUDF { - fn eq(&self, other: &Self) -> bool { - self.name == other.name && self.signature == other.signature - } -} - -impl std::hash::Hash for ScalarUDF { - fn hash(&self, state: &mut H) { - self.name.hash(state); - self.signature.hash(state); - } -} - -impl ScalarUDF { - /// Create a new ScalarUDF - pub fn new( - name: &str, - signature: &Signature, - return_type: &ReturnTypeFunction, - fun: &ScalarFunctionImplementation, - ) -> Self { - Self { - name: name.to_owned(), - signature: signature.clone(), - return_type: return_type.clone(), - fun: fun.clone(), - } - } - - /// creates a logical expression with a call of the UDF - /// This utility allows using the UDF without requiring access to the registry. - pub fn call(&self, args: Vec) -> Expr { - Expr::ScalarUDF { - fun: Arc::new(self.clone()), - args, - } - } -} - /// Create a physical expression of the UDF. /// This function errors when `args`' can't be coerced to a valid argument type of the UDF. pub fn create_physical_expr( diff --git a/datafusion/src/physical_plan/union.rs b/datafusion/src/physical_plan/union.rs index 96dbc2eb448c..2d79a19dd8b6 100644 --- a/datafusion/src/physical_plan/union.rs +++ b/datafusion/src/physical_plan/union.rs @@ -144,6 +144,10 @@ impl ExecutionPlan for UnionExec { .reduce(stats_union) .unwrap_or_default() } + + fn should_repartition_children(&self) -> bool { + false + } } /// Stream wrapper that records `BaselineMetrics` for a particular diff --git a/datafusion/src/physical_plan/window_functions.rs b/datafusion/src/physical_plan/window_functions.rs index f281d7d01837..0b1fd65cb677 100644 --- a/datafusion/src/physical_plan/window_functions.rs +++ b/datafusion/src/physical_plan/window_functions.rs @@ -23,129 +23,16 @@ use crate::error::{DataFusionError, Result}; use crate::physical_plan::functions::{TypeSignature, Volatility}; use crate::physical_plan::{ - aggregates, aggregates::AggregateFunction, functions::Signature, - type_coercion::data_types, windows::find_ranges_in_range, PhysicalExpr, + aggregates, functions::Signature, type_coercion::data_types, + windows::find_ranges_in_range, PhysicalExpr, }; use crate::record_batch::RecordBatch; use arrow::array::ArrayRef; use arrow::datatypes::{DataType, Field}; +pub use datafusion_expr::{BuiltInWindowFunction, WindowFunction}; use std::any::Any; use std::ops::Range; use std::sync::Arc; -use std::{fmt, str::FromStr}; - -/// WindowFunction -#[derive(Debug, Clone, PartialEq, Eq, Hash)] -pub enum WindowFunction { - /// window function that leverages an aggregate function - AggregateFunction(AggregateFunction), - /// window function that leverages a built-in window function - BuiltInWindowFunction(BuiltInWindowFunction), -} - -impl FromStr for WindowFunction { - type Err = DataFusionError; - fn from_str(name: &str) -> Result { - let name = name.to_lowercase(); - if let Ok(aggregate) = AggregateFunction::from_str(name.as_str()) { - Ok(WindowFunction::AggregateFunction(aggregate)) - } else if let Ok(built_in_function) = - BuiltInWindowFunction::from_str(name.as_str()) - { - Ok(WindowFunction::BuiltInWindowFunction(built_in_function)) - } else { - Err(DataFusionError::Plan(format!( - "There is no window function named {}", - name - ))) - } - } -} - -impl fmt::Display for BuiltInWindowFunction { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - match self { - BuiltInWindowFunction::RowNumber => write!(f, "ROW_NUMBER"), - BuiltInWindowFunction::Rank => write!(f, "RANK"), - BuiltInWindowFunction::DenseRank => write!(f, "DENSE_RANK"), - BuiltInWindowFunction::PercentRank => write!(f, "PERCENT_RANK"), - BuiltInWindowFunction::CumeDist => write!(f, "CUME_DIST"), - BuiltInWindowFunction::Ntile => write!(f, "NTILE"), - BuiltInWindowFunction::Lag => write!(f, "LAG"), - BuiltInWindowFunction::Lead => write!(f, "LEAD"), - BuiltInWindowFunction::FirstValue => write!(f, "FIRST_VALUE"), - BuiltInWindowFunction::LastValue => write!(f, "LAST_VALUE"), - BuiltInWindowFunction::NthValue => write!(f, "NTH_VALUE"), - } - } -} - -impl fmt::Display for WindowFunction { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - match self { - WindowFunction::AggregateFunction(fun) => fun.fmt(f), - WindowFunction::BuiltInWindowFunction(fun) => fun.fmt(f), - } - } -} - -/// An aggregate function that is part of a built-in window function -#[derive(Debug, Clone, PartialEq, Eq, Hash)] -pub enum BuiltInWindowFunction { - /// number of the current row within its partition, counting from 1 - RowNumber, - /// rank of the current row with gaps; same as row_number of its first peer - Rank, - /// ank of the current row without gaps; this function counts peer groups - DenseRank, - /// relative rank of the current row: (rank - 1) / (total rows - 1) - PercentRank, - /// relative rank of the current row: (number of rows preceding or peer with current row) / (total rows) - CumeDist, - /// integer ranging from 1 to the argument value, dividing the partition as equally as possible - Ntile, - /// returns value evaluated at the row that is offset rows before the current row within the partition; - /// if there is no such row, instead return default (which must be of the same type as value). - /// Both offset and default are evaluated with respect to the current row. - /// If omitted, offset defaults to 1 and default to null - Lag, - /// returns value evaluated at the row that is offset rows after the current row within the partition; - /// if there is no such row, instead return default (which must be of the same type as value). - /// Both offset and default are evaluated with respect to the current row. - /// If omitted, offset defaults to 1 and default to null - Lead, - /// returns value evaluated at the row that is the first row of the window frame - FirstValue, - /// returns value evaluated at the row that is the last row of the window frame - LastValue, - /// returns value evaluated at the row that is the nth row of the window frame (counting from 1); null if no such row - NthValue, -} - -impl FromStr for BuiltInWindowFunction { - type Err = DataFusionError; - fn from_str(name: &str) -> Result { - Ok(match name.to_uppercase().as_str() { - "ROW_NUMBER" => BuiltInWindowFunction::RowNumber, - "RANK" => BuiltInWindowFunction::Rank, - "DENSE_RANK" => BuiltInWindowFunction::DenseRank, - "PERCENT_RANK" => BuiltInWindowFunction::PercentRank, - "CUME_DIST" => BuiltInWindowFunction::CumeDist, - "NTILE" => BuiltInWindowFunction::Ntile, - "LAG" => BuiltInWindowFunction::Lag, - "LEAD" => BuiltInWindowFunction::Lead, - "FIRST_VALUE" => BuiltInWindowFunction::FirstValue, - "LAST_VALUE" => BuiltInWindowFunction::LastValue, - "NTH_VALUE" => BuiltInWindowFunction::NthValue, - _ => { - return Err(DataFusionError::Plan(format!( - "There is no built-in window function named {}", - name - ))) - } - }) - } -} /// Returns the datatype of the window function pub fn return_type( @@ -302,72 +189,7 @@ pub(crate) trait BuiltInWindowFunctionExpr: Send + Sync + std::fmt::Debug { #[cfg(test)] mod tests { use super::*; - - #[test] - fn test_window_function_case_insensitive() -> Result<()> { - let names = vec![ - "row_number", - "rank", - "dense_rank", - "percent_rank", - "cume_dist", - "ntile", - "lag", - "lead", - "first_value", - "last_value", - "nth_value", - "min", - "max", - "count", - "avg", - "sum", - ]; - for name in names { - let fun = WindowFunction::from_str(name)?; - let fun2 = WindowFunction::from_str(name.to_uppercase().as_str())?; - assert_eq!(fun, fun2); - assert_eq!(fun.to_string(), name.to_uppercase()); - } - Ok(()) - } - - #[test] - fn test_window_function_from_str() -> Result<()> { - assert_eq!( - WindowFunction::from_str("max")?, - WindowFunction::AggregateFunction(AggregateFunction::Max) - ); - assert_eq!( - WindowFunction::from_str("min")?, - WindowFunction::AggregateFunction(AggregateFunction::Min) - ); - assert_eq!( - WindowFunction::from_str("avg")?, - WindowFunction::AggregateFunction(AggregateFunction::Avg) - ); - assert_eq!( - WindowFunction::from_str("cume_dist")?, - WindowFunction::BuiltInWindowFunction(BuiltInWindowFunction::CumeDist) - ); - assert_eq!( - WindowFunction::from_str("first_value")?, - WindowFunction::BuiltInWindowFunction(BuiltInWindowFunction::FirstValue) - ); - assert_eq!( - WindowFunction::from_str("LAST_value")?, - WindowFunction::BuiltInWindowFunction(BuiltInWindowFunction::LastValue) - ); - assert_eq!( - WindowFunction::from_str("LAG")?, - WindowFunction::BuiltInWindowFunction(BuiltInWindowFunction::Lag) - ); - assert_eq!( - WindowFunction::from_str("LEAD")?, - WindowFunction::BuiltInWindowFunction(BuiltInWindowFunction::Lead) - ); - Ok(()) - } + use std::str::FromStr; #[test] fn test_count_return_type() -> Result<()> { diff --git a/datafusion/src/pyarrow.rs b/datafusion/src/pyarrow.rs deleted file mode 100644 index 88ab2e4dade5..000000000000 --- a/datafusion/src/pyarrow.rs +++ /dev/null @@ -1,96 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -use arrow::array::Array; -use arrow::error::ArrowError; -use pyo3::exceptions::{PyException, PyNotImplementedError}; -use pyo3::ffi::Py_uintptr_t; -use pyo3::prelude::*; -use pyo3::types::PyList; -use std::sync::Arc; - -use crate::error::DataFusionError; -use crate::scalar::ScalarValue; - -impl From for PyErr { - fn from(err: DataFusionError) -> PyErr { - PyException::new_err(err.to_string()) - } -} - -impl From for PyErr { - fn from(err: PyO3ArrowError) -> PyErr { - PyException::new_err(format!("{:?}", err)) - } -} - -#[derive(Debug)] -enum PyO3ArrowError { - ArrowError(ArrowError), -} - -fn to_rust_array(ob: PyObject, py: Python) -> PyResult> { - // prepare a pointer to receive the Array struct - let array = Box::new(arrow::ffi::Ffi_ArrowArray::empty()); - let schema = Box::new(arrow::ffi::Ffi_ArrowSchema::empty()); - - let array_ptr = &*array as *const arrow::ffi::Ffi_ArrowArray; - let schema_ptr = &*schema as *const arrow::ffi::Ffi_ArrowSchema; - - // make the conversion through PyArrow's private API - // this changes the pointer's memory and is thus unsafe. In particular, `_export_to_c` can go out of bounds - ob.call_method1( - py, - "_export_to_c", - (array_ptr as Py_uintptr_t, schema_ptr as Py_uintptr_t), - )?; - - let field = unsafe { - arrow::ffi::import_field_from_c(schema.as_ref()) - .map_err(PyO3ArrowError::ArrowError)? - }; - let array = unsafe { - arrow::ffi::import_array_from_c(array, &field) - .map_err(PyO3ArrowError::ArrowError)? - }; - - Ok(array.into()) -} -impl<'source> FromPyObject<'source> for ScalarValue { - fn extract(value: &'source PyAny) -> PyResult { - let py = value.py(); - let typ = value.getattr("type")?; - let val = value.call_method0("as_py")?; - - // construct pyarrow array from the python value and pyarrow type - let factory = py.import("pyarrow")?.getattr("array")?; - let args = PyList::new(py, &[val]); - let array = factory.call1((args, typ))?; - - // convert the pyarrow array to rust array using C data interface] - let array = to_rust_array(array.to_object(py), py)?; - let scalar = ScalarValue::try_from_array(&array, 0)?; - - Ok(scalar) - } -} - -impl<'a> IntoPy for ScalarValue { - fn into_py(self, _py: Python) -> PyObject { - Err(PyNotImplementedError::new_err("Not implemented")).unwrap() - } -} diff --git a/datafusion/src/record_batch.rs b/datafusion/src/record_batch.rs index 8fba09e73e33..430904c58cac 100644 --- a/datafusion/src/record_batch.rs +++ b/datafusion/src/record_batch.rs @@ -1,432 +1,20 @@ -//! Contains [`RecordBatch`]. -use std::sync::Arc; - -use crate::field_util::SchemaExt; -use arrow::array::*; -use arrow::chunk::Chunk; -use arrow::compute::filter::{build_filter, filter}; -use arrow::datatypes::*; -use arrow::error::{ArrowError, Result}; - -/// A two-dimensional dataset with a number of -/// columns ([`Array`]) and rows and defined [`Schema`](crate::datatypes::Schema). -/// # Implementation -/// Cloning is `O(C)` where `C` is the number of columns. -#[derive(Clone, Debug, PartialEq)] -pub struct RecordBatch { - schema: Arc, - columns: Vec>, -} - -impl RecordBatch { - /// Creates a [`RecordBatch`] from a schema and columns. - /// # Errors - /// This function errors iff - /// * `columns` is empty - /// * the schema and column data types do not match - /// * `columns` have a different length - /// # Example - /// - /// ``` - /// # use std::sync::Arc; - /// # use arrow2::array::PrimitiveArray; - /// # use arrow2::datatypes::{Schema, Field, DataType}; - /// # use arrow2::record_batch::RecordBatch; - /// # fn main() -> arrow2::error::Result<()> { - /// let id_array = PrimitiveArray::from_slice([1i32, 2, 3, 4, 5]); - /// let schema = Arc::new(Schema::new(vec![ - /// Field::new("id", DataType::Int32, false) - /// ])); - /// - /// let batch = RecordBatch::try_new( - /// schema, - /// vec![Arc::new(id_array)] - /// )?; - /// # Ok(()) - /// # } - /// ``` - pub fn try_new(schema: Arc, columns: Vec>) -> Result { - let options = RecordBatchOptions::default(); - Self::validate_new_batch(&schema, columns.as_slice(), &options)?; - Ok(RecordBatch { schema, columns }) - } - - /// Creates a [`RecordBatch`] from a schema and columns, with additional options, - /// such as whether to strictly validate field names. - /// - /// See [`Self::try_new()`] for the expected conditions. - pub fn try_new_with_options( - schema: Arc, - columns: Vec>, - options: &RecordBatchOptions, - ) -> Result { - Self::validate_new_batch(&schema, &columns, options)?; - Ok(RecordBatch { schema, columns }) - } - - /// Creates a new empty [`RecordBatch`]. - pub fn new_empty(schema: Arc) -> Self { - let columns = schema - .fields() - .iter() - .map(|field| new_empty_array(field.data_type().clone()).into()) - .collect(); - RecordBatch { schema, columns } - } - - /// Creates a new [`RecordBatch`] from a [`arrow::chunk::Chunk`] - pub fn new_with_chunk(schema: &Arc, chunk: Chunk) -> Self { - Self { - schema: schema.clone(), - columns: chunk.into_arrays(), - } - } - - /// Validate the schema and columns using [`RecordBatchOptions`]. Returns an error - /// if any validation check fails. - fn validate_new_batch( - schema: &Schema, - columns: &[Arc], - options: &RecordBatchOptions, - ) -> Result<()> { - // check that there are some columns - if columns.is_empty() { - return Err(ArrowError::InvalidArgumentError( - "at least one column must be defined to create a record batch" - .to_string(), - )); - } - // check that number of fields in schema match column length - if schema.fields().len() != columns.len() { - return Err(ArrowError::InvalidArgumentError(format!( - "number of columns({}) must match number of fields({}) in schema", - columns.len(), - schema.fields().len(), - ))); - } - // check that all columns have the same row count, and match the schema - let len = columns[0].len(); - - // This is a bit repetitive, but it is better to check the condition outside the loop - if options.match_field_names { - for (i, column) in columns.iter().enumerate() { - if column.len() != len { - return Err(ArrowError::InvalidArgumentError( - "all columns in a record batch must have the same length" - .to_string(), - )); - } - if column.data_type() != schema.field(i).data_type() { - return Err(ArrowError::InvalidArgumentError(format!( - "column types must match schema types, expected {:?} but found {:?} at column index {}", - schema.field(i).data_type(), - column.data_type(), - i))); - } - } - } else { - for (i, column) in columns.iter().enumerate() { - if column.len() != len { - return Err(ArrowError::InvalidArgumentError( - "all columns in a record batch must have the same length" - .to_string(), - )); - } - if !column.data_type().eq(schema.field(i).data_type()) { - return Err(ArrowError::InvalidArgumentError(format!( - "column types must match schema types, expected {:?} but found {:?} at column index {}", - schema.field(i).data_type(), - column.data_type(), - i))); - } - } - } - - Ok(()) - } - - /// Returns the [`Schema`](crate::datatypes::Schema) of the record batch. - pub fn schema(&self) -> &Arc { - &self.schema - } - - /// Returns the number of columns in the record batch. - /// - /// # Example - /// - /// ``` - /// # use std::sync::Arc; - /// # use arrow2::array::PrimitiveArray; - /// # use arrow2::datatypes::{Schema, Field, DataType}; - /// # use arrow2::record_batch::RecordBatch; - /// # fn main() -> arrow2::error::Result<()> { - /// let id_array = PrimitiveArray::from_slice([1i32, 2, 3, 4, 5]); - /// let schema = Arc::new(Schema::new(vec![ - /// Field::new("id", DataType::Int32, false) - /// ])); - /// - /// let batch = RecordBatch::try_new(schema, vec![Arc::new(id_array)])?; - /// - /// assert_eq!(batch.num_columns(), 1); - /// # Ok(()) - /// # } - /// ``` - pub fn num_columns(&self) -> usize { - self.columns.len() - } - - /// Returns the number of rows in each column. - /// - /// # Panics - /// - /// Panics if the `RecordBatch` contains no columns. - /// - /// # Example - /// - /// ``` - /// # use std::sync::Arc; - /// # use arrow2::array::PrimitiveArray; - /// # use arrow2::datatypes::{Schema, Field, DataType}; - /// # use arrow2::record_batch::RecordBatch; - /// # fn main() -> arrow2::error::Result<()> { - /// let id_array = PrimitiveArray::from_slice([1i32, 2, 3, 4, 5]); - /// let schema = Arc::new(Schema::new(vec![ - /// Field::new("id", DataType::Int32, false) - /// ])); - /// - /// let batch = RecordBatch::try_new(schema, vec![Arc::new(id_array)])?; - /// - /// assert_eq!(batch.num_rows(), 5); - /// # Ok(()) - /// # } - /// ``` - pub fn num_rows(&self) -> usize { - self.columns[0].len() - } - - /// Get a reference to a column's array by index. - /// - /// # Panics - /// - /// Panics if `index` is outside of `0..num_columns`. - pub fn column(&self, index: usize) -> &Arc { - &self.columns[index] - } - - /// Get a reference to all columns in the record batch. - pub fn columns(&self) -> &[Arc] { - &self.columns[..] - } - - /// Create a `RecordBatch` from an iterable list of pairs of the - /// form `(field_name, array)`, with the same requirements on - /// fields and arrays as [`RecordBatch::try_new`]. This method is - /// often used to create a single `RecordBatch` from arrays, - /// e.g. for testing. - /// - /// The resulting schema is marked as nullable for each column if - /// the array for that column is has any nulls. To explicitly - /// specify nullibility, use [`RecordBatch::try_from_iter_with_nullable`] - /// - /// Example: - /// ``` - /// use std::sync::Arc; - /// use arrow::array::*; - /// use arrow::datatypes::DataType; - /// use datafusion::record_batch::RecordBatch; - /// - /// let a: Arc = Arc::new(Int32Array::from_slice(&[1, 2])); - /// let b: Arc = Arc::new(Utf8Array::::from_slice(&["a", "b"])); - /// - /// let record_batch = RecordBatch::try_from_iter(vec![ - /// ("a", a), - /// ("b", b), - /// ]); - /// ``` - pub fn try_from_iter(value: I) -> Result - where - I: IntoIterator)>, - F: AsRef, - { - // TODO: implement `TryFrom` trait, once - // https://github.com/rust-lang/rust/issues/50133 is no longer an - // issue - let iter = value.into_iter().map(|(field_name, array)| { - let nullable = array.null_count() > 0; - (field_name, array, nullable) - }); - - Self::try_from_iter_with_nullable(iter) - } - - /// Create a `RecordBatch` from an iterable list of tuples of the - /// form `(field_name, array, nullable)`, with the same requirements on - /// fields and arrays as [`RecordBatch::try_new`]. This method is often - /// used to create a single `RecordBatch` from arrays, e.g. for - /// testing. - /// - /// Example: - /// ``` - /// use std::sync::Arc; - /// use arrow::array::*; - /// use arrow::datatypes::DataType; - /// use datafusion::record_batch::RecordBatch; - /// - /// let a: Arc = Arc::new(Int32Array::from_slice(&[1, 2])); - /// let b: Arc = Arc::new(Utf8Array::::from_slice(&["a", "b"])); - /// - /// // Note neither `a` nor `b` has any actual nulls, but we mark - /// // b an nullable - /// let record_batch = RecordBatch::try_from_iter_with_nullable(vec![ - /// ("a", a, false), - /// ("b", b, true), - /// ]); - /// ``` - pub fn try_from_iter_with_nullable(value: I) -> Result - where - I: IntoIterator, bool)>, - F: AsRef, - { - // TODO: implement `TryFrom` trait, once - // https://github.com/rust-lang/rust/issues/50133 is no longer an - // issue - let (fields, columns) = value - .into_iter() - .map(|(field_name, array, nullable)| { - let field_name = field_name.as_ref(); - let field = Field::new(field_name, array.data_type().clone(), nullable); - (field, array) - }) - .unzip(); - - let schema = Arc::new(Schema::new(fields)); - RecordBatch::try_new(schema, columns) - } - - /// Deconstructs itself into its internal components - pub fn into_inner(self) -> (Vec>, Arc) { - let Self { columns, schema } = self; - (columns, schema) - } - - /// Projects the schema onto the specified columns - pub fn project(&self, indices: &[usize]) -> Result { - let projected_schema = self.schema.project(indices)?; - let batch_fields = indices - .iter() - .map(|f| { - self.columns.get(*f).cloned().ok_or_else(|| { - ArrowError::InvalidArgumentError(format!( - "project index {} out of bounds, max field {}", - f, - self.columns.len() - )) - }) - }) - .collect::>>()?; - - RecordBatch::try_new(SchemaRef::new(projected_schema), batch_fields) - } - - /// Return a new RecordBatch where each column is sliced - /// according to `offset` and `length` - /// - /// # Panics - /// - /// Panics if `offset` with `length` is greater than column length. - pub fn slice(&self, offset: usize, length: usize) -> RecordBatch { - if self.schema.fields().is_empty() { - assert!((offset + length) == 0); - return RecordBatch::new_empty(self.schema.clone()); - } - assert!((offset + length) <= self.num_rows()); - - let columns = self - .columns() - .iter() - .map(|column| Arc::from(column.slice(offset, length))) - .collect(); - - Self { - schema: self.schema.clone(), - columns, - } - } -} - -/// Options that control the behaviour used when creating a [`RecordBatch`]. -#[derive(Debug)] -pub struct RecordBatchOptions { - /// Match field names of structs and lists. If set to `true`, the names must match. - pub match_field_names: bool, -} - -impl Default for RecordBatchOptions { - fn default() -> Self { - Self { - match_field_names: true, - } - } -} - -impl From for RecordBatch { - /// # Panics iff the null count of the array is not null. - fn from(array: StructArray) -> Self { - assert!(array.null_count() == 0); - let (fields, values, _) = array.into_data(); - RecordBatch { - schema: Arc::new(Schema::new(fields)), - columns: values, - } - } -} - -impl From for StructArray { - fn from(batch: RecordBatch) -> Self { - let (fields, values) = batch - .schema - .fields - .iter() - .zip(batch.columns.iter()) - .map(|t| (t.0.clone(), t.1.clone())) - .unzip(); - StructArray::from_data(DataType::Struct(fields), values, None) - } -} - -impl From for Chunk { - fn from(rb: RecordBatch) -> Self { - Chunk::new(rb.columns) - } -} - -impl From<&RecordBatch> for Chunk { - fn from(rb: &RecordBatch) -> Self { - Chunk::new(rb.columns.clone()) - } -} - -/// Returns a new [RecordBatch] with arrays containing only values matching the filter. -/// WARNING: the nulls of `filter` are ignored and the value on its slot is considered. -/// Therefore, it is considered undefined behavior to pass `filter` with null values. -pub fn filter_record_batch( - record_batch: &RecordBatch, - filter_values: &BooleanArray, -) -> Result { - let num_colums = record_batch.columns().len(); - - let filtered_arrays = match num_colums { - 1 => { - vec![filter(record_batch.columns()[0].as_ref(), filter_values)?.into()] - } - _ => { - let filter = build_filter(filter_values)?; - record_batch - .columns() - .iter() - .map(|a| filter(a.as_ref()).into()) - .collect() - } - }; - RecordBatch::try_new(record_batch.schema().clone(), filtered_arrays) -} +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! RecordBatch reimported from datafusion-common + +pub use datafusion_common::record_batch::*; diff --git a/datafusion/src/scalar.rs b/datafusion/src/scalar.rs index 847a9ddd65fd..fffeaa1b58f7 100644 --- a/datafusion/src/scalar.rs +++ b/datafusion/src/scalar.rs @@ -15,1903 +15,25 @@ // specific language governing permissions and limitations // under the License. -//! This module provides ScalarValue, an enum that can be used for storage of single elements - -use std::{convert::TryFrom, fmt, iter::repeat, sync::Arc}; - -use crate::error::{DataFusionError, Result}; -use crate::field_util::{FieldExt, StructArrayExt}; -use arrow::bitmap::Bitmap; -use arrow::buffer::Buffer; -use arrow::compute::concatenate; -use arrow::datatypes::DataType::Decimal; -use arrow::{ - array::*, - datatypes::{DataType, Field, IntegerType, IntervalUnit, TimeUnit}, - scalar::{PrimitiveScalar, Scalar}, - types::{days_ms, NativeType}, -}; -use ordered_float::OrderedFloat; -use std::cmp::Ordering; -use std::convert::{Infallible, TryInto}; -use std::str::FromStr; - -type StringArray = Utf8Array; -type LargeStringArray = Utf8Array; -type SmallBinaryArray = BinaryArray; -type LargeBinaryArray = BinaryArray; -type MutableStringArray = MutableUtf8Array; -type MutableLargeStringArray = MutableUtf8Array; - -// TODO may need to be moved to arrow-rs -/// The max precision and scale for decimal128 -pub(crate) const MAX_PRECISION_FOR_DECIMAL128: usize = 38; -pub(crate) const MAX_SCALE_FOR_DECIMAL128: usize = 38; - -/// Represents a dynamically typed, nullable single value. -/// This is the single-valued counter-part of arrow’s `Array`. -#[derive(Clone)] -pub enum ScalarValue { - /// true or false value - Boolean(Option), - /// 32bit float - Float32(Option), - /// 64bit float - Float64(Option), - /// 128bit decimal, using the i128 to represent the decimal - Decimal128(Option, usize, usize), - /// signed 8bit int - Int8(Option), - /// signed 16bit int - Int16(Option), - /// signed 32bit int - Int32(Option), - /// signed 64bit int - Int64(Option), - /// unsigned 8bit int - UInt8(Option), - /// unsigned 16bit int - UInt16(Option), - /// unsigned 32bit int - UInt32(Option), - /// unsigned 64bit int - UInt64(Option), - /// utf-8 encoded string. - Utf8(Option), - /// utf-8 encoded string representing a LargeString's arrow type. - LargeUtf8(Option), - /// binary - Binary(Option>), - /// large binary - LargeBinary(Option>), - /// list of nested ScalarValue (boxed to reduce size_of(ScalarValue)) - #[allow(clippy::box_collection)] - List(Option>>, Box), - /// Date stored as a signed 32bit int - Date32(Option), - /// Date stored as a signed 64bit int - Date64(Option), - /// Timestamp Second - TimestampSecond(Option, Option), - /// Timestamp Milliseconds - TimestampMillisecond(Option, Option), - /// Timestamp Microseconds - TimestampMicrosecond(Option, Option), - /// Timestamp Nanoseconds - TimestampNanosecond(Option, Option), - /// Interval with YearMonth unit - IntervalYearMonth(Option), - /// Interval with DayTime unit - IntervalDayTime(Option), - /// Interval with MonthDayNano unit - IntervalMonthDayNano(Option), - /// struct of nested ScalarValue (boxed to reduce size_of(ScalarValue)) - #[allow(clippy::box_collection)] - Struct(Option>>, Box>), -} - -// manual implementation of `PartialEq` that uses OrderedFloat to -// get defined behavior for floating point -impl PartialEq for ScalarValue { - fn eq(&self, other: &Self) -> bool { - use ScalarValue::*; - // This purposely doesn't have a catch-all "(_, _)" so that - // any newly added enum variant will require editing this list - // or else face a compile error - match (self, other) { - (Decimal128(v1, p1, s1), Decimal128(v2, p2, s2)) => { - v1.eq(v2) && p1.eq(p2) && s1.eq(s2) - } - (Decimal128(_, _, _), _) => false, - (Boolean(v1), Boolean(v2)) => v1.eq(v2), - (Boolean(_), _) => false, - (Float32(v1), Float32(v2)) => { - let v1 = v1.map(OrderedFloat); - let v2 = v2.map(OrderedFloat); - v1.eq(&v2) - } - (Float32(_), _) => false, - (Float64(v1), Float64(v2)) => { - let v1 = v1.map(OrderedFloat); - let v2 = v2.map(OrderedFloat); - v1.eq(&v2) - } - (Float64(_), _) => false, - (Int8(v1), Int8(v2)) => v1.eq(v2), - (Int8(_), _) => false, - (Int16(v1), Int16(v2)) => v1.eq(v2), - (Int16(_), _) => false, - (Int32(v1), Int32(v2)) => v1.eq(v2), - (Int32(_), _) => false, - (Int64(v1), Int64(v2)) => v1.eq(v2), - (Int64(_), _) => false, - (UInt8(v1), UInt8(v2)) => v1.eq(v2), - (UInt8(_), _) => false, - (UInt16(v1), UInt16(v2)) => v1.eq(v2), - (UInt16(_), _) => false, - (UInt32(v1), UInt32(v2)) => v1.eq(v2), - (UInt32(_), _) => false, - (UInt64(v1), UInt64(v2)) => v1.eq(v2), - (UInt64(_), _) => false, - (Utf8(v1), Utf8(v2)) => v1.eq(v2), - (Utf8(_), _) => false, - (LargeUtf8(v1), LargeUtf8(v2)) => v1.eq(v2), - (LargeUtf8(_), _) => false, - (Binary(v1), Binary(v2)) => v1.eq(v2), - (Binary(_), _) => false, - (LargeBinary(v1), LargeBinary(v2)) => v1.eq(v2), - (LargeBinary(_), _) => false, - (List(v1, t1), List(v2, t2)) => v1.eq(v2) && t1.eq(t2), - (List(_, _), _) => false, - (Date32(v1), Date32(v2)) => v1.eq(v2), - (Date32(_), _) => false, - (Date64(v1), Date64(v2)) => v1.eq(v2), - (Date64(_), _) => false, - (TimestampSecond(v1, _), TimestampSecond(v2, _)) => v1.eq(v2), - (TimestampSecond(_, _), _) => false, - (TimestampMillisecond(v1, _), TimestampMillisecond(v2, _)) => v1.eq(v2), - (TimestampMillisecond(_, _), _) => false, - (TimestampMicrosecond(v1, _), TimestampMicrosecond(v2, _)) => v1.eq(v2), - (TimestampMicrosecond(_, _), _) => false, - (TimestampNanosecond(v1, _), TimestampNanosecond(v2, _)) => v1.eq(v2), - (TimestampNanosecond(_, _), _) => false, - (IntervalYearMonth(v1), IntervalYearMonth(v2)) => v1.eq(v2), - (IntervalYearMonth(_), _) => false, - (IntervalDayTime(v1), IntervalDayTime(v2)) => v1.eq(v2), - (IntervalDayTime(_), _) => false, - (IntervalMonthDayNano(v1), IntervalMonthDayNano(v2)) => v1.eq(v2), - (IntervalMonthDayNano(_), _) => false, - (Struct(v1, t1), Struct(v2, t2)) => v1.eq(v2) && t1.eq(t2), - (Struct(_, _), _) => false, - } - } -} - -// manual implementation of `PartialOrd` that uses OrderedFloat to -// get defined behavior for floating point -impl PartialOrd for ScalarValue { - fn partial_cmp(&self, other: &Self) -> Option { - use ScalarValue::*; - // This purposely doesn't have a catch-all "(_, _)" so that - // any newly added enum variant will require editing this list - // or else face a compile error - match (self, other) { - (Decimal128(v1, p1, s1), Decimal128(v2, p2, s2)) => { - if p1.eq(p2) && s1.eq(s2) { - v1.partial_cmp(v2) - } else { - // Two decimal values can be compared if they have the same precision and scale. - None - } - } - (Decimal128(_, _, _), _) => None, - (Boolean(v1), Boolean(v2)) => v1.partial_cmp(v2), - (Boolean(_), _) => None, - (Float32(v1), Float32(v2)) => { - let v1 = v1.map(OrderedFloat); - let v2 = v2.map(OrderedFloat); - v1.partial_cmp(&v2) - } - (Float32(_), _) => None, - (Float64(v1), Float64(v2)) => { - let v1 = v1.map(OrderedFloat); - let v2 = v2.map(OrderedFloat); - v1.partial_cmp(&v2) - } - (Float64(_), _) => None, - (Int8(v1), Int8(v2)) => v1.partial_cmp(v2), - (Int8(_), _) => None, - (Int16(v1), Int16(v2)) => v1.partial_cmp(v2), - (Int16(_), _) => None, - (Int32(v1), Int32(v2)) => v1.partial_cmp(v2), - (Int32(_), _) => None, - (Int64(v1), Int64(v2)) => v1.partial_cmp(v2), - (Int64(_), _) => None, - (UInt8(v1), UInt8(v2)) => v1.partial_cmp(v2), - (UInt8(_), _) => None, - (UInt16(v1), UInt16(v2)) => v1.partial_cmp(v2), - (UInt16(_), _) => None, - (UInt32(v1), UInt32(v2)) => v1.partial_cmp(v2), - (UInt32(_), _) => None, - (UInt64(v1), UInt64(v2)) => v1.partial_cmp(v2), - (UInt64(_), _) => None, - (Utf8(v1), Utf8(v2)) => v1.partial_cmp(v2), - (Utf8(_), _) => None, - (LargeUtf8(v1), LargeUtf8(v2)) => v1.partial_cmp(v2), - (LargeUtf8(_), _) => None, - (Binary(v1), Binary(v2)) => v1.partial_cmp(v2), - (Binary(_), _) => None, - (LargeBinary(v1), LargeBinary(v2)) => v1.partial_cmp(v2), - (LargeBinary(_), _) => None, - (List(v1, t1), List(v2, t2)) => { - if t1.eq(t2) { - v1.partial_cmp(v2) - } else { - None - } - } - (List(_, _), _) => None, - (Date32(v1), Date32(v2)) => v1.partial_cmp(v2), - (Date32(_), _) => None, - (Date64(v1), Date64(v2)) => v1.partial_cmp(v2), - (Date64(_), _) => None, - (TimestampSecond(v1, _), TimestampSecond(v2, _)) => v1.partial_cmp(v2), - (TimestampSecond(_, _), _) => None, - (TimestampMillisecond(v1, _), TimestampMillisecond(v2, _)) => { - v1.partial_cmp(v2) - } - (TimestampMillisecond(_, _), _) => None, - (TimestampMicrosecond(v1, _), TimestampMicrosecond(v2, _)) => { - v1.partial_cmp(v2) - } - (TimestampMicrosecond(_, _), _) => None, - (TimestampNanosecond(v1, _), TimestampNanosecond(v2, _)) => { - v1.partial_cmp(v2) - } - (TimestampNanosecond(_, _), _) => None, - (IntervalYearMonth(v1), IntervalYearMonth(v2)) => v1.partial_cmp(v2), - (IntervalYearMonth(_), _) => None, - (_, IntervalDayTime(_)) => None, - (IntervalDayTime(_), _) => None, - (IntervalMonthDayNano(v1), IntervalMonthDayNano(v2)) => v1.partial_cmp(v2), - (IntervalMonthDayNano(_), _) => None, - (Struct(v1, t1), Struct(v2, t2)) => { - if t1.eq(t2) { - v1.partial_cmp(v2) - } else { - None - } - } - (Struct(_, _), _) => None, - } - } -} - -impl Eq for ScalarValue {} - -// manual implementation of `Hash` that uses OrderedFloat to -// get defined behavior for floating point -impl std::hash::Hash for ScalarValue { - fn hash(&self, state: &mut H) { - use ScalarValue::*; - match self { - Decimal128(v, p, s) => { - v.hash(state); - p.hash(state); - s.hash(state) - } - Boolean(v) => v.hash(state), - Float32(v) => { - let v = v.map(OrderedFloat); - v.hash(state) - } - Float64(v) => { - let v = v.map(OrderedFloat); - v.hash(state) - } - Int8(v) => v.hash(state), - Int16(v) => v.hash(state), - Int32(v) => v.hash(state), - Int64(v) => v.hash(state), - UInt8(v) => v.hash(state), - UInt16(v) => v.hash(state), - UInt32(v) => v.hash(state), - UInt64(v) => v.hash(state), - Utf8(v) => v.hash(state), - LargeUtf8(v) => v.hash(state), - Binary(v) => v.hash(state), - LargeBinary(v) => v.hash(state), - List(v, t) => { - v.hash(state); - t.hash(state); - } - Date32(v) => v.hash(state), - Date64(v) => v.hash(state), - TimestampSecond(v, _) => v.hash(state), - TimestampMillisecond(v, _) => v.hash(state), - TimestampMicrosecond(v, _) => v.hash(state), - TimestampNanosecond(v, _) => v.hash(state), - IntervalYearMonth(v) => v.hash(state), - IntervalDayTime(v) => v.hash(state), - IntervalMonthDayNano(v) => v.hash(state), - Struct(v, t) => { - v.hash(state); - t.hash(state); - } - } - } -} - -// return the index into the dictionary values for array@index as well -// as a reference to the dictionary values array. Returns None for the -// index if the array is NULL at index -#[inline] -fn get_dict_value( - array: &ArrayRef, - index: usize, -) -> Result<(&ArrayRef, Option)> { - let dict_array = array.as_any().downcast_ref::>().unwrap(); - - // look up the index in the values dictionary - let keys_col = dict_array.keys(); - if !keys_col.is_valid(index) { - return Ok((dict_array.values(), None)); - } - let values_index = keys_col.value(index).to_usize().ok_or_else(|| { - DataFusionError::Internal(format!( - "Can not convert index to usize in dictionary of type creating group by value {:?}", - keys_col.data_type() - )) - })?; - - Ok((dict_array.values(), Some(values_index))) -} - -macro_rules! typed_cast_tz { - ($array:expr, $index:expr, $SCALAR:ident, $TZ:expr) => {{ - let array = $array.as_any().downcast_ref::().unwrap(); - ScalarValue::$SCALAR( - match array.is_null($index) { - true => None, - false => Some(array.value($index).into()), - }, - $TZ.clone(), - ) - }}; -} - -macro_rules! typed_cast { - ($array:expr, $index:expr, $ARRAYTYPE:ident, $SCALAR:ident) => {{ - let array = $array.as_any().downcast_ref::<$ARRAYTYPE>().unwrap(); - ScalarValue::$SCALAR(match array.is_null($index) { - true => None, - false => Some(array.value($index).into()), - }) - }}; -} - -macro_rules! build_list { - ($VALUE_BUILDER_TY:ident, $SCALAR_TY:ident, $VALUES:expr, $SIZE:expr) => {{ - let dt = DataType::List(Box::new(Field::new("item", DataType::$SCALAR_TY, true))); - match $VALUES { - // the return on the macro is necessary, to short-circuit and return ArrayRef - None => { - return Arc::from(new_null_array(dt, $SIZE)); - } - Some(values) => { - let mut array = MutableListArray::::new_from( - <$VALUE_BUILDER_TY>::default(), - dt, - $SIZE, - ); - build_values_list!(array, $SCALAR_TY, values.as_ref(), $SIZE) - } - } - }}; -} - -macro_rules! build_timestamp_list { - ($TIME_UNIT:expr, $VALUES:expr, $SIZE:expr, $TZ:expr) => {{ - let child_dt = DataType::Timestamp($TIME_UNIT, $TZ.clone()); - match $VALUES { - // the return on the macro is necessary, to short-circuit and return ArrayRef - None => { - let null_array: ArrayRef = new_null_array( - DataType::List(Box::new(Field::new("item", child_dt, true))), - $SIZE, - ) - .into(); - null_array - } - Some(values) => { - let values = values.as_ref(); - let empty_arr = ::default().to(child_dt.clone()); - let mut array = MutableListArray::::new_from( - empty_arr, - DataType::List(Box::new(Field::new("item", child_dt, true))), - $SIZE, - ); - - match $TIME_UNIT { - TimeUnit::Second => { - build_values_list_tz!(array, TimestampSecond, values, $SIZE) - } - TimeUnit::Microsecond => { - build_values_list_tz!(array, TimestampMillisecond, values, $SIZE) - } - TimeUnit::Millisecond => { - build_values_list_tz!(array, TimestampMicrosecond, values, $SIZE) - } - TimeUnit::Nanosecond => { - build_values_list_tz!(array, TimestampNanosecond, values, $SIZE) - } - } - } - } - }}; -} - -macro_rules! build_values_list { - ($MUTABLE_ARR:ident, $SCALAR_TY:ident, $VALUES:expr, $SIZE:expr) => {{ - for _ in 0..$SIZE { - let mut vec = vec![]; - for scalar_value in $VALUES { - match scalar_value { - ScalarValue::$SCALAR_TY(v) => { - vec.push(v.clone()); - } - _ => panic!("Incompatible ScalarValue for list"), - }; - } - $MUTABLE_ARR.try_push(Some(vec)).unwrap(); - } - - let array: ListArray = $MUTABLE_ARR.into(); - Arc::new(array) - }}; -} - -macro_rules! dyn_to_array { - ($self:expr, $value:expr, $size:expr, $ty:ty) => {{ - Arc::new(PrimitiveArray::<$ty>::from_data( - $self.get_datatype(), - Buffer::<$ty>::from_iter(repeat(*$value).take($size)), - None, - )) - }}; -} - -macro_rules! build_values_list_tz { - ($MUTABLE_ARR:ident, $SCALAR_TY:ident, $VALUES:expr, $SIZE:expr) => {{ - for _ in 0..$SIZE { - let mut vec = vec![]; - for scalar_value in $VALUES { - match scalar_value { - ScalarValue::$SCALAR_TY(v, _) => { - vec.push(v.clone()); - } - _ => panic!("Incompatible ScalarValue for list"), - }; - } - $MUTABLE_ARR.try_push(Some(vec)).unwrap(); - } - - let array: ListArray = $MUTABLE_ARR.into(); - Arc::new(array) - }}; -} - -macro_rules! eq_array_primitive { - ($array:expr, $index:expr, $ARRAYTYPE:ident, $VALUE:expr) => {{ - let array = $array.as_any().downcast_ref::<$ARRAYTYPE>().unwrap(); - let is_valid = array.is_valid($index); - match $VALUE { - Some(val) => is_valid && &array.value($index) == val, - None => !is_valid, - } - }}; -} - -impl ScalarValue { - /// Create a decimal Scalar from value/precision and scale. - pub fn try_new_decimal128( - value: i128, - precision: usize, - scale: usize, - ) -> Result { - // make sure the precision and scale is valid - if precision <= MAX_PRECISION_FOR_DECIMAL128 && scale <= precision { - return Ok(ScalarValue::Decimal128(Some(value), precision, scale)); - } - return Err(DataFusionError::Internal(format!( - "Can not new a decimal type ScalarValue for precision {} and scale {}", - precision, scale - ))); - } - - /// Getter for the `DataType` of the value - pub fn get_datatype(&self) -> DataType { - match self { - ScalarValue::Boolean(_) => DataType::Boolean, - ScalarValue::UInt8(_) => DataType::UInt8, - ScalarValue::UInt16(_) => DataType::UInt16, - ScalarValue::UInt32(_) => DataType::UInt32, - ScalarValue::UInt64(_) => DataType::UInt64, - ScalarValue::Int8(_) => DataType::Int8, - ScalarValue::Int16(_) => DataType::Int16, - ScalarValue::Int32(_) => DataType::Int32, - ScalarValue::Int64(_) => DataType::Int64, - ScalarValue::Decimal128(_, precision, scale) => { - DataType::Decimal(*precision, *scale) - } - ScalarValue::TimestampSecond(_, tz_opt) => { - DataType::Timestamp(TimeUnit::Second, tz_opt.clone()) - } - ScalarValue::TimestampMillisecond(_, tz_opt) => { - DataType::Timestamp(TimeUnit::Millisecond, tz_opt.clone()) - } - ScalarValue::TimestampMicrosecond(_, tz_opt) => { - DataType::Timestamp(TimeUnit::Microsecond, tz_opt.clone()) - } - ScalarValue::TimestampNanosecond(_, tz_opt) => { - DataType::Timestamp(TimeUnit::Nanosecond, tz_opt.clone()) - } - ScalarValue::Float32(_) => DataType::Float32, - ScalarValue::Float64(_) => DataType::Float64, - ScalarValue::Utf8(_) => DataType::Utf8, - ScalarValue::LargeUtf8(_) => DataType::LargeUtf8, - ScalarValue::Binary(_) => DataType::Binary, - ScalarValue::LargeBinary(_) => DataType::LargeBinary, - ScalarValue::List(_, data_type) => DataType::List(Box::new(Field::new( - "item", - data_type.as_ref().clone(), - true, - ))), - ScalarValue::Date32(_) => DataType::Date32, - ScalarValue::Date64(_) => DataType::Date64, - ScalarValue::IntervalYearMonth(_) => { - DataType::Interval(IntervalUnit::YearMonth) - } - ScalarValue::IntervalDayTime(_) => DataType::Interval(IntervalUnit::DayTime), - ScalarValue::IntervalMonthDayNano(_) => { - DataType::Interval(IntervalUnit::MonthDayNano) - } - ScalarValue::Struct(_, fields) => DataType::Struct(fields.as_ref().clone()), - } - } - - /// Calculate arithmetic negation for a scalar value - pub fn arithmetic_negate(&self) -> Self { - match self { - ScalarValue::Boolean(None) - | ScalarValue::Int8(None) - | ScalarValue::Int16(None) - | ScalarValue::Int32(None) - | ScalarValue::Int64(None) - | ScalarValue::Float32(None) => self.clone(), - ScalarValue::Float64(Some(v)) => ScalarValue::Float64(Some(-v)), - ScalarValue::Float32(Some(v)) => ScalarValue::Float32(Some(-v)), - ScalarValue::Int8(Some(v)) => ScalarValue::Int8(Some(-v)), - ScalarValue::Int16(Some(v)) => ScalarValue::Int16(Some(-v)), - ScalarValue::Int32(Some(v)) => ScalarValue::Int32(Some(-v)), - ScalarValue::Int64(Some(v)) => ScalarValue::Int64(Some(-v)), - ScalarValue::Decimal128(Some(v), precision, scale) => { - ScalarValue::Decimal128(Some(-v), *precision, *scale) - } - _ => panic!("Cannot run arithmetic negate on scalar value: {:?}", self), - } - } - - /// whether this value is null or not. - pub fn is_null(&self) -> bool { - matches!( - *self, - ScalarValue::Boolean(None) - | ScalarValue::UInt8(None) - | ScalarValue::UInt16(None) - | ScalarValue::UInt32(None) - | ScalarValue::UInt64(None) - | ScalarValue::Int8(None) - | ScalarValue::Int16(None) - | ScalarValue::Int32(None) - | ScalarValue::Int64(None) - | ScalarValue::Float32(None) - | ScalarValue::Float64(None) - | ScalarValue::Date32(None) - | ScalarValue::Date64(None) - | ScalarValue::Utf8(None) - | ScalarValue::LargeUtf8(None) - | ScalarValue::List(None, _) - | ScalarValue::TimestampSecond(None, _) - | ScalarValue::TimestampMillisecond(None, _) - | ScalarValue::TimestampMicrosecond(None, _) - | ScalarValue::TimestampNanosecond(None, _) - | ScalarValue::Struct(None, _) - | ScalarValue::Decimal128(None, _, _) // For decimal type, the value is null means ScalarValue::Decimal128 is null. - ) - } - - /// Converts a scalar value into an 1-row array. - pub fn to_array(&self) -> ArrayRef { - self.to_array_of_size(1) - } - - /// Converts an iterator of references [`ScalarValue`] into an [`ArrayRef`] - /// corresponding to those values. For example, - /// - /// Returns an error if the iterator is empty or if the - /// [`ScalarValue`]s are not all the same type - /// - /// Example - /// ``` - /// use datafusion::scalar::ScalarValue; - /// use arrow::array::{BooleanArray, Array}; - /// - /// let scalars = vec![ - /// ScalarValue::Boolean(Some(true)), - /// ScalarValue::Boolean(None), - /// ScalarValue::Boolean(Some(false)), - /// ]; - /// - /// // Build an Array from the list of ScalarValues - /// let array = ScalarValue::iter_to_array(scalars.into_iter()) - /// .unwrap(); - /// - /// let expected: Box = Box::new( - /// BooleanArray::from(vec![ - /// Some(true), - /// None, - /// Some(false) - /// ] - /// )); - /// - /// assert_eq!(&array, &expected); - /// ``` - pub fn iter_to_array( - scalars: impl IntoIterator, - ) -> Result { - let mut scalars = scalars.into_iter().peekable(); - - // figure out the type based on the first element - let data_type = match scalars.peek() { - None => { - return Err(DataFusionError::Internal( - "Empty iterator passed to ScalarValue::iter_to_array".to_string(), - )); - } - Some(sv) => sv.get_datatype(), - }; - - /// Creates an array of $ARRAY_TY by unpacking values of - /// SCALAR_TY for primitive types - macro_rules! build_array_primitive { - ($TY:ty, $SCALAR_TY:ident, $DT:ident) => {{ - { - Arc::new(scalars - .map(|sv| { - if let ScalarValue::$SCALAR_TY(v) = sv { - Ok(v) - } else { - Err(DataFusionError::Internal(format!( - "Inconsistent types in ScalarValue::iter_to_array. \ - Expected {:?}, got {:?}", - data_type, sv - ))) - } - }).collect::>>()?.to($DT) - ) as Arc - } - }}; - } - - macro_rules! build_array_primitive_tz { - ($SCALAR_TY:ident) => {{ - { - let array = scalars - .map(|sv| { - if let ScalarValue::$SCALAR_TY(v, _) = sv { - Ok(v) - } else { - Err(DataFusionError::Internal(format!( - "Inconsistent types in ScalarValue::iter_to_array. \ - Expected {:?}, got {:?}", - data_type, sv - ))) - } - }) - .collect::>()?; - - Arc::new(array) - } - }}; - } - - /// Creates an array of $ARRAY_TY by unpacking values of - /// SCALAR_TY for "string-like" types. - macro_rules! build_array_string { - ($ARRAY_TY:ident, $SCALAR_TY:ident) => {{ - { - let array = scalars - .map(|sv| { - if let ScalarValue::$SCALAR_TY(v) = sv { - Ok(v) - } else { - Err(DataFusionError::Internal(format!( - "Inconsistent types in ScalarValue::iter_to_array. \ - Expected {:?}, got {:?}", - data_type, sv - ))) - } - }) - .collect::>()?; - Arc::new(array) - } - }}; - } - - macro_rules! build_array_list { - ($MUTABLE_TY:ty, $SCALAR_TY:ident) => {{ - let mut array = MutableListArray::::new(); - for scalar in scalars.into_iter() { - match scalar { - ScalarValue::List(Some(xs), _) => { - let xs = *xs; - let mut vec = vec![]; - for s in xs { - match s { - ScalarValue::$SCALAR_TY(o) => { vec.push(o) } - sv => return Err(DataFusionError::Internal(format!( - "Inconsistent types in ScalarValue::iter_to_array. \ - Expected Utf8, got {:?}", - sv - ))), - } - } - array.try_push(Some(vec))?; - } - ScalarValue::List(None, _) => { - array.push_null(); - } - sv => { - return Err(DataFusionError::Internal(format!( - "Inconsistent types in ScalarValue::iter_to_array. \ - Expected List, got {:?}", - sv - ))) - } - } - } - - let array: ListArray = array.into(); - Arc::new(array) - }} - } - - use DataType::*; - let array: Arc = match &data_type { - DataType::Decimal(precision, scale) => { - let decimal_array = - ScalarValue::iter_to_decimal_array(scalars, precision, scale)?; - Arc::new(decimal_array) - } - DataType::Boolean => Arc::new( - scalars - .map(|sv| { - if let ScalarValue::Boolean(v) = sv { - Ok(v) - } else { - Err(DataFusionError::Internal(format!( - "Inconsistent types in ScalarValue::iter_to_array. \ - Expected {:?}, got {:?}", - data_type, sv - ))) - } - }) - .collect::>()?, - ), - Float32 => { - build_array_primitive!(f32, Float32, Float32) - } - Float64 => { - build_array_primitive!(f64, Float64, Float64) - } - Int8 => build_array_primitive!(i8, Int8, Int8), - Int16 => build_array_primitive!(i16, Int16, Int16), - Int32 => build_array_primitive!(i32, Int32, Int32), - Int64 => build_array_primitive!(i64, Int64, Int64), - UInt8 => build_array_primitive!(u8, UInt8, UInt8), - UInt16 => build_array_primitive!(u16, UInt16, UInt16), - UInt32 => build_array_primitive!(u32, UInt32, UInt32), - UInt64 => build_array_primitive!(u64, UInt64, UInt64), - Utf8 => build_array_string!(StringArray, Utf8), - LargeUtf8 => build_array_string!(LargeStringArray, LargeUtf8), - Binary => build_array_string!(SmallBinaryArray, Binary), - LargeBinary => build_array_string!(LargeBinaryArray, LargeBinary), - Date32 => build_array_primitive!(i32, Date32, Date32), - Date64 => build_array_primitive!(i64, Date64, Date64), - Timestamp(TimeUnit::Second, _) => { - build_array_primitive_tz!(TimestampSecond) - } - Timestamp(TimeUnit::Millisecond, _) => { - build_array_primitive_tz!(TimestampMillisecond) - } - Timestamp(TimeUnit::Microsecond, _) => { - build_array_primitive_tz!(TimestampMicrosecond) - } - Timestamp(TimeUnit::Nanosecond, _) => { - build_array_primitive_tz!(TimestampNanosecond) - } - Interval(IntervalUnit::DayTime) => { - build_array_primitive!(days_ms, IntervalDayTime, data_type) - } - Interval(IntervalUnit::YearMonth) => { - build_array_primitive!(i32, IntervalYearMonth, data_type) - } - DataType::List(fields) if fields.data_type() == &DataType::Int8 => { - build_array_list!(Int8Vec, Int8) - } - DataType::List(fields) if fields.data_type() == &DataType::Int16 => { - build_array_list!(Int16Vec, Int16) - } - DataType::List(fields) if fields.data_type() == &DataType::Int32 => { - build_array_list!(Int32Vec, Int32) - } - DataType::List(fields) if fields.data_type() == &DataType::Int64 => { - build_array_list!(Int64Vec, Int64) - } - DataType::List(fields) if fields.data_type() == &DataType::UInt8 => { - build_array_list!(UInt8Vec, UInt8) - } - DataType::List(fields) if fields.data_type() == &DataType::UInt16 => { - build_array_list!(UInt16Vec, UInt16) - } - DataType::List(fields) if fields.data_type() == &DataType::UInt32 => { - build_array_list!(UInt32Vec, UInt32) - } - DataType::List(fields) if fields.data_type() == &DataType::UInt64 => { - build_array_list!(UInt64Vec, UInt64) - } - DataType::List(fields) if fields.data_type() == &DataType::Float32 => { - build_array_list!(Float32Vec, Float32) - } - DataType::List(fields) if fields.data_type() == &DataType::Float64 => { - build_array_list!(Float64Vec, Float64) - } - DataType::List(fields) if fields.data_type() == &DataType::Utf8 => { - build_array_list!(MutableStringArray, Utf8) - } - DataType::List(fields) if fields.data_type() == &DataType::LargeUtf8 => { - build_array_list!(MutableLargeStringArray, LargeUtf8) - } - DataType::List(_) => { - // Fallback case handling homogeneous lists with any ScalarValue element type - let list_array = ScalarValue::iter_to_array_list(scalars, &data_type)?; - Arc::new(list_array) - } - DataType::Struct(fields) => { - // Initialize a Vector to store the ScalarValues for each column - let mut columns: Vec> = - (0..fields.len()).map(|_| Vec::new()).collect(); - - // Iterate over scalars to populate the column scalars for each row - for scalar in scalars { - if let ScalarValue::Struct(values, fields) = scalar { - match values { - Some(values) => { - // Push value for each field - for c in 0..columns.len() { - let column = columns.get_mut(c).unwrap(); - column.push(values[c].clone()); - } - } - None => { - // Push NULL of the appropriate type for each field - for c in 0..columns.len() { - let dtype = fields[c].data_type(); - let column = columns.get_mut(c).unwrap(); - column.push(ScalarValue::try_from(dtype)?); - } - } - }; - } else { - return Err(DataFusionError::Internal(format!( - "Expected Struct but found: {}", - scalar - ))); - }; - } - - // Call iter_to_array recursively to convert the scalars for each column into Arrow arrays - let field_values = columns - .iter() - .map(|c| Self::iter_to_array(c.clone()).map(Arc::from)) - .collect::>>()?; - - Arc::new(StructArray::from_data(data_type, field_values, None)) - } - _ => { - return Err(DataFusionError::Internal(format!( - "Unsupported creation of {:?} array from ScalarValue {:?}", - data_type, - scalars.peek() - ))); - } - }; - - Ok(array) - } - - fn iter_to_decimal_array( - scalars: impl IntoIterator, - precision: &usize, - scale: &usize, - ) -> Result { - // collect the value as Option - let array = scalars - .into_iter() - .map(|element: ScalarValue| match element { - ScalarValue::Decimal128(v1, _, _) => v1, - _ => unreachable!(), - }) - .collect::>>(); - - // build the decimal array using the Decimal Builder - Ok(Int128Vec::from(array) - .to(Decimal(*precision, *scale)) - .into()) - } - - fn iter_to_array_list( - scalars: impl IntoIterator, - data_type: &DataType, - ) -> Result> { - let mut offsets: Vec = vec![0]; - - let mut elements: Vec = Vec::new(); - let mut valid: Vec = vec![]; - - let mut flat_len = 0i32; - for scalar in scalars { - if let ScalarValue::List(values, _) = scalar { - match values { - Some(values) => { - let element_array = ScalarValue::iter_to_array(*values)?; - - // Add new offset index - flat_len += element_array.len() as i32; - offsets.push(flat_len); - - elements.push(element_array); - - // Element is valid - valid.push(true); - } - None => { - // Repeat previous offset index - offsets.push(flat_len); - - // Element is null - valid.push(false); - } - } - } else { - return Err(DataFusionError::Internal(format!( - "Expected ScalarValue::List element. Received {:?}", - scalar - ))); - } - } - - // Concatenate element arrays to create single flat array - let element_arrays: Vec<&dyn Array> = - elements.iter().map(|a| a.as_ref()).collect(); - let flat_array = match concatenate::concatenate(&element_arrays) { - Ok(flat_array) => flat_array, - Err(err) => return Err(DataFusionError::ArrowError(err)), - }; - - let list_array = ListArray::::from_data( - data_type.clone(), - Buffer::from(offsets), - flat_array.into(), - Some(Bitmap::from(valid)), - ); - - Ok(list_array) - } +//! ScalarValue reimported from datafusion-common - /// Converts a scalar value into an array of `size` rows. - pub fn to_array_of_size(&self, size: usize) -> ArrayRef { - match self { - ScalarValue::Decimal128(e, precision, scale) => { - Int128Vec::from_iter(repeat(e).take(size)) - .to(Decimal(*precision, *scale)) - .into_arc() - } - ScalarValue::Boolean(e) => { - Arc::new(BooleanArray::from(vec![*e; size])) as ArrayRef - } - ScalarValue::Float64(e) => match e { - Some(value) => { - dyn_to_array!(self, value, size, f64) - } - None => new_null_array(self.get_datatype(), size).into(), - }, - ScalarValue::Float32(e) => match e { - Some(value) => dyn_to_array!(self, value, size, f32), - None => new_null_array(self.get_datatype(), size).into(), - }, - ScalarValue::Int8(e) => match e { - Some(value) => dyn_to_array!(self, value, size, i8), - None => new_null_array(self.get_datatype(), size).into(), - }, - ScalarValue::Int16(e) => match e { - Some(value) => dyn_to_array!(self, value, size, i16), - None => new_null_array(self.get_datatype(), size).into(), - }, - ScalarValue::Int32(e) - | ScalarValue::Date32(e) - | ScalarValue::IntervalYearMonth(e) => match e { - Some(value) => dyn_to_array!(self, value, size, i32), - None => new_null_array(self.get_datatype(), size).into(), - }, - ScalarValue::IntervalMonthDayNano(e) => match e { - Some(value) => dyn_to_array!(self, value, size, i128), - None => new_null_array(self.get_datatype(), size).into(), - }, - ScalarValue::Int64(e) | ScalarValue::Date64(e) => match e { - Some(value) => dyn_to_array!(self, value, size, i64), - None => new_null_array(self.get_datatype(), size).into(), - }, - ScalarValue::UInt8(e) => match e { - Some(value) => dyn_to_array!(self, value, size, u8), - None => new_null_array(self.get_datatype(), size).into(), - }, - ScalarValue::UInt16(e) => match e { - Some(value) => dyn_to_array!(self, value, size, u16), - None => new_null_array(self.get_datatype(), size).into(), - }, - ScalarValue::UInt32(e) => match e { - Some(value) => dyn_to_array!(self, value, size, u32), - None => new_null_array(self.get_datatype(), size).into(), - }, - ScalarValue::UInt64(e) => match e { - Some(value) => dyn_to_array!(self, value, size, u64), - None => new_null_array(self.get_datatype(), size).into(), - }, - ScalarValue::TimestampSecond(e, _) => match e { - Some(value) => dyn_to_array!(self, value, size, i64), - None => new_null_array(self.get_datatype(), size).into(), - }, - ScalarValue::TimestampMillisecond(e, _) => match e { - Some(value) => dyn_to_array!(self, value, size, i64), - None => new_null_array(self.get_datatype(), size).into(), - }, - - ScalarValue::TimestampMicrosecond(e, _) => match e { - Some(value) => dyn_to_array!(self, value, size, i64), - None => new_null_array(self.get_datatype(), size).into(), - }, - ScalarValue::TimestampNanosecond(e, _) => match e { - Some(value) => dyn_to_array!(self, value, size, i64), - None => new_null_array(self.get_datatype(), size).into(), - }, - ScalarValue::Utf8(e) => match e { - Some(value) => Arc::new(Utf8Array::::from_trusted_len_values_iter( - repeat(&value).take(size), - )), - None => new_null_array(self.get_datatype(), size).into(), - }, - ScalarValue::LargeUtf8(e) => match e { - Some(value) => Arc::new(Utf8Array::::from_trusted_len_values_iter( - repeat(&value).take(size), - )), - None => new_null_array(self.get_datatype(), size).into(), - }, - ScalarValue::Binary(e) => match e { - Some(value) => Arc::new( - repeat(Some(value.as_slice())) - .take(size) - .collect::>(), - ), - None => new_null_array(self.get_datatype(), size).into(), - }, - ScalarValue::LargeBinary(e) => match e { - Some(value) => Arc::new( - repeat(Some(value.as_slice())) - .take(size) - .collect::>(), - ), - None => new_null_array(self.get_datatype(), size).into(), - }, - ScalarValue::List(values, data_type) => match data_type.as_ref() { - DataType::Boolean => { - build_list!(MutableBooleanArray, Boolean, values, size) - } - DataType::Int8 => build_list!(Int8Vec, Int8, values, size), - DataType::Int16 => build_list!(Int16Vec, Int16, values, size), - DataType::Int32 => build_list!(Int32Vec, Int32, values, size), - DataType::Int64 => build_list!(Int64Vec, Int64, values, size), - DataType::UInt8 => build_list!(UInt8Vec, UInt8, values, size), - DataType::UInt16 => build_list!(UInt16Vec, UInt16, values, size), - DataType::UInt32 => build_list!(UInt32Vec, UInt32, values, size), - DataType::UInt64 => build_list!(UInt64Vec, UInt64, values, size), - DataType::Float32 => build_list!(Float32Vec, Float32, values, size), - DataType::Float64 => build_list!(Float64Vec, Float64, values, size), - DataType::Timestamp(unit, tz) => { - build_timestamp_list!(*unit, values, size, tz.clone()) - } - DataType::Utf8 => build_list!(MutableStringArray, Utf8, values, size), - DataType::LargeUtf8 => { - build_list!(MutableLargeStringArray, LargeUtf8, values, size) - } - dt => panic!("Unexpected DataType for list {:?}", dt), - }, - ScalarValue::IntervalDayTime(e) => match e { - Some(value) => { - Arc::new(PrimitiveArray::::from_trusted_len_values_iter( - std::iter::repeat(*value).take(size), - )) - } - None => new_null_array(self.get_datatype(), size).into(), - }, - ScalarValue::Struct(values, _) => match values { - Some(values) => { - let field_values = - values.iter().map(|v| v.to_array_of_size(size)).collect(); - Arc::new(StructArray::from_data( - self.get_datatype(), - field_values, - None, - )) - } - None => Arc::new(StructArray::new_null(self.get_datatype(), size)), - }, - } - } - - fn get_decimal_value_from_array( - array: &ArrayRef, - index: usize, - precision: &usize, - scale: &usize, - ) -> ScalarValue { - let array = array.as_any().downcast_ref::().unwrap(); - if array.is_null(index) { - ScalarValue::Decimal128(None, *precision, *scale) - } else { - ScalarValue::Decimal128(Some(array.value(index)), *precision, *scale) - } - } - - /// Converts a value in `array` at `index` into a ScalarValue - pub fn try_from_array(array: &ArrayRef, index: usize) -> Result { - // handle NULL value - if !array.is_valid(index) { - return array.data_type().try_into(); - } - - Ok(match array.data_type() { - DataType::Decimal(precision, scale) => { - ScalarValue::get_decimal_value_from_array(array, index, precision, scale) - } - DataType::Boolean => typed_cast!(array, index, BooleanArray, Boolean), - DataType::Float64 => typed_cast!(array, index, Float64Array, Float64), - DataType::Float32 => typed_cast!(array, index, Float32Array, Float32), - DataType::UInt64 => typed_cast!(array, index, UInt64Array, UInt64), - DataType::UInt32 => typed_cast!(array, index, UInt32Array, UInt32), - DataType::UInt16 => typed_cast!(array, index, UInt16Array, UInt16), - DataType::UInt8 => typed_cast!(array, index, UInt8Array, UInt8), - DataType::Int64 => typed_cast!(array, index, Int64Array, Int64), - DataType::Int32 => typed_cast!(array, index, Int32Array, Int32), - DataType::Int16 => typed_cast!(array, index, Int16Array, Int16), - DataType::Int8 => typed_cast!(array, index, Int8Array, Int8), - DataType::Binary => typed_cast!(array, index, SmallBinaryArray, Binary), - DataType::LargeBinary => { - typed_cast!(array, index, LargeBinaryArray, LargeBinary) - } - DataType::Utf8 => typed_cast!(array, index, StringArray, Utf8), - DataType::LargeUtf8 => typed_cast!(array, index, LargeStringArray, LargeUtf8), - DataType::List(nested_type) => { - let list_array = array - .as_any() - .downcast_ref::>() - .ok_or_else(|| { - DataFusionError::Internal( - "Failed to downcast ListArray".to_string(), - ) - })?; - let value = match list_array.is_null(index) { - true => None, - false => { - let nested_array = ArrayRef::from(list_array.value(index)); - let scalar_vec = (0..nested_array.len()) - .map(|i| ScalarValue::try_from_array(&nested_array, i)) - .collect::>>()?; - Some(scalar_vec) - } - }; - let value = value.map(Box::new); - let data_type = Box::new(nested_type.data_type().clone()); - ScalarValue::List(value, data_type) - } - DataType::Date32 => { - typed_cast!(array, index, Int32Array, Date32) - } - DataType::Date64 => { - typed_cast!(array, index, Int64Array, Date64) - } - DataType::Timestamp(TimeUnit::Second, tz_opt) => { - typed_cast_tz!(array, index, TimestampSecond, tz_opt) - } - DataType::Timestamp(TimeUnit::Millisecond, tz_opt) => { - typed_cast_tz!(array, index, TimestampMillisecond, tz_opt) - } - DataType::Timestamp(TimeUnit::Microsecond, tz_opt) => { - typed_cast_tz!(array, index, TimestampMicrosecond, tz_opt) - } - DataType::Timestamp(TimeUnit::Nanosecond, tz_opt) => { - typed_cast_tz!(array, index, TimestampNanosecond, tz_opt) - } - DataType::Dictionary(index_type, _, _) => { - let (values, values_index) = match index_type { - IntegerType::Int8 => get_dict_value::(array, index)?, - IntegerType::Int16 => get_dict_value::(array, index)?, - IntegerType::Int32 => get_dict_value::(array, index)?, - IntegerType::Int64 => get_dict_value::(array, index)?, - IntegerType::UInt8 => get_dict_value::(array, index)?, - IntegerType::UInt16 => get_dict_value::(array, index)?, - IntegerType::UInt32 => get_dict_value::(array, index)?, - IntegerType::UInt64 => get_dict_value::(array, index)?, - }; - - match values_index { - Some(values_index) => Self::try_from_array(values, values_index)?, - // was null - None => values.data_type().try_into()?, - } - } - DataType::Struct(fields) => { - let array = - array - .as_any() - .downcast_ref::() - .ok_or_else(|| { - DataFusionError::Internal( - "Failed to downcast ArrayRef to StructArray".to_string(), - ) - })?; - let mut field_values: Vec = Vec::new(); - for col_index in 0..array.num_columns() { - let col_array = &array.values()[col_index]; - let col_scalar = ScalarValue::try_from_array(col_array, index)?; - field_values.push(col_scalar); - } - Self::Struct(Some(Box::new(field_values)), Box::new(fields.clone())) - } - other => { - return Err(DataFusionError::NotImplemented(format!( - "Can't create a scalar from array of type \"{:?}\"", - other - ))); - } - }) - } - - fn eq_array_decimal( - array: &ArrayRef, - index: usize, - value: &Option, - precision: usize, - scale: usize, - ) -> bool { - let array = array.as_any().downcast_ref::().unwrap(); - match array.data_type() { - Decimal(pre, sca) => { - if *pre != precision || *sca != scale { - return false; - } - } - _ => return false, - } - match value { - None => array.is_null(index), - Some(v) => !array.is_null(index) && array.value(index) == *v, - } - } - - /// Compares a single row of array @ index for equality with self, - /// in an optimized fashion. - /// - /// This method implements an optimized version of: - /// - /// ```text - /// let arr_scalar = Self::try_from_array(array, index).unwrap(); - /// arr_scalar.eq(self) - /// ``` - /// - /// *Performance note*: the arrow compute kernels should be - /// preferred over this function if at all possible as they can be - /// vectorized and are generally much faster. - /// - /// This function has a few narrow usescases such as hash table key - /// comparisons where comparing a single row at a time is necessary. - #[inline] - pub fn eq_array(&self, array: &ArrayRef, index: usize) -> bool { - if let DataType::Dictionary(key_type, _, _) = array.data_type() { - return self.eq_array_dictionary(array, index, key_type); - } - - match self { - ScalarValue::Decimal128(v, precision, scale) => { - ScalarValue::eq_array_decimal(array, index, v, *precision, *scale) - } - ScalarValue::Boolean(val) => { - eq_array_primitive!(array, index, BooleanArray, val) - } - ScalarValue::Float32(val) => { - eq_array_primitive!(array, index, Float32Array, val) - } - ScalarValue::Float64(val) => { - eq_array_primitive!(array, index, Float64Array, val) - } - ScalarValue::Int8(val) => eq_array_primitive!(array, index, Int8Array, val), - ScalarValue::Int16(val) => eq_array_primitive!(array, index, Int16Array, val), - ScalarValue::Int32(val) => eq_array_primitive!(array, index, Int32Array, val), - ScalarValue::Int64(val) => eq_array_primitive!(array, index, Int64Array, val), - ScalarValue::UInt8(val) => eq_array_primitive!(array, index, UInt8Array, val), - ScalarValue::UInt16(val) => { - eq_array_primitive!(array, index, UInt16Array, val) - } - ScalarValue::UInt32(val) => { - eq_array_primitive!(array, index, UInt32Array, val) - } - ScalarValue::UInt64(val) => { - eq_array_primitive!(array, index, UInt64Array, val) - } - ScalarValue::Utf8(val) => eq_array_primitive!(array, index, StringArray, val), - ScalarValue::LargeUtf8(val) => { - eq_array_primitive!(array, index, LargeStringArray, val) - } - ScalarValue::Binary(val) => { - eq_array_primitive!(array, index, SmallBinaryArray, val) - } - ScalarValue::LargeBinary(val) => { - eq_array_primitive!(array, index, LargeBinaryArray, val) - } - ScalarValue::List(_, _) => unimplemented!(), - ScalarValue::Date32(val) => { - eq_array_primitive!(array, index, Int32Array, val) - } - ScalarValue::Date64(val) => { - eq_array_primitive!(array, index, Int64Array, val) - } - ScalarValue::TimestampSecond(val, _) => { - eq_array_primitive!(array, index, Int64Array, val) - } - ScalarValue::TimestampMillisecond(val, _) => { - eq_array_primitive!(array, index, Int64Array, val) - } - ScalarValue::TimestampMicrosecond(val, _) => { - eq_array_primitive!(array, index, Int64Array, val) - } - ScalarValue::TimestampNanosecond(val, _) => { - eq_array_primitive!(array, index, Int64Array, val) - } - ScalarValue::IntervalYearMonth(val) => { - eq_array_primitive!(array, index, Int32Array, val) - } - ScalarValue::IntervalDayTime(val) => { - eq_array_primitive!(array, index, DaysMsArray, val) - } - ScalarValue::IntervalMonthDayNano(val) => { - eq_array_primitive!(array, index, Int128Array, val) - } - ScalarValue::Struct(_, _) => unimplemented!(), - } - } - - /// Compares a dictionary array with indexes of type `key_type` - /// with the array @ index for equality with self - fn eq_array_dictionary( - &self, - array: &ArrayRef, - index: usize, - key_type: &IntegerType, - ) -> bool { - let (values, values_index) = match key_type { - IntegerType::Int8 => get_dict_value::(array, index).unwrap(), - IntegerType::Int16 => get_dict_value::(array, index).unwrap(), - IntegerType::Int32 => get_dict_value::(array, index).unwrap(), - IntegerType::Int64 => get_dict_value::(array, index).unwrap(), - IntegerType::UInt8 => get_dict_value::(array, index).unwrap(), - IntegerType::UInt16 => get_dict_value::(array, index).unwrap(), - IntegerType::UInt32 => get_dict_value::(array, index).unwrap(), - IntegerType::UInt64 => get_dict_value::(array, index).unwrap(), - }; - - match values_index { - Some(values_index) => self.eq_array(values, values_index), - None => self.is_null(), - } - } -} - -macro_rules! impl_scalar { - ($ty:ty, $scalar:tt) => { - impl From<$ty> for ScalarValue { - fn from(value: $ty) -> Self { - ScalarValue::$scalar(Some(value)) - } - } - - impl From> for ScalarValue { - fn from(value: Option<$ty>) -> Self { - ScalarValue::$scalar(value) - } - } - }; -} - -impl_scalar!(f64, Float64); -impl_scalar!(f32, Float32); -impl_scalar!(i8, Int8); -impl_scalar!(i16, Int16); -impl_scalar!(i32, Int32); -impl_scalar!(i64, Int64); -impl_scalar!(bool, Boolean); -impl_scalar!(u8, UInt8); -impl_scalar!(u16, UInt16); -impl_scalar!(u32, UInt32); -impl_scalar!(u64, UInt64); - -impl From<&str> for ScalarValue { - fn from(value: &str) -> Self { - Some(value).into() - } -} - -impl From> for ScalarValue { - fn from(value: Option<&str>) -> Self { - let value = value.map(|s| s.to_string()); - ScalarValue::Utf8(value) - } -} - -impl FromStr for ScalarValue { - type Err = Infallible; - - fn from_str(s: &str) -> std::result::Result { - Ok(s.into()) - } -} - -impl From> for ScalarValue { - fn from(value: Vec<(&str, ScalarValue)>) -> Self { - let (fields, scalars): (Vec<_>, Vec<_>) = value - .into_iter() - .map(|(name, scalar)| { - (Field::new(name, scalar.get_datatype(), false), scalar) - }) - .unzip(); - - Self::Struct(Some(Box::new(scalars)), Box::new(fields)) - } -} - -macro_rules! impl_try_from { - ($SCALAR:ident, $NATIVE:ident) => { - impl TryFrom for $NATIVE { - type Error = DataFusionError; - - fn try_from(value: ScalarValue) -> Result { - match value { - ScalarValue::$SCALAR(Some(inner_value)) => Ok(inner_value), - _ => Err(DataFusionError::Internal(format!( - "Cannot convert {:?} to {}", - value, - std::any::type_name::() - ))), - } - } - } - }; -} - -impl_try_from!(Int8, i8); -impl_try_from!(Int16, i16); - -// special implementation for i32 because of Date32 -impl TryFrom for i32 { - type Error = DataFusionError; - - fn try_from(value: ScalarValue) -> Result { - match value { - ScalarValue::Int32(Some(inner_value)) - | ScalarValue::Date32(Some(inner_value)) => Ok(inner_value), - _ => Err(DataFusionError::Internal(format!( - "Cannot convert {:?} to {}", - value, - std::any::type_name::() - ))), - } - } -} - -// special implementation for i64 because of TimeNanosecond -impl TryFrom for i64 { - type Error = DataFusionError; - - fn try_from(value: ScalarValue) -> Result { - match value { - ScalarValue::Int64(Some(inner_value)) - | ScalarValue::Date64(Some(inner_value)) - | ScalarValue::TimestampNanosecond(Some(inner_value), _) - | ScalarValue::TimestampMicrosecond(Some(inner_value), _) - | ScalarValue::TimestampMillisecond(Some(inner_value), _) - | ScalarValue::TimestampSecond(Some(inner_value), _) => Ok(inner_value), - _ => Err(DataFusionError::Internal(format!( - "Cannot convert {:?} to {}", - value, - std::any::type_name::() - ))), - } - } -} - -// special implementation for i128 because of Decimal128 -impl TryFrom for i128 { - type Error = DataFusionError; - - fn try_from(value: ScalarValue) -> Result { - match value { - ScalarValue::Decimal128(Some(inner_value), _, _) => Ok(inner_value), - _ => Err(DataFusionError::Internal(format!( - "Cannot convert {:?} to {}", - value, - std::any::type_name::() - ))), - } - } -} - -impl_try_from!(UInt8, u8); -impl_try_from!(UInt16, u16); -impl_try_from!(UInt32, u32); -impl_try_from!(UInt64, u64); -impl_try_from!(Float32, f32); -impl_try_from!(Float64, f64); -impl_try_from!(Boolean, bool); - -impl TryInto> for &ScalarValue { - type Error = DataFusionError; - - fn try_into(self) -> Result> { - use arrow::scalar::*; - match self { - ScalarValue::Boolean(b) => Ok(Box::new(BooleanScalar::new(*b))), - ScalarValue::Float32(f) => { - Ok(Box::new(PrimitiveScalar::::new(DataType::Float32, *f))) - } - ScalarValue::Float64(f) => { - Ok(Box::new(PrimitiveScalar::::new(DataType::Float64, *f))) - } - ScalarValue::Int8(i) => { - Ok(Box::new(PrimitiveScalar::::new(DataType::Int8, *i))) - } - ScalarValue::Int16(i) => { - Ok(Box::new(PrimitiveScalar::::new(DataType::Int16, *i))) - } - ScalarValue::Int32(i) => { - Ok(Box::new(PrimitiveScalar::::new(DataType::Int32, *i))) - } - ScalarValue::Int64(i) => { - Ok(Box::new(PrimitiveScalar::::new(DataType::Int64, *i))) - } - ScalarValue::UInt8(u) => { - Ok(Box::new(PrimitiveScalar::::new(DataType::UInt8, *u))) - } - ScalarValue::UInt16(u) => { - Ok(Box::new(PrimitiveScalar::::new(DataType::UInt16, *u))) - } - ScalarValue::UInt32(u) => { - Ok(Box::new(PrimitiveScalar::::new(DataType::UInt32, *u))) - } - ScalarValue::UInt64(u) => { - Ok(Box::new(PrimitiveScalar::::new(DataType::UInt64, *u))) - } - ScalarValue::Utf8(s) => Ok(Box::new(Utf8Scalar::::new(s.clone()))), - ScalarValue::LargeUtf8(s) => Ok(Box::new(Utf8Scalar::::new(s.clone()))), - ScalarValue::Binary(b) => Ok(Box::new(BinaryScalar::::new(b.clone()))), - ScalarValue::LargeBinary(b) => { - Ok(Box::new(BinaryScalar::::new(b.clone()))) - } - ScalarValue::Date32(i) => { - Ok(Box::new(PrimitiveScalar::::new(DataType::Date32, *i))) - } - ScalarValue::Date64(i) => { - Ok(Box::new(PrimitiveScalar::::new(DataType::Date64, *i))) - } - ScalarValue::TimestampSecond(i, tz) => { - Ok(Box::new(PrimitiveScalar::::new( - DataType::Timestamp(TimeUnit::Second, tz.clone()), - *i, - ))) - } - ScalarValue::TimestampMillisecond(i, tz) => { - Ok(Box::new(PrimitiveScalar::::new( - DataType::Timestamp(TimeUnit::Millisecond, tz.clone()), - *i, - ))) - } - ScalarValue::TimestampMicrosecond(i, tz) => { - Ok(Box::new(PrimitiveScalar::::new( - DataType::Timestamp(TimeUnit::Microsecond, tz.clone()), - *i, - ))) - } - ScalarValue::TimestampNanosecond(i, tz) => { - Ok(Box::new(PrimitiveScalar::::new( - DataType::Timestamp(TimeUnit::Nanosecond, tz.clone()), - *i, - ))) - } - ScalarValue::IntervalYearMonth(i) => { - Ok(Box::new(PrimitiveScalar::::new( - DataType::Interval(IntervalUnit::YearMonth), - *i, - ))) - } - - // List and IntervalDayTime comparison not possible in arrow2 - _ => Err(DataFusionError::Internal( - "Conversion not possible in arrow2".to_owned(), - )), - } - } -} - -impl TryFrom> for ScalarValue { - type Error = DataFusionError; - - fn try_from(s: PrimitiveScalar) -> Result { - match s.data_type() { - DataType::Timestamp(TimeUnit::Second, tz) => { - let s = s.as_any().downcast_ref::>().unwrap(); - Ok(ScalarValue::TimestampSecond(s.value(), tz.clone())) - } - DataType::Timestamp(TimeUnit::Microsecond, tz) => { - let s = s.as_any().downcast_ref::>().unwrap(); - Ok(ScalarValue::TimestampMicrosecond(s.value(), tz.clone())) - } - DataType::Timestamp(TimeUnit::Millisecond, tz) => { - let s = s.as_any().downcast_ref::>().unwrap(); - Ok(ScalarValue::TimestampMillisecond(s.value(), tz.clone())) - } - DataType::Timestamp(TimeUnit::Nanosecond, tz) => { - let s = s.as_any().downcast_ref::>().unwrap(); - Ok(ScalarValue::TimestampNanosecond(s.value(), tz.clone())) - } - _ => Err(DataFusionError::Internal( - format!( - "Conversion from arrow Scalar to Datafusion ScalarValue not implemented for: {:?}", s)) - ), - } - } -} - -impl TryFrom<&DataType> for ScalarValue { - type Error = DataFusionError; - - /// Create a Null instance of ScalarValue for this datatype - fn try_from(datatype: &DataType) -> Result { - Ok(match datatype { - DataType::Boolean => ScalarValue::Boolean(None), - DataType::Float64 => ScalarValue::Float64(None), - DataType::Float32 => ScalarValue::Float32(None), - DataType::Int8 => ScalarValue::Int8(None), - DataType::Int16 => ScalarValue::Int16(None), - DataType::Int32 => ScalarValue::Int32(None), - DataType::Int64 => ScalarValue::Int64(None), - DataType::UInt8 => ScalarValue::UInt8(None), - DataType::UInt16 => ScalarValue::UInt16(None), - DataType::UInt32 => ScalarValue::UInt32(None), - DataType::UInt64 => ScalarValue::UInt64(None), - DataType::Decimal(precision, scale) => { - ScalarValue::Decimal128(None, *precision, *scale) - } - DataType::Utf8 => ScalarValue::Utf8(None), - DataType::LargeUtf8 => ScalarValue::LargeUtf8(None), - DataType::Date32 => ScalarValue::Date32(None), - DataType::Date64 => ScalarValue::Date64(None), - DataType::Timestamp(TimeUnit::Second, tz_opt) => { - ScalarValue::TimestampSecond(None, tz_opt.clone()) - } - DataType::Timestamp(TimeUnit::Millisecond, tz_opt) => { - ScalarValue::TimestampMillisecond(None, tz_opt.clone()) - } - DataType::Timestamp(TimeUnit::Microsecond, tz_opt) => { - ScalarValue::TimestampMicrosecond(None, tz_opt.clone()) - } - DataType::Timestamp(TimeUnit::Nanosecond, tz_opt) => { - ScalarValue::TimestampNanosecond(None, tz_opt.clone()) - } - DataType::Dictionary(_index_type, value_type, _) => { - value_type.as_ref().try_into()? - } - DataType::List(ref nested_type) => { - ScalarValue::List(None, Box::new(nested_type.data_type().clone())) - } - DataType::Struct(fields) => { - ScalarValue::Struct(None, Box::new(fields.clone())) - } - _ => { - return Err(DataFusionError::NotImplemented(format!( - "Can't create a scalar from data_type \"{:?}\"", - datatype - ))); - } - }) - } -} - -macro_rules! format_option { - ($F:expr, $EXPR:expr) => {{ - match $EXPR { - Some(e) => write!($F, "{}", e), - None => write!($F, "NULL"), - } - }}; -} - -impl fmt::Display for ScalarValue { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - match self { - ScalarValue::Decimal128(v, p, s) => { - write!(f, "{}", format_args!("{:?},{:?},{:?}", v, p, s))?; - } - ScalarValue::Boolean(e) => format_option!(f, e)?, - ScalarValue::Float32(e) => format_option!(f, e)?, - ScalarValue::Float64(e) => format_option!(f, e)?, - ScalarValue::Int8(e) => format_option!(f, e)?, - ScalarValue::Int16(e) => format_option!(f, e)?, - ScalarValue::Int32(e) => format_option!(f, e)?, - ScalarValue::Int64(e) => format_option!(f, e)?, - ScalarValue::UInt8(e) => format_option!(f, e)?, - ScalarValue::UInt16(e) => format_option!(f, e)?, - ScalarValue::UInt32(e) => format_option!(f, e)?, - ScalarValue::UInt64(e) => format_option!(f, e)?, - ScalarValue::TimestampSecond(e, _) => format_option!(f, e)?, - ScalarValue::TimestampMillisecond(e, _) => format_option!(f, e)?, - ScalarValue::TimestampMicrosecond(e, _) => format_option!(f, e)?, - ScalarValue::TimestampNanosecond(e, _) => format_option!(f, e)?, - ScalarValue::Utf8(e) => format_option!(f, e)?, - ScalarValue::LargeUtf8(e) => format_option!(f, e)?, - ScalarValue::Binary(e) => match e { - Some(l) => write!( - f, - "{}", - l.iter() - .map(|v| format!("{}", v)) - .collect::>() - .join(",") - )?, - None => write!(f, "NULL")?, - }, - ScalarValue::LargeBinary(e) => match e { - Some(l) => write!( - f, - "{}", - l.iter() - .map(|v| format!("{}", v)) - .collect::>() - .join(",") - )?, - None => write!(f, "NULL")?, - }, - ScalarValue::List(e, _) => match e { - Some(l) => write!( - f, - "{}", - l.iter() - .map(|v| format!("{}", v)) - .collect::>() - .join(",") - )?, - None => write!(f, "NULL")?, - }, - ScalarValue::Date32(e) => format_option!(f, e)?, - ScalarValue::Date64(e) => format_option!(f, e)?, - ScalarValue::IntervalDayTime(e) => format_option!(f, e)?, - ScalarValue::IntervalYearMonth(e) => format_option!(f, e)?, - ScalarValue::IntervalMonthDayNano(e) => format_option!(f, e)?, - ScalarValue::Struct(e, fields) => match e { - Some(l) => write!( - f, - "{{{}}}", - l.iter() - .zip(fields.iter()) - .map(|(value, field)| format!("{}:{}", field.name(), value)) - .collect::>() - .join(",") - )?, - None => write!(f, "NULL")?, - }, - }; - Ok(()) - } -} - -impl fmt::Debug for ScalarValue { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - match self { - ScalarValue::Decimal128(_, _, _) => write!(f, "Decimal128({})", self), - ScalarValue::Boolean(_) => write!(f, "Boolean({})", self), - ScalarValue::Float32(_) => write!(f, "Float32({})", self), - ScalarValue::Float64(_) => write!(f, "Float64({})", self), - ScalarValue::Int8(_) => write!(f, "Int8({})", self), - ScalarValue::Int16(_) => write!(f, "Int16({})", self), - ScalarValue::Int32(_) => write!(f, "Int32({})", self), - ScalarValue::Int64(_) => write!(f, "Int64({})", self), - ScalarValue::UInt8(_) => write!(f, "UInt8({})", self), - ScalarValue::UInt16(_) => write!(f, "UInt16({})", self), - ScalarValue::UInt32(_) => write!(f, "UInt32({})", self), - ScalarValue::UInt64(_) => write!(f, "UInt64({})", self), - ScalarValue::TimestampSecond(_, tz_opt) => { - write!(f, "TimestampSecond({}, {:?})", self, tz_opt) - } - ScalarValue::TimestampMillisecond(_, tz_opt) => { - write!(f, "TimestampMillisecond({}, {:?})", self, tz_opt) - } - ScalarValue::TimestampMicrosecond(_, tz_opt) => { - write!(f, "TimestampMicrosecond({}, {:?})", self, tz_opt) - } - ScalarValue::TimestampNanosecond(_, tz_opt) => { - write!(f, "TimestampNanosecond({}, {:?})", self, tz_opt) - } - ScalarValue::Utf8(None) => write!(f, "Utf8({})", self), - ScalarValue::Utf8(Some(_)) => write!(f, "Utf8(\"{}\")", self), - ScalarValue::LargeUtf8(None) => write!(f, "LargeUtf8({})", self), - ScalarValue::LargeUtf8(Some(_)) => write!(f, "LargeUtf8(\"{}\")", self), - ScalarValue::Binary(None) => write!(f, "Binary({})", self), - ScalarValue::Binary(Some(_)) => write!(f, "Binary(\"{}\")", self), - ScalarValue::LargeBinary(None) => write!(f, "LargeBinary({})", self), - ScalarValue::LargeBinary(Some(_)) => write!(f, "LargeBinary(\"{}\")", self), - ScalarValue::List(_, dt) => write!(f, "List[{:?}]([{}])", dt, self), - ScalarValue::Date32(_) => write!(f, "Date32(\"{}\")", self), - ScalarValue::Date64(_) => write!(f, "Date64(\"{}\")", self), - ScalarValue::IntervalDayTime(_) => { - write!(f, "IntervalDayTime(\"{}\")", self) - } - ScalarValue::IntervalYearMonth(_) => { - write!(f, "IntervalYearMonth(\"{}\")", self) - } - ScalarValue::IntervalMonthDayNano(_) => { - write!(f, "IntervalMonthDayNano(\"{}\")", self) - } - ScalarValue::Struct(e, fields) => { - // Use Debug representation of field values - match e { - Some(l) => write!( - f, - "Struct({{{}}})", - l.iter() - .zip(fields.iter()) - .map(|(value, field)| format!("{}:{:?}", field.name(), value)) - .collect::>() - .join(",") - ), - None => write!(f, "Struct(NULL)"), - } - } - } - } -} +pub use datafusion_common::{ + ScalarValue, MAX_PRECISION_FOR_DECIMAL128, MAX_SCALE_FOR_DECIMAL128, +}; #[cfg(test)] mod tests { use super::*; use crate::field_util::struct_array_from; + use arrow::types::days_ms; + use arrow::{array::*, datatypes::*}; + use std::cmp::Ordering; + use std::sync::Arc; + + type StringArray = Utf8Array; + type LargeStringArray = Utf8Array; + type SmallBinaryArray = BinaryArray; + type LargeBinaryArray = BinaryArray; #[test] fn scalar_decimal_test() { diff --git a/datafusion/src/sql/planner.rs b/datafusion/src/sql/planner.rs index e07b41ff764c..3339e89ca7bc 100644 --- a/datafusion/src/sql/planner.rs +++ b/datafusion/src/sql/planner.rs @@ -36,7 +36,7 @@ use crate::logical_plan::{ use crate::optimizer::utils::exprlist_to_columns; use crate::prelude::JoinType; use crate::scalar::ScalarValue; -use crate::sql::utils::make_decimal_type; +use crate::sql::utils::{make_decimal_type, normalize_ident}; use crate::{ error::{DataFusionError, Result}, physical_plan::udaf::AggregateUDF, @@ -1194,7 +1194,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { SelectItem::UnnamedExpr(expr) => self.sql_to_rex(expr, schema), SelectItem::ExprWithAlias { expr, alias } => Ok(Alias( Box::new(self.sql_to_rex(expr, schema)?), - alias.value.clone(), + normalize_ident(alias), )), SelectItem::Wildcard => Ok(Expr::Wildcard), SelectItem::QualifiedWildcard(_) => Err(DataFusionError::NotImplemented( @@ -1395,6 +1395,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { SQLExpr::Identifier(ref id) => { if id.value.starts_with('@') { + // TODO: figure out if ScalarVariables should be insensitive. let var_names = vec![id.value.clone()]; Ok(Expr::ScalarVariable(var_names)) } else { @@ -1404,7 +1405,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { // identifier. (e.g. it is "foo.bar" not foo.bar) Ok(Expr::Column(Column { relation: None, - name: id.value.clone(), + name: normalize_ident(id), })) } } @@ -1421,8 +1422,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { } SQLExpr::CompoundIdentifier(ids) => { - let mut var_names: Vec<_> = - ids.iter().map(|id| id.value.clone()).collect(); + let mut var_names: Vec<_> = ids.iter().map(normalize_ident).collect(); if &var_names[0][0..1] == "@" { Ok(Expr::ScalarVariable(var_names)) @@ -1642,13 +1642,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { // (e.g. "foo.bar") for function names yet function.name.to_string() } else { - // if there is a quote style, then don't normalize - // the name, otherwise normalize to lowercase - let ident = &function.name.0[0]; - match ident.quote_style { - Some(_) => ident.value.clone(), - None => ident.value.to_ascii_lowercase(), - } + normalize_ident(&function.name.0[0]) }; // first, scalar built-in @@ -2181,11 +2175,10 @@ pub fn convert_data_type(sql_type: &SQLDataType) -> Result { #[cfg(test)] mod tests { - use functions::ScalarFunctionImplementation; - use crate::datasource::empty::EmptyTable; use crate::physical_plan::functions::Volatility; use crate::{logical_plan::create_udf, sql::parser::DFParser}; + use datafusion_expr::ScalarFunctionImplementation; use super::*; diff --git a/datafusion/src/sql/utils.rs b/datafusion/src/sql/utils.rs index 0ede5ad8559e..cbe40d6dc51d 100644 --- a/datafusion/src/sql/utils.rs +++ b/datafusion/src/sql/utils.rs @@ -18,7 +18,9 @@ //! SQL Utility Functions use arrow::datatypes::DataType; +use sqlparser::ast::Ident; +use crate::logical_plan::ExprVisitable; use crate::logical_plan::{Expr, LogicalPlan}; use crate::scalar::{ScalarValue, MAX_PRECISION_FOR_DECIMAL128}; use crate::{ @@ -532,6 +534,14 @@ pub(crate) fn make_decimal_type( } } +// Normalize an identifer to a lowercase string unless the identifier is quoted. +pub(crate) fn normalize_ident(id: &Ident) -> String { + match id.quote_style { + Some(_) => id.value.clone(), + None => id.value.to_ascii_lowercase(), + } +} + #[cfg(test)] mod tests { use super::*; diff --git a/datafusion/tests/order_spill_fuzz.rs b/datafusion/tests/order_spill_fuzz.rs index 9fd38f1e5b4a..179d940b1f24 100644 --- a/datafusion/tests/order_spill_fuzz.rs +++ b/datafusion/tests/order_spill_fuzz.rs @@ -17,8 +17,10 @@ //! Fuzz Test for various corner cases sorting RecordBatches exceeds available memory and should spill -use arrow::array::{ArrayRef, Int32Array}; -use arrow::compute::sort::SortOptions; +use arrow::{ + array::{ArrayRef, Int32Array}, + compute::sort::SortOptions, +}; use datafusion::execution::memory_manager::MemoryManagerConfig; use datafusion::execution::runtime_env::{RuntimeConfig, RuntimeEnv}; use datafusion::physical_plan::expressions::{col, PhysicalSortExpr}; diff --git a/datafusion/tests/parquet_pruning.rs b/datafusion/tests/parquet_pruning.rs index c41095d0e9eb..b4a02fe8d027 100644 --- a/datafusion/tests/parquet_pruning.rs +++ b/datafusion/tests/parquet_pruning.rs @@ -21,14 +21,11 @@ use std::sync::Arc; use arrow::array::PrimitiveArray; use arrow::datatypes::TimeUnit; -use arrow::error::ArrowError; +use arrow::io::parquet::write::{FileWriter, RowGroupIterator}; use arrow::{ array::{Array, ArrayRef, Float64Array, Int32Array, Int64Array, Utf8Array}, datatypes::{DataType, Field, Schema}, - io::parquet::write::{ - array_to_pages, to_parquet_schema, write_file, Compression, Compressor, DynIter, - DynStreamingIterator, Encoding, FallibleStreamingIterator, Version, WriteOptions, - }, + io::parquet::write::{Compression, Encoding, Version, WriteOptions}, }; use chrono::{Datelike, Duration}; use datafusion::field_util::SchemaExt; @@ -631,50 +628,34 @@ async fn make_test_file(scenario: Scenario) -> NamedTempFile { write_statistics: true, version: Version::V1, }; - let parquet_schema = to_parquet_schema(schema.as_ref()).unwrap(); - let descritors = parquet_schema.columns().to_vec().into_iter(); - - let row_groups = batches.iter().map(|batch| { - let iterator = - batch - .columns() - .iter() - .zip(descritors.clone()) - .map(|(array, type_)| { - let encoding = - if let DataType::Dictionary(_, _, _) = array.data_type() { - Encoding::RleDictionary - } else { - Encoding::Plain - }; - array_to_pages(array.as_ref(), type_, options, encoding).map( - move |pages| { - let encoded_pages = DynIter::new(pages.map(|x| Ok(x?))); - let compressed_pages = Compressor::new( - encoded_pages, - options.compression, - vec![], - ) - .map_err(ArrowError::from); - DynStreamingIterator::new(compressed_pages) - }, - ) - }); - let iterator = DynIter::new(iterator); - Ok(iterator) - }); - - let mut writer = output_file.as_file(); - - write_file( - &mut writer, - row_groups, + let encodings: Vec = schema + .fields() + .iter() + .map(|field| { + if let DataType::Dictionary(_, _, _) = field.data_type() { + Encoding::RleDictionary + } else { + Encoding::Plain + } + }) + .collect(); + let row_groups = RowGroupIterator::try_new( + batches.iter().map(|batch| Ok(batch.into())), schema, - parquet_schema, options, - None, - ) - .unwrap(); + encodings, + ); + + let mut file = output_file.as_file(); + + let mut writer = + FileWriter::try_new(&mut file, schema.as_ref().clone(), options).unwrap(); + writer.start().unwrap(); + for rg in row_groups.unwrap() { + let (group, len) = rg.unwrap(); + writer.write(group, len).unwrap(); + } + writer.end(None).unwrap(); output_file } diff --git a/datafusion/tests/simplification.rs b/datafusion/tests/simplification.rs index 7bd62401d4fa..942f3852c8d6 100644 --- a/datafusion/tests/simplification.rs +++ b/datafusion/tests/simplification.rs @@ -19,6 +19,8 @@ use arrow::datatypes::{DataType, Field, Schema}; use datafusion::field_util::SchemaExt; +use datafusion::logical_plan::ExprSchemable; +use datafusion::logical_plan::ExprSimplifiable; use datafusion::{ error::Result, execution::context::ExecutionProps, diff --git a/datafusion/tests/sql/explain.rs b/datafusion/tests/sql/explain.rs new file mode 100644 index 000000000000..00842b5eb8ab --- /dev/null +++ b/datafusion/tests/sql/explain.rs @@ -0,0 +1,60 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::datatypes::{DataType, Field, Schema}; +use datafusion::{ + logical_plan::{LogicalPlan, LogicalPlanBuilder, PlanType}, + prelude::ExecutionContext, +}; + +#[test] +fn optimize_explain() { + let schema = Schema::new(vec![Field::new("id", DataType::Int32, false)]); + + let plan = LogicalPlanBuilder::scan_empty(Some("employee"), &schema, None) + .unwrap() + .explain(true, false) + .unwrap() + .build() + .unwrap(); + + if let LogicalPlan::Explain(e) = &plan { + assert_eq!(e.stringified_plans.len(), 1); + } else { + panic!("plan was not an explain: {:?}", plan); + } + + // now optimize the plan and expect to see more plans + let optimized_plan = ExecutionContext::new().optimize(&plan).unwrap(); + if let LogicalPlan::Explain(e) = &optimized_plan { + // should have more than one plan + assert!( + e.stringified_plans.len() > 1, + "plans: {:#?}", + e.stringified_plans + ); + // should have at least one optimized plan + let opt = e + .stringified_plans + .iter() + .any(|p| matches!(p.plan_type, PlanType::OptimizedLogicalPlan { .. })); + + assert!(opt, "plans: {:#?}", e.stringified_plans); + } else { + panic!("plan was not an explain: {:?}", plan); + } +} diff --git a/datafusion/tests/sql/mod.rs b/datafusion/tests/sql/mod.rs index 26374cd5151b..710faf687fee 100644 --- a/datafusion/tests/sql/mod.rs +++ b/datafusion/tests/sql/mod.rs @@ -93,7 +93,9 @@ pub mod udf; pub mod union; pub mod window; +mod explain; pub mod information_schema; +mod partitioned_csv; #[cfg_attr(not(feature = "unicode_expressions"), ignore)] pub mod unicode; @@ -665,6 +667,21 @@ pub fn table_with_sequence( Ok(Arc::new(MemTable::try_new(schema, partitions)?)) } +/// Return a new table provider that has a single Int32 column with +/// values between `seq_start` and `seq_end` +pub fn table_with_sequence( + seq_start: i32, + seq_end: i32, +) -> Result> { + let schema = Arc::new(Schema::new(vec![Field::new("i", DataType::Int32, true)])); + let arr = Arc::new(Int32Array::from((seq_start..=seq_end).collect::>())); + let partitions = vec![vec![RecordBatch::try_new( + schema.clone(), + vec![arr as ArrayRef], + )?]]; + Ok(Arc::new(MemTable::try_new(schema, partitions)?)) +} + // Normalizes parts of an explain plan that vary from run to run (such as path) fn normalize_for_explain(s: &str) -> String { // Convert things like /Users/alamb/Software/arrow/testing/data/csv/aggregate_test_100.csv diff --git a/datafusion/tests/sql/partitioned_csv.rs b/datafusion/tests/sql/partitioned_csv.rs new file mode 100644 index 000000000000..5efc837d5c95 --- /dev/null +++ b/datafusion/tests/sql/partitioned_csv.rs @@ -0,0 +1,95 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Utility functions for running with a partitioned csv dataset: + +use std::{io::Write, sync::Arc}; + +use arrow::{ + datatypes::{DataType, Field, Schema, SchemaRef}, + record_batch::RecordBatch, +}; +use datafusion::{ + error::Result, + prelude::{CsvReadOptions, ExecutionConfig, ExecutionContext}, +}; +use tempfile::TempDir; + +/// Execute SQL and return results +async fn plan_and_collect( + ctx: &mut ExecutionContext, + sql: &str, +) -> Result> { + ctx.sql(sql).await?.collect().await +} + +/// Execute SQL and return results +pub async fn execute(sql: &str, partition_count: usize) -> Result> { + let tmp_dir = TempDir::new()?; + let mut ctx = create_ctx(&tmp_dir, partition_count).await?; + plan_and_collect(&mut ctx, sql).await +} + +/// Generate CSV partitions within the supplied directory +fn populate_csv_partitions( + tmp_dir: &TempDir, + partition_count: usize, + file_extension: &str, +) -> Result { + // define schema for data source (csv file) + let schema = Arc::new(Schema::new(vec![ + Field::new("c1", DataType::UInt32, false), + Field::new("c2", DataType::UInt64, false), + Field::new("c3", DataType::Boolean, false), + ])); + + // generate a partitioned file + for partition in 0..partition_count { + let filename = format!("partition-{}.{}", partition, file_extension); + let file_path = tmp_dir.path().join(&filename); + let mut file = std::fs::File::create(file_path)?; + + // generate some data + for i in 0..=10 { + let data = format!("{},{},{}\n", partition, i, i % 2 == 0); + file.write_all(data.as_bytes())?; + } + } + + Ok(schema) +} + +/// Generate a partitioned CSV file and register it with an execution context +pub async fn create_ctx( + tmp_dir: &TempDir, + partition_count: usize, +) -> Result { + let mut ctx = + ExecutionContext::with_config(ExecutionConfig::new().with_target_partitions(8)); + + let schema = populate_csv_partitions(tmp_dir, partition_count, ".csv")?; + + // register csv file with the execution context + ctx.register_csv( + "test", + tmp_dir.path().to_str().unwrap(), + CsvReadOptions::new().schema(&schema), + ) + .await?; + + Ok(ctx) +} diff --git a/datafusion/tests/sql/projection.rs b/datafusion/tests/sql/projection.rs index 57fa598bb754..0a956a9411eb 100644 --- a/datafusion/tests/sql/projection.rs +++ b/datafusion/tests/sql/projection.rs @@ -15,6 +15,9 @@ // specific language governing permissions and limitations // under the License. +use datafusion::logical_plan::{LogicalPlanBuilder, UNNAMED_TABLE}; +use tempfile::TempDir; + use super::*; #[tokio::test] @@ -73,3 +76,192 @@ async fn csv_query_group_by_avg_with_projection() -> Result<()> { assert_batches_sorted_eq!(expected, &actual); Ok(()) } + +#[tokio::test] +async fn parallel_projection() -> Result<()> { + let partition_count = 4; + let results = + partitioned_csv::execute("SELECT c1, c2 FROM test", partition_count).await?; + + let expected = vec![ + "+----+----+", + "| c1 | c2 |", + "+----+----+", + "| 3 | 1 |", + "| 3 | 2 |", + "| 3 | 3 |", + "| 3 | 4 |", + "| 3 | 5 |", + "| 3 | 6 |", + "| 3 | 7 |", + "| 3 | 8 |", + "| 3 | 9 |", + "| 3 | 10 |", + "| 2 | 1 |", + "| 2 | 2 |", + "| 2 | 3 |", + "| 2 | 4 |", + "| 2 | 5 |", + "| 2 | 6 |", + "| 2 | 7 |", + "| 2 | 8 |", + "| 2 | 9 |", + "| 2 | 10 |", + "| 1 | 1 |", + "| 1 | 2 |", + "| 1 | 3 |", + "| 1 | 4 |", + "| 1 | 5 |", + "| 1 | 6 |", + "| 1 | 7 |", + "| 1 | 8 |", + "| 1 | 9 |", + "| 1 | 10 |", + "| 0 | 1 |", + "| 0 | 2 |", + "| 0 | 3 |", + "| 0 | 4 |", + "| 0 | 5 |", + "| 0 | 6 |", + "| 0 | 7 |", + "| 0 | 8 |", + "| 0 | 9 |", + "| 0 | 10 |", + "+----+----+", + ]; + assert_batches_sorted_eq!(expected, &results); + + Ok(()) +} + +#[tokio::test] +async fn projection_on_table_scan() -> Result<()> { + let tmp_dir = TempDir::new()?; + let partition_count = 4; + let ctx = partitioned_csv::create_ctx(&tmp_dir, partition_count).await?; + let runtime = ctx.state.lock().runtime_env.clone(); + + let table = ctx.table("test")?; + let logical_plan = LogicalPlanBuilder::from(table.to_logical_plan()) + .project(vec![col("c2")])? + .build()?; + + let optimized_plan = ctx.optimize(&logical_plan)?; + match &optimized_plan { + LogicalPlan::Projection(Projection { input, .. }) => match &**input { + LogicalPlan::TableScan(TableScan { + source, + projected_schema, + .. + }) => { + assert_eq!(source.schema().fields().len(), 3); + assert_eq!(projected_schema.fields().len(), 1); + } + _ => panic!("input to projection should be TableScan"), + }, + _ => panic!("expect optimized_plan to be projection"), + } + + let expected = "Projection: #test.c2\ + \n TableScan: test projection=Some([1])"; + assert_eq!(format!("{:?}", optimized_plan), expected); + + let physical_plan = ctx.create_physical_plan(&optimized_plan).await?; + + assert_eq!(1, physical_plan.schema().fields().len()); + assert_eq!("c2", physical_plan.schema().field(0).name().as_str()); + + let batches = collect(physical_plan, runtime).await?; + assert_eq!(40, batches.iter().map(|x| x.num_rows()).sum::()); + + Ok(()) +} + +#[tokio::test] +async fn preserve_nullability_on_projection() -> Result<()> { + let tmp_dir = TempDir::new()?; + let ctx = partitioned_csv::create_ctx(&tmp_dir, 1).await?; + + let schema: Schema = ctx.table("test").unwrap().schema().clone().into(); + assert!(!schema.field_with_name("c1")?.is_nullable()); + + let plan = LogicalPlanBuilder::scan_empty(None, &schema, None)? + .project(vec![col("c1")])? + .build()?; + + let plan = ctx.optimize(&plan)?; + let physical_plan = ctx.create_physical_plan(&Arc::new(plan)).await?; + assert!(!physical_plan.schema().field_with_name("c1")?.is_nullable()); + Ok(()) +} + +#[tokio::test] +async fn projection_on_memory_scan() -> Result<()> { + let schema = Schema::new(vec![ + Field::new("a", DataType::Int32, false), + Field::new("b", DataType::Int32, false), + Field::new("c", DataType::Int32, false), + ]); + let schema = SchemaRef::new(schema); + + let partitions = vec![vec![RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(Int32Array::from_slice(&[1, 10, 10, 100])), + Arc::new(Int32Array::from_slice(&[2, 12, 12, 120])), + Arc::new(Int32Array::from_slice(&[3, 12, 12, 120])), + ], + )?]]; + + let plan = LogicalPlanBuilder::scan_memory(partitions, schema, None)? + .project(vec![col("b")])? + .build()?; + assert_fields_eq(&plan, vec!["b"]); + + let ctx = ExecutionContext::new(); + let optimized_plan = ctx.optimize(&plan)?; + match &optimized_plan { + LogicalPlan::Projection(Projection { input, .. }) => match &**input { + LogicalPlan::TableScan(TableScan { + source, + projected_schema, + .. + }) => { + assert_eq!(source.schema().fields().len(), 3); + assert_eq!(projected_schema.fields().len(), 1); + } + _ => panic!("input to projection should be InMemoryScan"), + }, + _ => panic!("expect optimized_plan to be projection"), + } + + let expected = format!( + "Projection: #{}.b\ + \n TableScan: {} projection=Some([1])", + UNNAMED_TABLE, UNNAMED_TABLE + ); + assert_eq!(format!("{:?}", optimized_plan), expected); + + let physical_plan = ctx.create_physical_plan(&optimized_plan).await?; + + assert_eq!(1, physical_plan.schema().fields().len()); + assert_eq!("b", physical_plan.schema().field(0).name().as_str()); + + let runtime = ctx.state.lock().runtime_env.clone(); + let batches = collect(physical_plan, runtime).await?; + assert_eq!(1, batches.len()); + assert_eq!(1, batches[0].num_columns()); + assert_eq!(4, batches[0].num_rows()); + + Ok(()) +} + +fn assert_fields_eq(plan: &LogicalPlan, expected: Vec<&str>) { + let actual: Vec = plan + .schema() + .fields() + .iter() + .map(|f| f.name().clone()) + .collect(); + assert_eq!(actual, expected); +} diff --git a/datafusion/tests/sql/select.rs b/datafusion/tests/sql/select.rs index 89fd6f2b1571..132da0777058 100644 --- a/datafusion/tests/sql/select.rs +++ b/datafusion/tests/sql/select.rs @@ -16,6 +16,8 @@ // under the License. use super::*; +use datafusion::physical_plan::collect_partitioned; +use tempfile::TempDir; #[tokio::test] async fn all_where_empty() -> Result<()> { @@ -924,3 +926,59 @@ async fn csv_select_nested() -> Result<()> { assert_batches_eq!(expected, &actual); Ok(()) } + +#[tokio::test] +async fn parallel_query_with_filter() -> Result<()> { + let tmp_dir = TempDir::new()?; + let partition_count = 4; + let ctx = partitioned_csv::create_ctx(&tmp_dir, partition_count).await?; + + let logical_plan = + ctx.create_logical_plan("SELECT c1, c2 FROM test WHERE c1 > 0 AND c1 < 3")?; + let logical_plan = ctx.optimize(&logical_plan)?; + + let physical_plan = ctx.create_physical_plan(&logical_plan).await?; + + let runtime = ctx.state.lock().runtime_env.clone(); + let results = collect_partitioned(physical_plan, runtime).await?; + + // note that the order of partitions is not deterministic + let mut num_rows = 0; + for partition in &results { + for batch in partition { + num_rows += batch.num_rows(); + } + } + assert_eq!(20, num_rows); + + let results: Vec = results.into_iter().flatten().collect(); + let expected = vec![ + "+----+----+", + "| c1 | c2 |", + "+----+----+", + "| 1 | 1 |", + "| 1 | 10 |", + "| 1 | 2 |", + "| 1 | 3 |", + "| 1 | 4 |", + "| 1 | 5 |", + "| 1 | 6 |", + "| 1 | 7 |", + "| 1 | 8 |", + "| 1 | 9 |", + "| 2 | 1 |", + "| 2 | 10 |", + "| 2 | 2 |", + "| 2 | 3 |", + "| 2 | 4 |", + "| 2 | 5 |", + "| 2 | 6 |", + "| 2 | 7 |", + "| 2 | 8 |", + "| 2 | 9 |", + "+----+----+", + ]; + assert_batches_sorted_eq!(expected, &results); + + Ok(()) +} diff --git a/docs/source/index.rst b/docs/source/index.rst index bf6b25096b4b..5109e60338fa 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -55,6 +55,7 @@ Table of content specification/roadmap specification/invariants specification/output-field-name-semantic + specification/quarterly_roadmap .. _toc.readme: diff --git a/docs/source/specification/quarterly_roadmap.md b/docs/source/specification/quarterly_roadmap.md new file mode 100644 index 000000000000..5bb805d7e7f0 --- /dev/null +++ b/docs/source/specification/quarterly_roadmap.md @@ -0,0 +1,72 @@ + + +# Roadmap + +A quarterly roadmap will be published to give the DataFusion community visibility into the priorities of the projects contributors. This roadmap is not binding. + +## 2022 Q1 + +### DataFusion Core + +- Publish official Arrow2 branch +- Implementation of memory manager (i.e. to enable spilling to disk as needed) + +### Benchmarking + +- Inclusion in Db-Benchmark with all quries covered +- All TPCH queries covered + +### Performance Improvements + +- Predicate evaluation +- Improve multi-column comparisons (that can't be vectorized at the moment) +- Null constant support + +### New Features + +- Read JSON as table +- Simplify DDL with Datafusion-Cli +- Add Decimal128 data type and the attendant features such as Arrow Kernel and UDF support +- Add new experimental e-graph based optimizer + +### Ballista + +- Begin work on design documents and plan / priorities for development + +### Extensions ([datafusion-contrib](https://github.com/datafusion-contrib])) + +- Stable S3 support +- Begin design discussions and prototyping of a stream provider + +## Beyond 2022 Q1 + +There is no clear timeline for the below, but community members have expressed interest in working on these topics. + +### DataFusion Core + +- Custom SQL support +- Split DataFusion into multiple crates +- Push based query execution and code generation + +### Ballista + +- Evolve architecture so that it can be deployed in a multi-tenant cloud native environment +- Ensure Ballista is scalable, elastic, and stable for production usage +- Develop distributed ML capabilities diff --git a/docs/source/user-guide/sql/index.rst b/docs/source/user-guide/sql/index.rst index 2489f6ba1f10..fc96acc8733c 100644 --- a/docs/source/user-guide/sql/index.rst +++ b/docs/source/user-guide/sql/index.rst @@ -21,6 +21,7 @@ SQL Reference .. toctree:: :maxdepth: 2 + sql_status select ddl DataFusion Functions diff --git a/docs/source/user-guide/sql/sql_status.md b/docs/source/user-guide/sql/sql_status.md new file mode 100644 index 000000000000..0df14e58a8be --- /dev/null +++ b/docs/source/user-guide/sql/sql_status.md @@ -0,0 +1,241 @@ + + +# Status + +## General + +- [x] SQL Parser +- [x] SQL Query Planner +- [x] Query Optimizer +- [x] Constant folding +- [x] Join Reordering +- [x] Limit Pushdown +- [x] Projection push down +- [x] Predicate push down +- [x] Type coercion +- [x] Parallel query execution + +## SQL Support + +- [x] Projection +- [x] Filter (WHERE) +- [x] Filter post-aggregate (HAVING) +- [x] Limit +- [x] Aggregate +- [x] Common math functions +- [x] cast +- [x] try_cast +- [x] [`VALUES` lists](https://www.postgresql.org/docs/current/queries-values.html) +- Postgres compatible String functions + - [x] ascii + - [x] bit_length + - [x] btrim + - [x] char_length + - [x] character_length + - [x] chr + - [x] concat + - [x] concat_ws + - [x] initcap + - [x] left + - [x] length + - [x] lpad + - [x] ltrim + - [x] octet_length + - [x] regexp_replace + - [x] repeat + - [x] replace + - [x] reverse + - [x] right + - [x] rpad + - [x] rtrim + - [x] split_part + - [x] starts_with + - [x] strpos + - [x] substr + - [x] to_hex + - [x] translate + - [x] trim +- Miscellaneous/Boolean functions + - [x] nullif +- Approximation functions + - [x] approx_distinct +- Common date/time functions + - [ ] Basic date functions + - [ ] Basic time functions + - [x] Basic timestamp functions + - [x] [to_timestamp](docs/user-guide/book/sql/datafusion-functions.html#to_timestamp) + - [x] [to_timestamp_millis](docs/user-guide/book/sql/datafusion-functions.html#to_timestamp_millis) + - [x] [to_timestamp_micros](docs/user-guide/book/sql/datafusion-functions.html#to_timestamp_micros) + - [x] [to_timestamp_seconds](docs/user-guide/book/sql/datafusion-functions.html#to_timestamp_seconds) +- nested functions + - [x] Array of columns +- [x] Schema Queries + - [x] SHOW TABLES + - [x] SHOW COLUMNS + - [x] information_schema.{tables, columns} + - [ ] information_schema other views +- [x] Sorting +- [ ] Nested types +- [ ] Lists +- [x] Subqueries +- [x] Common table expressions +- [x] Set Operations + - [x] UNION ALL + - [x] UNION + - [x] INTERSECT + - [x] INTERSECT ALL + - [x] EXCEPT + - [x] EXCEPT ALL +- [x] Joins + - [x] INNER JOIN + - [x] LEFT JOIN + - [x] RIGHT JOIN + - [x] FULL JOIN + - [x] CROSS JOIN +- [ ] Window + - [x] Empty window + - [x] Common window functions + - [x] Window with PARTITION BY clause + - [x] Window with ORDER BY clause + - [ ] Window with FILTER clause + - [ ] [Window with custom WINDOW FRAME](https://github.com/apache/arrow-datafusion/issues/361) + - [ ] UDF and UDAF for window functions + +## Data Sources + +- [x] CSV +- [x] Parquet primitive types +- [ ] Parquet nested types + +## Extensibility + +DataFusion is designed to be extensible at all points. To that end, you can provide your own custom: + +- [x] User Defined Functions (UDFs) +- [x] User Defined Aggregate Functions (UDAFs) +- [x] User Defined Table Source (`TableProvider`) for tables +- [x] User Defined `Optimizer` passes (plan rewrites) +- [x] User Defined `LogicalPlan` nodes +- [x] User Defined `ExecutionPlan` nodes + +## Rust Version Compatbility + +This crate is tested with the latest stable version of Rust. We do not currently test against other, older versions of the Rust compiler. + +# Supported SQL + +This library currently supports many SQL constructs, including + +- `CREATE EXTERNAL TABLE X STORED AS PARQUET LOCATION '...';` to register a table's locations +- `SELECT ... FROM ...` together with any expression +- `ALIAS` to name an expression +- `CAST` to change types, including e.g. `Timestamp(Nanosecond, None)` +- Many mathematical unary and binary expressions such as `+`, `/`, `sqrt`, `tan`, `>=`. +- `WHERE` to filter +- `GROUP BY` together with one of the following aggregations: `MIN`, `MAX`, `COUNT`, `SUM`, `AVG`, `CORR`, `VAR`, `COVAR`, `STDDEV` (sample and population) +- `ORDER BY` together with an expression and optional `ASC` or `DESC` and also optional `NULLS FIRST` or `NULLS LAST` + +## Supported Functions + +DataFusion strives to implement a subset of the [PostgreSQL SQL dialect](https://www.postgresql.org/docs/current/functions.html) where possible. We explicitly choose a single dialect to maximize interoperability with other tools and allow reuse of the PostgreSQL documents and tutorials as much as possible. + +Currently, only a subset of the PostgreSQL dialect is implemented, and we will document any deviations. + +## Schema Metadata / Information Schema Support + +DataFusion supports the showing metadata about the tables available. This information can be accessed using the views of the ISO SQL `information_schema` schema or the DataFusion specific `SHOW TABLES` and `SHOW COLUMNS` commands. + +More information can be found in the [Postgres docs](https://www.postgresql.org/docs/13/infoschema-schema.html)). + +To show tables available for use in DataFusion, use the `SHOW TABLES` command or the `information_schema.tables` view: + +```sql +> show tables; ++---------------+--------------------+------------+------------+ +| table_catalog | table_schema | table_name | table_type | ++---------------+--------------------+------------+------------+ +| datafusion | public | t | BASE TABLE | +| datafusion | information_schema | tables | VIEW | ++---------------+--------------------+------------+------------+ + +> select * from information_schema.tables; + ++---------------+--------------------+------------+--------------+ +| table_catalog | table_schema | table_name | table_type | ++---------------+--------------------+------------+--------------+ +| datafusion | public | t | BASE TABLE | +| datafusion | information_schema | TABLES | SYSTEM TABLE | ++---------------+--------------------+------------+--------------+ +``` + +To show the schema of a table in DataFusion, use the `SHOW COLUMNS` command or the or `information_schema.columns` view: + +```sql +> show columns from t; ++---------------+--------------+------------+-------------+-----------+-------------+ +| table_catalog | table_schema | table_name | column_name | data_type | is_nullable | ++---------------+--------------+------------+-------------+-----------+-------------+ +| datafusion | public | t | a | Int32 | NO | +| datafusion | public | t | b | Utf8 | NO | +| datafusion | public | t | c | Float32 | NO | ++---------------+--------------+------------+-------------+-----------+-------------+ + +> select table_name, column_name, ordinal_position, is_nullable, data_type from information_schema.columns; ++------------+-------------+------------------+-------------+-----------+ +| table_name | column_name | ordinal_position | is_nullable | data_type | ++------------+-------------+------------------+-------------+-----------+ +| t | a | 0 | NO | Int32 | +| t | b | 1 | NO | Utf8 | +| t | c | 2 | NO | Float32 | ++------------+-------------+------------------+-------------+-----------+ +``` + +## Supported Data Types + +DataFusion uses Arrow, and thus the Arrow type system, for query +execution. The SQL types from +[sqlparser-rs](https://github.com/ballista-compute/sqlparser-rs/blob/main/src/ast/data_type.rs#L57) +are mapped to Arrow types according to the following table + +| SQL Data Type | Arrow DataType | +| ------------- | --------------------------------- | +| `CHAR` | `Utf8` | +| `VARCHAR` | `Utf8` | +| `UUID` | _Not yet supported_ | +| `CLOB` | _Not yet supported_ | +| `BINARY` | _Not yet supported_ | +| `VARBINARY` | _Not yet supported_ | +| `DECIMAL` | `Float64` | +| `FLOAT` | `Float32` | +| `SMALLINT` | `Int16` | +| `INT` | `Int32` | +| `BIGINT` | `Int64` | +| `REAL` | `Float32` | +| `DOUBLE` | `Float64` | +| `BOOLEAN` | `Boolean` | +| `DATE` | `Date32` | +| `TIME` | `Time64(TimeUnit::Millisecond)` | +| `TIMESTAMP` | `Timestamp(TimeUnit::Nanosecond)` | +| `INTERVAL` | _Not yet supported_ | +| `REGCLASS` | _Not yet supported_ | +| `TEXT` | _Not yet supported_ | +| `BYTEA` | _Not yet supported_ | +| `CUSTOM` | _Not yet supported_ | +| `ARRAY` | _Not yet supported_ |