Skip to content

Commit

Permalink
refactor the way of extracting table name.
Browse files Browse the repository at this point in the history
  • Loading branch information
Rachelint committed Dec 7, 2022
1 parent 2cf3c3c commit 499e6ae
Showing 1 changed file with 24 additions and 109 deletions.
133 changes: 24 additions & 109 deletions sql/src/parser.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use log::debug;
use paste::paste;
use sqlparser::{
ast::{
ColumnDef, ColumnOption, ColumnOptionDef, Expr, Ident, Join, ObjectName, SetExpr,
ColumnDef, ColumnOption, ColumnOptionDef, Expr, Ident, ObjectName, SetExpr,
Statement as SqlStatement, TableConstraint, TableFactor, TableWithJoins,
},
dialect::{keywords::Keyword, Dialect, MySqlDialect},
Expand Down Expand Up @@ -182,9 +182,9 @@ impl<'a> Parser<'a> {
}
_ => {
// use the native parser
Ok(Statement::Standard(Box::new(maybe_normalize_table_name(
self.parser.parse_statement()?,
))))
let mut statement = self.parser.parse_statement()?;
maybe_normalize_table_name(&mut statement);
Ok(Statement::Standard(Box::new(statement)))
}
}
}
Expand Down Expand Up @@ -550,121 +550,36 @@ fn build_timestamp_key_constraint(col_defs: &[ColumnDef], constraints: &mut Vec<
/// case-sensitive in sql.
// TODO: maybe other items(such as: alias, column name) need to be normalized,
// too.
pub fn maybe_normalize_table_name(statement: SqlStatement) -> SqlStatement {
let original_statement = statement.clone();
pub fn maybe_normalize_table_name(statement: &mut SqlStatement) {
if let SqlStatement::Query(query) = statement {
let sqlparser::ast::Query {
with,
body,
order_by,
limit,
offset,
fetch,
lock,
} = *query;

let body = if let SetExpr::Select(select) = *body {
let sqlparser::ast::Select {
distinct,
top,
projection,
into,
from,
lateral_views,
selection,
group_by,
cluster_by,
distribute_by,
sort_by,
having,
qualify,
} = *select;

let from: Vec<_> = from.into_iter().map(convert_one_from).collect();

Box::new(SetExpr::Select(Box::new(sqlparser::ast::Select {
distinct,
top,
projection,
into,
from,
lateral_views,
selection,
group_by,
cluster_by,
distribute_by,
sort_by,
having,
qualify,
})))
} else {
return original_statement;
};

SqlStatement::Query(Box::new(sqlparser::ast::Query {
with,
body,
order_by,
limit,
offset,
fetch,
lock,
}))
} else {
original_statement
if let SetExpr::Select(select) = query.body.as_mut() {
for one_from in &mut select.from {
maybe_convert_one_from(one_from);
}
}
}
}

fn convert_one_from(one_from: TableWithJoins) -> TableWithJoins {
fn maybe_convert_one_from(one_from: &mut TableWithJoins) {
let TableWithJoins { relation, joins } = one_from;

let relation = convert_relation(relation);
let joins: Vec<_> = joins
.into_iter()
.map(|join| Join {
relation: convert_relation(join.relation),
join_operator: join.join_operator,
})
.collect();

TableWithJoins { relation, joins }
maybe_convert_relation(relation);
joins.iter_mut().for_each(|join| {
maybe_convert_relation(&mut join.relation);
});
}

fn convert_relation(relation: TableFactor) -> TableFactor {
if let TableFactor::Table {
name,
alias,
args,
with_hints,
} = relation.clone()
{
let new_name = maybe_convert_table_name(name);

TableFactor::Table {
name: new_name,
alias,
args,
with_hints,
}
} else {
relation
fn maybe_convert_relation(relation: &mut TableFactor) {
if let TableFactor::Table { name, .. } = relation {
maybe_convert_table_name(name);
}
}

fn maybe_convert_table_name(object_name: ObjectName) -> ObjectName {
let quoteds: Vec<_> = object_name
.0
.into_iter()
.map(|id| {
if id.quote_style.is_none() {
Ident::with_quote('`', id.value)
} else {
id
}
})
.collect();

ObjectName(quoteds)
fn maybe_convert_table_name(object_name: &mut ObjectName) {
object_name.0.iter_mut().for_each(|id| {
if id.quote_style.is_none() {
let _ = std::mem::replace(id, Ident::with_quote('`', id.value.clone()));
}
})
}

#[cfg(test)]
Expand Down

0 comments on commit 499e6ae

Please sign in to comment.