diff --git a/datafusion/sql/src/unparser/ast.rs b/datafusion/sql/src/unparser/ast.rs new file mode 100644 index 000000000000..955aabe74c22 --- /dev/null +++ b/datafusion/sql/src/unparser/ast.rs @@ -0,0 +1,585 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! This file contains builders to create SQL ASTs. They are purposefully +//! not exported as they will eventually be move to the SQLparser package. +//! +//! +//! See + +use core::fmt; + +use sqlparser::ast; + +#[derive(Clone)] +pub(super) struct QueryBuilder { + with: Option, + body: Option>, + order_by: Vec, + limit: Option, + limit_by: Vec, + offset: Option, + fetch: Option, + locks: Vec, + for_clause: Option, +} + +#[allow(dead_code)] +impl QueryBuilder { + pub fn with(&mut self, value: Option) -> &mut Self { + let new = self; + new.with = value; + new + } + pub fn body(&mut self, value: Box) -> &mut Self { + let new = self; + new.body = Option::Some(value); + new + } + pub fn order_by(&mut self, value: Vec) -> &mut Self { + let new = self; + new.order_by = value; + new + } + pub fn limit(&mut self, value: Option) -> &mut Self { + let new = self; + new.limit = value; + new + } + pub fn limit_by(&mut self, value: Vec) -> &mut Self { + let new = self; + new.limit_by = value; + new + } + pub fn offset(&mut self, value: Option) -> &mut Self { + let new = self; + new.offset = value; + new + } + pub fn fetch(&mut self, value: Option) -> &mut Self { + let new = self; + new.fetch = value; + new + } + pub fn locks(&mut self, value: Vec) -> &mut Self { + let new = self; + new.locks = value; + new + } + pub fn for_clause(&mut self, value: Option) -> &mut Self { + let new = self; + new.for_clause = value; + new + } + pub fn build(&self) -> Result { + Ok(ast::Query { + with: self.with.clone(), + body: match self.body { + Some(ref value) => value.clone(), + None => { + return Result::Err(Into::into(UninitializedFieldError::from("body"))) + } + }, + order_by: self.order_by.clone(), + limit: self.limit.clone(), + limit_by: self.limit_by.clone(), + offset: self.offset.clone(), + fetch: self.fetch.clone(), + locks: self.locks.clone(), + for_clause: self.for_clause.clone(), + }) + } + fn create_empty() -> Self { + Self { + with: Default::default(), + body: Default::default(), + order_by: Default::default(), + limit: Default::default(), + limit_by: Default::default(), + offset: Default::default(), + fetch: Default::default(), + locks: Default::default(), + for_clause: Default::default(), + } + } +} +impl Default for QueryBuilder { + fn default() -> Self { + Self::create_empty() + } +} + +#[derive(Clone)] +pub(super) struct SelectBuilder { + distinct: Option, + top: Option, + projection: Vec, + into: Option, + from: Vec, + lateral_views: Vec, + selection: Option, + group_by: Option, + cluster_by: Vec, + distribute_by: Vec, + sort_by: Vec, + having: Option, + named_window: Vec, + qualify: Option, + value_table_mode: Option, +} + +#[allow(dead_code)] +impl SelectBuilder { + pub fn distinct(&mut self, value: Option) -> &mut Self { + let new = self; + new.distinct = value; + new + } + pub fn top(&mut self, value: Option) -> &mut Self { + let new = self; + new.top = value; + new + } + pub fn projection(&mut self, value: Vec) -> &mut Self { + let new = self; + new.projection = value; + new + } + pub fn into(&mut self, value: Option) -> &mut Self { + let new = self; + new.into = value; + new + } + pub fn from(&mut self, value: Vec) -> &mut Self { + let new = self; + new.from = value; + new + } + pub fn push_from(&mut self, value: TableWithJoinsBuilder) -> &mut Self { + let new = self; + new.from.push(value); + new + } + pub fn pop_from(&mut self) -> Option { + self.from.pop() + } + pub fn lateral_views(&mut self, value: Vec) -> &mut Self { + let new = self; + new.lateral_views = value; + new + } + pub fn selection(&mut self, value: Option) -> &mut Self { + let new = self; + new.selection = value; + new + } + pub fn group_by(&mut self, value: ast::GroupByExpr) -> &mut Self { + let new = self; + new.group_by = Option::Some(value); + new + } + pub fn cluster_by(&mut self, value: Vec) -> &mut Self { + let new = self; + new.cluster_by = value; + new + } + pub fn distribute_by(&mut self, value: Vec) -> &mut Self { + let new = self; + new.distribute_by = value; + new + } + pub fn sort_by(&mut self, value: Vec) -> &mut Self { + let new = self; + new.sort_by = value; + new + } + pub fn having(&mut self, value: Option) -> &mut Self { + let new = self; + new.having = value; + new + } + pub fn named_window(&mut self, value: Vec) -> &mut Self { + let new = self; + new.named_window = value; + new + } + pub fn qualify(&mut self, value: Option) -> &mut Self { + let new = self; + new.qualify = value; + new + } + pub fn value_table_mode(&mut self, value: Option) -> &mut Self { + let new = self; + new.value_table_mode = value; + new + } + pub fn build(&self) -> Result { + Ok(ast::Select { + distinct: self.distinct.clone(), + top: self.top.clone(), + projection: self.projection.clone(), + into: self.into.clone(), + from: self + .from + .iter() + .map(|b| b.build()) + .collect::, BuilderError>>()?, + lateral_views: self.lateral_views.clone(), + selection: self.selection.clone(), + group_by: match self.group_by { + Some(ref value) => value.clone(), + None => { + return Result::Err(Into::into(UninitializedFieldError::from( + "group_by", + ))) + } + }, + cluster_by: self.cluster_by.clone(), + distribute_by: self.distribute_by.clone(), + sort_by: self.sort_by.clone(), + having: self.having.clone(), + named_window: self.named_window.clone(), + qualify: self.qualify.clone(), + value_table_mode: self.value_table_mode, + }) + } + fn create_empty() -> Self { + Self { + distinct: Default::default(), + top: Default::default(), + projection: Default::default(), + into: Default::default(), + from: Default::default(), + lateral_views: Default::default(), + selection: Default::default(), + group_by: Some(ast::GroupByExpr::Expressions(Vec::new())), + cluster_by: Default::default(), + distribute_by: Default::default(), + sort_by: Default::default(), + having: Default::default(), + named_window: Default::default(), + qualify: Default::default(), + value_table_mode: Default::default(), + } + } +} +impl Default for SelectBuilder { + fn default() -> Self { + Self::create_empty() + } +} + +#[derive(Clone)] +pub(super) struct TableWithJoinsBuilder { + relation: Option, + joins: Vec, +} + +#[allow(dead_code)] +impl TableWithJoinsBuilder { + pub fn relation(&mut self, value: RelationBuilder) -> &mut Self { + let new = self; + new.relation = Option::Some(value); + new + } + + pub fn joins(&mut self, value: Vec) -> &mut Self { + let new = self; + new.joins = value; + new + } + pub fn push_join(&mut self, value: ast::Join) -> &mut Self { + let new = self; + new.joins.push(value); + new + } + + pub fn build(&self) -> Result { + Ok(ast::TableWithJoins { + relation: match self.relation { + Some(ref value) => value.build()?, + None => { + return Result::Err(Into::into(UninitializedFieldError::from( + "relation", + ))) + } + }, + joins: self.joins.clone(), + }) + } + fn create_empty() -> Self { + Self { + relation: Default::default(), + joins: Default::default(), + } + } +} +impl Default for TableWithJoinsBuilder { + fn default() -> Self { + Self::create_empty() + } +} + +#[derive(Clone)] +pub(super) struct RelationBuilder { + relation: Option, +} + +#[allow(dead_code)] +#[derive(Clone)] +enum TableFactorBuilder { + Table(TableRelationBuilder), + Derived(DerivedRelationBuilder), +} + +#[allow(dead_code)] +impl RelationBuilder { + pub fn has_relation(&self) -> bool { + self.relation.is_some() + } + pub fn table(&mut self, value: TableRelationBuilder) -> &mut Self { + let new = self; + new.relation = Option::Some(TableFactorBuilder::Table(value)); + new + } + pub fn derived(&mut self, value: DerivedRelationBuilder) -> &mut Self { + let new = self; + new.relation = Option::Some(TableFactorBuilder::Derived(value)); + new + } + pub fn alias(&mut self, value: Option) -> &mut Self { + let new = self; + match new.relation { + Some(TableFactorBuilder::Table(ref mut rel_builder)) => { + rel_builder.alias = value; + } + Some(TableFactorBuilder::Derived(ref mut rel_builder)) => { + rel_builder.alias = value; + } + None => (), + } + new + } + pub fn build(&self) -> Result { + Ok(match self.relation { + Some(TableFactorBuilder::Table(ref value)) => value.build()?, + Some(TableFactorBuilder::Derived(ref value)) => value.build()?, + None => { + return Result::Err(Into::into(UninitializedFieldError::from("relation"))) + } + }) + } + fn create_empty() -> Self { + Self { + relation: Default::default(), + } + } +} +impl Default for RelationBuilder { + fn default() -> Self { + Self::create_empty() + } +} + +#[derive(Clone)] +pub(super) struct TableRelationBuilder { + name: Option, + alias: Option, + args: Option>, + with_hints: Vec, + version: Option, + partitions: Vec, +} + +#[allow(dead_code)] +impl TableRelationBuilder { + pub fn name(&mut self, value: ast::ObjectName) -> &mut Self { + let new = self; + new.name = Option::Some(value); + new + } + pub fn alias(&mut self, value: Option) -> &mut Self { + let new = self; + new.alias = value; + new + } + pub fn args(&mut self, value: Option>) -> &mut Self { + let new = self; + new.args = value; + new + } + pub fn with_hints(&mut self, value: Vec) -> &mut Self { + let new = self; + new.with_hints = value; + new + } + pub fn version(&mut self, value: Option) -> &mut Self { + let new = self; + new.version = value; + new + } + pub fn partitions(&mut self, value: Vec) -> &mut Self { + let new = self; + new.partitions = value; + new + } + pub fn build(&self) -> Result { + Ok(ast::TableFactor::Table { + name: match self.name { + Some(ref value) => value.clone(), + None => { + return Result::Err(Into::into(UninitializedFieldError::from("name"))) + } + }, + alias: self.alias.clone(), + args: self.args.clone(), + with_hints: self.with_hints.clone(), + version: self.version.clone(), + partitions: self.partitions.clone(), + }) + } + fn create_empty() -> Self { + Self { + name: Default::default(), + alias: Default::default(), + args: Default::default(), + with_hints: Default::default(), + version: Default::default(), + partitions: Default::default(), + } + } +} +impl Default for TableRelationBuilder { + fn default() -> Self { + Self::create_empty() + } +} +#[derive(Clone)] +pub(super) struct DerivedRelationBuilder { + lateral: Option, + subquery: Option>, + alias: Option, +} + +#[allow(dead_code)] +impl DerivedRelationBuilder { + pub fn lateral(&mut self, value: bool) -> &mut Self { + let new = self; + new.lateral = Option::Some(value); + new + } + pub fn subquery(&mut self, value: Box) -> &mut Self { + let new = self; + new.subquery = Option::Some(value); + new + } + pub fn alias(&mut self, value: Option) -> &mut Self { + let new = self; + new.alias = value; + new + } + fn build(&self) -> Result { + Ok(ast::TableFactor::Derived { + lateral: match self.lateral { + Some(ref value) => *value, + None => { + return Result::Err(Into::into(UninitializedFieldError::from( + "lateral", + ))) + } + }, + subquery: match self.subquery { + Some(ref value) => value.clone(), + None => { + return Result::Err(Into::into(UninitializedFieldError::from( + "subquery", + ))) + } + }, + alias: self.alias.clone(), + }) + } + fn create_empty() -> Self { + Self { + lateral: Default::default(), + subquery: Default::default(), + alias: Default::default(), + } + } +} +impl Default for DerivedRelationBuilder { + fn default() -> Self { + Self::create_empty() + } +} + +/// Runtime error when a `build()` method is called and one or more required fields +/// do not have a value. +#[derive(Debug, Clone)] +pub(super) struct UninitializedFieldError(&'static str); + +impl UninitializedFieldError { + /// Create a new `UnitializedFieldError` for the specified field name. + pub fn new(field_name: &'static str) -> Self { + UninitializedFieldError(field_name) + } + + /// Get the name of the first-declared field that wasn't initialized + pub fn field_name(&self) -> &'static str { + self.0 + } +} + +impl fmt::Display for UninitializedFieldError { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "Field not initialized: {}", self.0) + } +} + +impl From<&'static str> for UninitializedFieldError { + fn from(field_name: &'static str) -> Self { + Self::new(field_name) + } +} +impl std::error::Error for UninitializedFieldError {} + +#[derive(Debug)] +pub enum BuilderError { + UninitializedField(&'static str), + ValidationError(String), +} +impl From for BuilderError { + fn from(s: UninitializedFieldError) -> Self { + Self::UninitializedField(s.field_name()) + } +} +impl From for BuilderError { + fn from(s: String) -> Self { + Self::ValidationError(s) + } +} +impl fmt::Display for BuilderError { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match self { + Self::UninitializedField(ref field) => { + write!(f, "`{}` must be initialized", field) + } + Self::ValidationError(ref error) => write!(f, "{}", error), + } + } +} +impl std::error::Error for BuilderError {} diff --git a/datafusion/sql/src/unparser/expr.rs b/datafusion/sql/src/unparser/expr.rs index bb14c8a70739..2a9fdd47ad93 100644 --- a/datafusion/sql/src/unparser/expr.rs +++ b/datafusion/sql/src/unparser/expr.rs @@ -118,14 +118,14 @@ impl Unparser<'_> { Ok(ast::Expr::Identifier(self.new_ident(col.name.to_string()))) } - fn new_ident(&self, str: String) -> ast::Ident { + pub(super) fn new_ident(&self, str: String) -> ast::Ident { ast::Ident { value: str, quote_style: self.dialect.identifier_quote_style(), } } - fn binary_op_to_sql( + pub(super) fn binary_op_to_sql( &self, lhs: ast::Expr, rhs: ast::Expr, @@ -312,19 +312,18 @@ mod tests { use super::*; + // See sql::tests for E2E tests. + #[test] fn expr_to_sql_ok() -> Result<()> { - let tests: Vec<(Expr, &str)> = vec![ - (col("a").gt(lit(4)), r#"a > 4"#), - ( - Expr::Column(Column { - relation: Some(TableReference::partial("a", "b")), - name: "c".to_string(), - }) - .gt(lit(4)), - r#"a.b.c > 4"#, - ), - ]; + let tests: Vec<(Expr, &str)> = vec![( + Expr::Column(Column { + relation: Some(TableReference::partial("a", "b")), + name: "c".to_string(), + }) + .gt(lit(4)), + r#"a.b.c > 4"#, + )]; for (expr, expected) in tests { let ast = expr_to_sql(&expr)?; diff --git a/datafusion/sql/src/unparser/mod.rs b/datafusion/sql/src/unparser/mod.rs index 77a9de0975ed..e67ebc198018 100644 --- a/datafusion/sql/src/unparser/mod.rs +++ b/datafusion/sql/src/unparser/mod.rs @@ -15,9 +15,12 @@ // specific language governing permissions and limitations // under the License. +mod ast; mod expr; +mod plan; pub use expr::expr_to_sql; +pub use plan::plan_to_sql; use self::dialect::{DefaultDialect, Dialect}; pub mod dialect; diff --git a/datafusion/sql/src/unparser/plan.rs b/datafusion/sql/src/unparser/plan.rs new file mode 100644 index 000000000000..21e4427c1f46 --- /dev/null +++ b/datafusion/sql/src/unparser/plan.rs @@ -0,0 +1,361 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use datafusion_common::{not_impl_err, plan_err, DataFusionError, Result}; +use datafusion_expr::{expr::Alias, Expr, JoinConstraint, JoinType, LogicalPlan}; +use sqlparser::ast; + +use super::{ + ast::{ + BuilderError, QueryBuilder, RelationBuilder, SelectBuilder, TableRelationBuilder, + TableWithJoinsBuilder, + }, + Unparser, +}; + +/// Convert a DataFusion [`LogicalPlan`] to `sqlparser::ast::Statement` +/// +/// This function is the opposite of `SqlToRel::sql_statement_to_plan` and can +/// be used to, among other things, convert `LogicalPlan`s to strings. +/// +/// # Example +/// ``` +/// use arrow::datatypes::{DataType, Field, Schema}; +/// use datafusion_expr::{col, logical_plan::table_scan}; +/// use datafusion_sql::unparser::plan_to_sql; +/// let schema = Schema::new(vec![ +/// Field::new("id", DataType::Utf8, false), +/// Field::new("value", DataType::Utf8, false), +/// ]); +/// let plan = table_scan(Some("table"), &schema, None) +/// .unwrap() +/// .project(vec![col("id"), col("value")]) +/// .unwrap() +/// .build() +/// .unwrap(); +/// let sql = plan_to_sql(&plan).unwrap(); +/// +/// assert_eq!(format!("{}", sql), "SELECT table.id, table.value FROM table") +/// ``` +pub fn plan_to_sql(plan: &LogicalPlan) -> Result { + let unparser = Unparser::default(); + unparser.plan_to_sql(plan) +} + +impl Unparser<'_> { + pub fn plan_to_sql(&self, plan: &LogicalPlan) -> Result { + match plan { + LogicalPlan::Projection(_) + | LogicalPlan::Filter(_) + | LogicalPlan::Window(_) + | LogicalPlan::Aggregate(_) + | LogicalPlan::Sort(_) + | LogicalPlan::Join(_) + | LogicalPlan::CrossJoin(_) + | LogicalPlan::Repartition(_) + | LogicalPlan::Union(_) + | LogicalPlan::TableScan(_) + | LogicalPlan::EmptyRelation(_) + | LogicalPlan::Subquery(_) + | LogicalPlan::SubqueryAlias(_) + | LogicalPlan::Limit(_) + | LogicalPlan::Statement(_) + | LogicalPlan::Values(_) + | LogicalPlan::Distinct(_) => self.select_to_sql(plan), + LogicalPlan::Dml(_) => self.dml_to_sql(plan), + LogicalPlan::Explain(_) + | LogicalPlan::Analyze(_) + | LogicalPlan::Extension(_) + | LogicalPlan::Prepare(_) + | LogicalPlan::Ddl(_) + | LogicalPlan::Copy(_) + | LogicalPlan::DescribeTable(_) + | LogicalPlan::RecursiveQuery(_) + | LogicalPlan::Unnest(_) => not_impl_err!("Unsupported plan: {plan:?}"), + } + } + + fn select_to_sql(&self, plan: &LogicalPlan) -> Result { + let mut query_builder = QueryBuilder::default(); + let mut select_builder = SelectBuilder::default(); + select_builder.push_from(TableWithJoinsBuilder::default()); + let mut relation_builder = RelationBuilder::default(); + self.select_to_sql_recursively( + plan, + &mut query_builder, + &mut select_builder, + &mut relation_builder, + )?; + + let mut twj = select_builder.pop_from().unwrap(); + twj.relation(relation_builder); + select_builder.push_from(twj); + + let body = ast::SetExpr::Select(Box::new(select_builder.build()?)); + let query = query_builder.body(Box::new(body)).build()?; + + Ok(ast::Statement::Query(Box::new(query))) + } + + fn select_to_sql_recursively( + &self, + plan: &LogicalPlan, + query: &mut QueryBuilder, + select: &mut SelectBuilder, + relation: &mut RelationBuilder, + ) -> Result<()> { + match plan { + LogicalPlan::TableScan(scan) => { + let mut builder = TableRelationBuilder::default(); + builder.name(ast::ObjectName(vec![ + self.new_ident(scan.table_name.table().to_string()) + ])); + relation.table(builder); + + Ok(()) + } + LogicalPlan::Projection(p) => { + let items = p + .expr + .iter() + .map(|e| self.select_item_to_sql(e)) + .collect::>>()?; + select.projection(items); + + self.select_to_sql_recursively(p.input.as_ref(), query, select, relation) + } + LogicalPlan::Filter(filter) => { + let filter_expr = self.expr_to_sql(&filter.predicate)?; + + select.selection(Some(filter_expr)); + + self.select_to_sql_recursively( + filter.input.as_ref(), + query, + select, + relation, + ) + } + LogicalPlan::Limit(limit) => { + if let Some(fetch) = limit.fetch { + query.limit(Some(ast::Expr::Value(ast::Value::Number( + fetch.to_string(), + false, + )))); + } + + self.select_to_sql_recursively( + limit.input.as_ref(), + query, + select, + relation, + ) + } + LogicalPlan::Sort(sort) => { + query.order_by(self.sort_to_sql(sort.expr.clone())?); + + self.select_to_sql_recursively( + sort.input.as_ref(), + query, + select, + relation, + ) + } + LogicalPlan::Aggregate(_agg) => { + not_impl_err!("Unsupported operator: {plan:?}") + } + LogicalPlan::Distinct(_distinct) => { + not_impl_err!("Unsupported operator: {plan:?}") + } + LogicalPlan::Join(join) => { + match join.join_constraint { + JoinConstraint::On => {} + JoinConstraint::Using => { + return not_impl_err!( + "Unsupported join constraint: {:?}", + join.join_constraint + ) + } + } + + // parse filter if exists + let join_filter = match &join.filter { + Some(filter) => Some(self.expr_to_sql(filter)?), + None => None, + }; + + // map join.on to `l.a = r.a AND l.b = r.b AND ...` + let eq_op = ast::BinaryOperator::Eq; + let join_on = self.join_conditions_to_sql(&join.on, eq_op)?; + + // Merge `join_on` and `join_filter` + let join_expr = match (join_filter, join_on) { + (Some(filter), Some(on)) => Some(self.and_op_to_sql(filter, on)), + (Some(filter), None) => Some(filter), + (None, Some(on)) => Some(on), + (None, None) => None, + }; + let join_constraint = match join_expr { + Some(expr) => ast::JoinConstraint::On(expr), + None => ast::JoinConstraint::None, + }; + + let mut right_relation = RelationBuilder::default(); + + self.select_to_sql_recursively( + join.left.as_ref(), + query, + select, + relation, + )?; + self.select_to_sql_recursively( + join.right.as_ref(), + query, + select, + &mut right_relation, + )?; + + let ast_join = ast::Join { + relation: right_relation.build()?, + join_operator: self + .join_operator_to_sql(join.join_type, join_constraint), + }; + let mut from = select.pop_from().unwrap(); + from.push_join(ast_join); + select.push_from(from); + + Ok(()) + } + LogicalPlan::SubqueryAlias(plan_alias) => { + // Handle bottom-up to allocate relation + self.select_to_sql_recursively( + plan_alias.input.as_ref(), + query, + select, + relation, + )?; + + relation.alias(Some( + self.new_table_alias(plan_alias.alias.table().to_string()), + )); + + Ok(()) + } + LogicalPlan::Union(_union) => { + not_impl_err!("Unsupported operator: {plan:?}") + } + LogicalPlan::Window(_window) => { + not_impl_err!("Unsupported operator: {plan:?}") + } + LogicalPlan::Extension(_) => not_impl_err!("Unsupported operator: {plan:?}"), + _ => not_impl_err!("Unsupported operator: {plan:?}"), + } + } + + fn select_item_to_sql(&self, expr: &Expr) -> Result { + match expr { + Expr::Alias(Alias { expr, name, .. }) => { + let inner = self.expr_to_sql(expr)?; + + Ok(ast::SelectItem::ExprWithAlias { + expr: inner, + alias: self.new_ident(name.to_string()), + }) + } + _ => { + let inner = self.expr_to_sql(expr)?; + + Ok(ast::SelectItem::UnnamedExpr(inner)) + } + } + } + + fn sort_to_sql(&self, sort_exprs: Vec) -> Result> { + sort_exprs + .iter() + .map(|expr: &Expr| match expr { + Expr::Sort(sort_expr) => { + let col = self.expr_to_sql(&sort_expr.expr)?; + Ok(ast::OrderByExpr { + asc: Some(sort_expr.asc), + expr: col, + nulls_first: Some(sort_expr.nulls_first), + }) + } + _ => plan_err!("Expecting Sort expr"), + }) + .collect::>>() + } + + fn join_operator_to_sql( + &self, + join_type: JoinType, + constraint: ast::JoinConstraint, + ) -> ast::JoinOperator { + match join_type { + JoinType::Inner => ast::JoinOperator::Inner(constraint), + JoinType::Left => ast::JoinOperator::LeftOuter(constraint), + JoinType::Right => ast::JoinOperator::RightOuter(constraint), + JoinType::Full => ast::JoinOperator::FullOuter(constraint), + JoinType::LeftAnti => ast::JoinOperator::LeftAnti(constraint), + JoinType::LeftSemi => ast::JoinOperator::LeftSemi(constraint), + JoinType::RightAnti => ast::JoinOperator::RightAnti(constraint), + JoinType::RightSemi => ast::JoinOperator::RightSemi(constraint), + } + } + + fn join_conditions_to_sql( + &self, + join_conditions: &Vec<(Expr, Expr)>, + eq_op: ast::BinaryOperator, + ) -> Result> { + // Only support AND conjunction for each binary expression in join conditions + let mut exprs: Vec = vec![]; + for (left, right) in join_conditions { + // Parse left + let l = self.expr_to_sql(left)?; + // Parse right + let r = self.expr_to_sql(right)?; + // AND with existing expression + exprs.push(self.binary_op_to_sql(l, r, eq_op.clone())); + } + let join_expr: Option = + exprs.into_iter().reduce(|r, l| self.and_op_to_sql(r, l)); + Ok(join_expr) + } + + fn and_op_to_sql(&self, lhs: ast::Expr, rhs: ast::Expr) -> ast::Expr { + self.binary_op_to_sql(lhs, rhs, ast::BinaryOperator::And) + } + + fn new_table_alias(&self, alias: String) -> ast::TableAlias { + ast::TableAlias { + name: self.new_ident(alias), + columns: Vec::new(), + } + } + + fn dml_to_sql(&self, plan: &LogicalPlan) -> Result { + not_impl_err!("Unsupported plan: {plan:?}") + } +} + +impl From for DataFusionError { + fn from(e: BuilderError) -> Self { + DataFusionError::External(Box::new(e)) + } +} diff --git a/datafusion/sql/tests/sql_integration.rs b/datafusion/sql/tests/sql_integration.rs index 6681c3d02564..fdf7ab8c3d28 100644 --- a/datafusion/sql/tests/sql_integration.rs +++ b/datafusion/sql/tests/sql_integration.rs @@ -22,12 +22,14 @@ use std::{sync::Arc, vec}; use arrow_schema::TimeUnit::Nanosecond; use arrow_schema::*; +use datafusion_sql::planner::PlannerContext; +use datafusion_sql::unparser::{expr_to_sql, plan_to_sql}; use sqlparser::dialect::{Dialect, GenericDialect, HiveDialect, MySqlDialect}; use datafusion_common::{ config::ConfigOptions, DataFusionError, Result, ScalarValue, TableReference, }; -use datafusion_common::{plan_err, ParamValues}; +use datafusion_common::{plan_err, DFSchema, ParamValues}; use datafusion_expr::{ logical_plan::{LogicalPlan, Prepare}, AggregateUDF, ColumnarValue, ScalarUDF, ScalarUDFImpl, Signature, TableSource, @@ -39,6 +41,7 @@ use datafusion_sql::{ }; use rstest::rstest; +use sqlparser::parser::Parser; #[test] fn parse_decimals() { @@ -4487,6 +4490,87 @@ impl TableSource for EmptyTable { } } +#[test] +fn roundtrip_expr() { + let tests: Vec<(TableReference, &str, &str)> = vec![ + (TableReference::bare("person"), "age > 35", "age > 35"), + (TableReference::bare("person"), "id = '10'", "id = '10'"), + ]; + + let roundtrip = |table, sql: &str| -> Result { + let dialect = GenericDialect {}; + let sql_expr = Parser::new(&dialect).try_with_sql(sql)?.parse_expr()?; + + let context = MockContextProvider::default(); + let schema = context.get_table_source(table)?.schema(); + let df_schema = DFSchema::try_from(schema.as_ref().clone())?; + let sql_to_rel = SqlToRel::new(&context); + let expr = + sql_to_rel.sql_to_expr(sql_expr, &df_schema, &mut PlannerContext::new())?; + + let ast = expr_to_sql(&expr)?; + + Ok(format!("{}", ast)) + }; + + for (table, query, expected) in tests { + let actual = roundtrip(table, query).unwrap(); + assert_eq!(actual, expected); + } +} + +#[test] +fn roundtrip_statement() { + let tests: Vec<(&str, &str)> = vec![ + ( + "select ta.j1_id from j1 ta;", + r#"SELECT ta.j1_id FROM j1 AS ta"#, + ), + ( + "select ta.j1_id from j1 ta order by ta.j1_id;", + r#"SELECT ta.j1_id FROM j1 AS ta ORDER BY ta.j1_id ASC NULLS LAST"#, + ), + ( + "select * from j1 ta order by ta.j1_id, ta.j1_string desc;", + r#"SELECT ta.j1_id, ta.j1_string FROM j1 AS ta ORDER BY ta.j1_id ASC NULLS LAST, ta.j1_string DESC NULLS FIRST"#, + ), + ( + "select * from j1 limit 10;", + r#"SELECT j1.j1_id, j1.j1_string FROM j1 LIMIT 10"#, + ), + ( + "select ta.j1_id from j1 ta where ta.j1_id > 1;", + r#"SELECT ta.j1_id FROM j1 AS ta WHERE ta.j1_id > 1"#, + ), + ( + "select ta.j1_id, tb.j2_string from j1 ta join j2 tb on ta.j1_id = tb.j2_id;", + r#"SELECT ta.j1_id, tb.j2_string FROM j1 AS ta JOIN j2 AS tb ON ta.j1_id = tb.j2_id"#, + ), + ( + "select ta.j1_id, tb.j2_string, tc.j3_string from j1 ta join j2 tb on ta.j1_id = tb.j2_id join j3 tc on ta.j1_id = tc.j3_id;", + r#"SELECT ta.j1_id, tb.j2_string, tc.j3_string FROM j1 AS ta JOIN j2 AS tb ON ta.j1_id = tb.j2_id JOIN j3 AS tc ON ta.j1_id = tc.j3_id"#, + ), + ]; + + let roundtrip = |sql: &str| -> Result { + let dialect = GenericDialect {}; + let statement = Parser::new(&dialect).try_with_sql(sql)?.parse_statement()?; + + let context = MockContextProvider::default(); + let sql_to_rel = SqlToRel::new(&context); + let plan = sql_to_rel.sql_statement_to_plan(statement)?; + + let ast = plan_to_sql(&plan)?; + + Ok(format!("{}", ast)) + }; + + for (query, expected) in tests { + let actual = roundtrip(query).unwrap(); + assert_eq!(actual, expected); + } +} + #[cfg(test)] #[ctor::ctor] fn init() {