Skip to content

Commit

Permalink
Feat(tsql): improve transpilation of temp table DDLs (#1958)
Browse files Browse the repository at this point in the history
* Feat(tsql): improve transpilation of temp table DDLs

* PR comment
  • Loading branch information
georgesittas authored Jul 25, 2023
1 parent 59847f5 commit 8448141
Show file tree
Hide file tree
Showing 5 changed files with 121 additions and 12 deletions.
10 changes: 7 additions & 3 deletions sqlglot/dialects/spark2.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,13 @@ def _create_sql(self: Hive.Generator, e: exp.Create) -> str:
kind = e.args["kind"]
properties = e.args.get("properties")

if kind.upper() == "TABLE" and any(
isinstance(prop, exp.TemporaryProperty)
for prop in (properties.expressions if properties else [])
if (
kind.upper() == "TABLE"
and e.expression
and any(
isinstance(prop, exp.TemporaryProperty)
for prop in (properties.expressions if properties else [])
)
):
return f"CREATE TEMPORARY VIEW {self.sql(e, 'this')} AS {self.sql(e, 'expression')}"
return create_with_partitions_sql(self, e)
Expand Down
61 changes: 57 additions & 4 deletions sqlglot/dialects/tsql.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,10 +307,6 @@ class Tokenizer(tokens.Tokenizer):
"SYSTEM_USER": TokenType.CURRENT_USER,
}

# TSQL allows @, # to appear as a variable/identifier prefix
SINGLE_TOKENS = tokens.Tokenizer.SINGLE_TOKENS.copy()
SINGLE_TOKENS.pop("#")

class Parser(parser.Parser):
FUNCTIONS = {
**parser.Parser.FUNCTIONS,
Expand Down Expand Up @@ -518,6 +514,36 @@ def _parse_user_defined_function(
expressions = self._parse_csv(self._parse_function_parameter)
return self.expression(exp.UserDefinedFunction, this=this, expressions=expressions)

def _parse_id_var(
self,
any_token: bool = True,
tokens: t.Optional[t.Collection[TokenType]] = None,
) -> t.Optional[exp.Expression]:
is_temporary = self._match(TokenType.HASH)
is_global = is_temporary and self._match(TokenType.HASH)

this = super()._parse_id_var(any_token=any_token, tokens=tokens)
if this:
if is_global:
this.set("global", True)
elif is_temporary:
this.set("temporary", True)

return this

def _parse_create(self) -> exp.Create | exp.Command:
create = super()._parse_create()

if isinstance(create, exp.Create):
table = create.this.this if isinstance(create.this, exp.Schema) else create.this
if isinstance(table, exp.Table) and table.this.args.get("temporary"):
if not create.args.get("properties"):
create.set("properties", exp.Properties(expressions=[]))

create.args["properties"].append("expressions", exp.TemporaryProperty())

return create

class Generator(generator.Generator):
LOCKING_READS_SUPPORTED = True
LIMIT_IS_TOP = True
Expand Down Expand Up @@ -552,6 +578,7 @@ class Generator(generator.Generator):
exp.Literal.string(f"SHA2_{e.args.get('length', 256)}"),
e.this,
),
exp.TemporaryProperty: lambda self, e: "",
exp.TimeToStr: _format_sql,
}

Expand All @@ -564,6 +591,22 @@ class Generator(generator.Generator):

LIMIT_FETCH = "FETCH"

def createable_sql(
self,
expression: exp.Create,
locations: dict[exp.Properties.Location, list[exp.Property]],
) -> str:
sql = self.sql(expression, "this")
properties = expression.args.get("properties")

if sql[:1] != "#" and any(
isinstance(prop, exp.TemporaryProperty)
for prop in (properties.expressions if properties else [])
):
sql = f"#{sql}"

return sql

def offset_sql(self, expression: exp.Offset) -> str:
return f"{super().offset_sql(expression)} ROWS"

Expand Down Expand Up @@ -616,3 +659,13 @@ def rollback_sql(self, expression: exp.Rollback) -> str:
this = self.sql(expression, "this")
this = f" {this}" if this else ""
return f"ROLLBACK TRANSACTION{this}"

def identifier_sql(self, expression: exp.Identifier) -> str:
identifier = super().identifier_sql(expression)

if expression.args.get("global"):
identifier = f"##{identifier}"
elif expression.args.get("temporary"):
identifier = f"#{identifier}"

return identifier
2 changes: 1 addition & 1 deletion sqlglot/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1496,7 +1496,7 @@ class JoinHint(Expression):


class Identifier(Expression):
arg_types = {"this": True, "quoted": False}
arg_types = {"this": True, "quoted": False, "global": False, "temporary": False}

@property
def quoted(self) -> bool:
Expand Down
10 changes: 7 additions & 3 deletions sqlglot/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -988,8 +988,9 @@ def properties(
) -> str:
if properties.expressions:
expressions = self.expressions(properties, sep=sep, indent=False)
expressions = self.wrap(expressions) if wrapped else expressions
return f"{prefix}{' ' if prefix and prefix != ' ' else ''}{expressions}{suffix}"
if expressions:
expressions = self.wrap(expressions) if wrapped else expressions
return f"{prefix}{' ' if prefix and prefix != ' ' else ''}{expressions}{suffix}"
return ""

def with_properties(self, properties: exp.Properties) -> str:
Expand Down Expand Up @@ -2415,7 +2416,7 @@ def expressions(
return ""

if flat:
return sep.join(self.sql(e) for e in expressions)
return sep.join(sql for sql in (self.sql(e) for e in expressions) if sql)

num_sqls = len(expressions)

Expand All @@ -2426,6 +2427,9 @@ def expressions(
result_sqls = []
for i, e in enumerate(expressions):
sql = self.sql(e, comment=False)
if not sql:
continue

comments = self.maybe_comment("", e) if isinstance(e, exp.Expression) else ""

if self.pretty:
Expand Down
50 changes: 49 additions & 1 deletion tests/dialects/test_tsql.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,7 +393,30 @@ def test_ddl(self):
self.validate_all(
"CREATE TABLE #mytemp (a INTEGER, b CHAR(2), c TIME(4), d FLOAT(24))",
write={
"tsql": "CREATE TABLE #mytemp (a INTEGER, b CHAR(2), c TIMESTAMP(4), d FLOAT(24))"
"spark": "CREATE TEMPORARY TABLE mytemp (a INT, b CHAR(2), c TIMESTAMP, d FLOAT)",
"tsql": "CREATE TABLE #mytemp (a INTEGER, b CHAR(2), c TIMESTAMP(4), d FLOAT(24))",
},
)
self.validate_all(
"CREATE TABLE #mytemptable (a INTEGER)",
read={
"duckdb": "CREATE TEMPORARY TABLE mytemptable (a INT)",
},
write={
"tsql": "CREATE TABLE #mytemptable (a INTEGER)",
"snowflake": "CREATE TEMPORARY TABLE mytemptable (a INT)",
"duckdb": "CREATE TEMPORARY TABLE mytemptable (a INT)",
"oracle": "CREATE TEMPORARY TABLE mytemptable (a NUMBER)",
},
)
self.validate_all(
"CREATE TABLE #mytemptable AS SELECT a FROM Source_Table",
write={
"duckdb": "CREATE TEMPORARY TABLE mytemptable AS SELECT a FROM Source_Table",
"oracle": "CREATE TEMPORARY TABLE mytemptable AS SELECT a FROM Source_Table",
"snowflake": "CREATE TEMPORARY TABLE mytemptable AS SELECT a FROM Source_Table",
"spark": "CREATE TEMPORARY VIEW mytemptable AS SELECT a FROM Source_Table",
"tsql": "CREATE TABLE #mytemptable AS SELECT a FROM Source_Table",
},
)

Expand Down Expand Up @@ -943,8 +966,15 @@ def test_identifier_prefixes(self):
expr = parse_one("#x", read="tsql")
self.assertIsInstance(expr, exp.Column)
self.assertIsInstance(expr.this, exp.Identifier)
self.assertTrue(expr.this.args.get("temporary"))
self.assertEqual(expr.sql("tsql"), "#x")

expr = parse_one("##x", read="tsql")
self.assertIsInstance(expr, exp.Column)
self.assertIsInstance(expr.this, exp.Identifier)
self.assertTrue(expr.this.args.get("global"))
self.assertEqual(expr.sql("tsql"), "##x")

expr = parse_one("@x", read="tsql")
self.assertIsInstance(expr, exp.Parameter)
self.assertIsInstance(expr.this, exp.Var)
Expand All @@ -955,6 +985,24 @@ def test_identifier_prefixes(self):
self.assertIsInstance(table.this, exp.Parameter)
self.assertIsInstance(table.this.this, exp.Var)

def test_temp_table(self):
self.validate_all(
"SELECT * FROM #mytemptable",
write={
"duckdb": "SELECT * FROM mytemptable",
"spark": "SELECT * FROM mytemptable",
"tsql": "SELECT * FROM #mytemptable",
},
)
self.validate_all(
"SELECT * FROM ##mytemptable",
write={
"duckdb": "SELECT * FROM mytemptable",
"spark": "SELECT * FROM mytemptable",
"tsql": "SELECT * FROM ##mytemptable",
},
)

def test_system_time(self):
self.validate_all(
"SELECT [x] FROM [a].[b] FOR SYSTEM_TIME AS OF 'foo'",
Expand Down

0 comments on commit 8448141

Please sign in to comment.