Skip to content

Commit

Permalink
refactor the way of extracting table name.
Browse files Browse the repository at this point in the history
  • Loading branch information
Rachelint committed Dec 7, 2022
1 parent e7e78e1 commit 15bcd41
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 113 deletions.
131 changes: 22 additions & 109 deletions sql/src/parser.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use log::debug;
use paste::paste;
use sqlparser::{
ast::{
ColumnDef, ColumnOption, ColumnOptionDef, Expr, Ident, Join, ObjectName, SetExpr,
ColumnDef, ColumnOption, ColumnOptionDef, Expr, Ident, ObjectName, SetExpr,
Statement as SqlStatement, TableConstraint, TableFactor, TableWithJoins,
},
dialect::{keywords::Keyword, Dialect, MySqlDialect},
Expand Down Expand Up @@ -182,9 +182,9 @@ impl<'a> Parser<'a> {
}
_ => {
// use the native parser
Ok(Statement::Standard(Box::new(maybe_normalize_table_name(
self.parser.parse_statement()?,
))))
let mut statement = self.parser.parse_statement()?;
maybe_normalize_table_name(&mut statement);
Ok(Statement::Standard(Box::new(statement)))
}
}
}
Expand Down Expand Up @@ -550,121 +550,34 @@ fn build_timestamp_key_constraint(col_defs: &[ColumnDef], constraints: &mut Vec<
/// case-sensitive in sql.
// TODO: maybe other items(such as: alias, column name) need to be normalized,
// too.
pub fn maybe_normalize_table_name(statement: SqlStatement) -> SqlStatement {
let original_statement = statement.clone();
pub fn maybe_normalize_table_name(statement: &mut SqlStatement) {
if let SqlStatement::Query(query) = statement {
let sqlparser::ast::Query {
with,
body,
order_by,
limit,
offset,
fetch,
lock,
} = *query;

let body = if let SetExpr::Select(select) = *body {
let sqlparser::ast::Select {
distinct,
top,
projection,
into,
from,
lateral_views,
selection,
group_by,
cluster_by,
distribute_by,
sort_by,
having,
qualify,
} = *select;

let from: Vec<_> = from.into_iter().map(convert_one_from).collect();

Box::new(SetExpr::Select(Box::new(sqlparser::ast::Select {
distinct,
top,
projection,
into,
from,
lateral_views,
selection,
group_by,
cluster_by,
distribute_by,
sort_by,
having,
qualify,
})))
} else {
return original_statement;
};

SqlStatement::Query(Box::new(sqlparser::ast::Query {
with,
body,
order_by,
limit,
offset,
fetch,
lock,
}))
} else {
original_statement
if let SetExpr::Select(select) = query.body.as_mut() {
select.from.iter_mut().for_each(maybe_convert_one_from);
}
}
}

fn convert_one_from(one_from: TableWithJoins) -> TableWithJoins {
fn maybe_convert_one_from(one_from: &mut TableWithJoins) {
let TableWithJoins { relation, joins } = one_from;

let relation = convert_relation(relation);
let joins: Vec<_> = joins
.into_iter()
.map(|join| Join {
relation: convert_relation(join.relation),
join_operator: join.join_operator,
})
.collect();

TableWithJoins { relation, joins }
maybe_convert_relation(relation);
joins.iter_mut().for_each(|join| {
maybe_convert_relation(&mut join.relation);
});
}

fn convert_relation(relation: TableFactor) -> TableFactor {
if let TableFactor::Table {
name,
alias,
args,
with_hints,
} = relation.clone()
{
let new_name = maybe_convert_table_name(name);

TableFactor::Table {
name: new_name,
alias,
args,
with_hints,
}
} else {
relation
fn maybe_convert_relation(relation: &mut TableFactor) {
if let TableFactor::Table { name, .. } = relation {
maybe_convert_table_name(name);
}
}

fn maybe_convert_table_name(object_name: ObjectName) -> ObjectName {
let quoteds: Vec<_> = object_name
.0
.into_iter()
.map(|id| {
if id.quote_style.is_none() {
Ident::with_quote('`', id.value)
} else {
id
}
})
.collect();

ObjectName(quoteds)
fn maybe_convert_table_name(object_name: &mut ObjectName) {
object_name.0.iter_mut().for_each(|id| {
if id.quote_style.is_none() {
let _ = std::mem::replace(id, Ident::with_quote('`', id.value.clone()));
}
})
}

#[cfg(test)]
Expand Down
8 changes: 4 additions & 4 deletions tests/cases/local/03_dml/case_insensitive.result
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,10 @@ SELECT
FROM
`case_insensitive_table1`;

ts,tsid,value1,
Timestamp(Timestamp(1)),Int64(0),Double(10.0),
Timestamp(Timestamp(2)),Int64(0),Double(20.0),
Timestamp(Timestamp(3)),Int64(0),Double(30.0),
tsid,ts,value1,
Int64(0),Timestamp(Timestamp(1)),Double(10.0),
Int64(0),Timestamp(Timestamp(2)),Double(20.0),
Int64(0),Timestamp(Timestamp(3)),Double(30.0),


SELECT
Expand Down

0 comments on commit 15bcd41

Please sign in to comment.