Skip to content

Commit

Permalink
Feat(tsql): insert output closes #1901
Browse files Browse the repository at this point in the history
  • Loading branch information
tobymao committed Jul 8, 2023
1 parent 3b215ad commit d68f844
Show file tree
Hide file tree
Showing 5 changed files with 50 additions and 12 deletions.
7 changes: 7 additions & 0 deletions sqlglot/dialects/tsql.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,7 @@ class Tokenizer(tokens.Tokenizer):
"UNIQUEIDENTIFIER": TokenType.UNIQUEIDENTIFIER,
"VARCHAR(MAX)": TokenType.TEXT,
"XML": TokenType.XML,
"OUTPUT": TokenType.RETURNING,
"SYSTEM_USER": TokenType.CURRENT_USER,
}

Expand Down Expand Up @@ -469,6 +470,7 @@ class Generator(generator.Generator):
LOCKING_READS_SUPPORTED = True
LIMIT_IS_TOP = True
QUERY_HINTS = False
RETURNING_END = False

TYPE_MAPPING = {
**generator.Generator.TYPE_MAPPING,
Expand Down Expand Up @@ -532,3 +534,8 @@ def returnsproperty_sql(self, expression: exp.ReturnsProperty) -> str:
table = expression.args.get("table")
table = f"{table} " if table else ""
return f"RETURNS {table}{self.sql(expression, 'this')}"

def returning_sql(self, expression: exp.Returning) -> str:
into = self.sql(expression, "into")
into = self.seg(f"INTO {into}") if into else ""
return f"{self.seg('OUTPUT')} {self.expressions(expression, flat=True)}{into}"
2 changes: 1 addition & 1 deletion sqlglot/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1576,7 +1576,7 @@ class OnConflict(Expression):


class Returning(Expression):
arg_types = {"expressions": True}
arg_types = {"expressions": True, "into": False}


# https://dev.mysql.com/doc/refman/8.0/en/charset-introducer.html
Expand Down
24 changes: 20 additions & 4 deletions sqlglot/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,9 @@ class Generator:
# Whether or not to generate the limit as TOP <value> instead of LIMIT <value>
LIMIT_IS_TOP = False

# Whether or not to generate INSERT INTO ... RETURNING or INSERT INTO RETURNING ...
RETURNING_END = True

# https://cloud.google.com/bigquery/docs/reference/standard-sql/query-syntax
SELECT_KINDS: t.Tuple[str, ...] = ("STRUCT", "VALUE")

Expand Down Expand Up @@ -836,8 +839,11 @@ def delete_sql(self, expression: exp.Delete) -> str:
limit = self.sql(expression, "limit")
tables = self.expressions(expression, key="tables")
tables = f" {tables}" if tables else ""
sql = f"DELETE{tables}{this}{using}{where}{returning}{limit}"
return self.prepend_ctes(expression, sql)
if self.RETURNING_END:
expression_sql = f"{this}{using}{where}{returning}{limit}"
else:
expression_sql = f"{returning}{this}{using}{where}{limit}"
return self.prepend_ctes(expression, f"DELETE{tables}{expression_sql}")

def drop_sql(self, expression: exp.Drop) -> str:
this = self.sql(expression, "this")
Expand Down Expand Up @@ -1134,7 +1140,13 @@ def insert_sql(self, expression: exp.Insert) -> str:
expression_sql = f"{self.sep()}{self.sql(expression, 'expression')}"
conflict = self.sql(expression, "conflict")
returning = self.sql(expression, "returning")
sql = f"INSERT{alternative}{ignore}{this}{exists}{partition_sql}{where}{expression_sql}{conflict}{returning}"

if self.RETURNING_END:
expression_sql = f"{expression_sql}{conflict}{returning}"
else:
expression_sql = f"{returning}{expression_sql}{conflict}"

sql = f"INSERT{alternative}{ignore}{this}{exists}{partition_sql}{where}{expression_sql}"
return self.prepend_ctes(expression, sql)

def intersect_sql(self, expression: exp.Intersect) -> str:
Expand Down Expand Up @@ -1276,7 +1288,11 @@ def update_sql(self, expression: exp.Update) -> str:
where_sql = self.sql(expression, "where")
returning = self.sql(expression, "returning")
limit = self.sql(expression, "limit")
sql = f"UPDATE {this} SET {set_sql}{from_sql}{where_sql}{returning}{limit}"
if self.RETURNING_END:
expression_sql = f"{from_sql}{where_sql}{returning}{limit}"
else:
expression_sql = f"{returning}{from_sql}{where_sql}{limit}"
sql = f"UPDATE {this} SET {set_sql}{expression_sql}"
return self.prepend_ctes(expression, sql)

def values_sql(self, expression: exp.Values) -> str:
Expand Down
24 changes: 17 additions & 7 deletions sqlglot/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -1708,6 +1708,8 @@ def _parse_insert(self) -> exp.Insert:
self._match(TokenType.TABLE)
this = self._parse_table(schema=True)

returning = self._parse_returning()

return self.expression(
exp.Insert,
this=this,
Expand All @@ -1717,7 +1719,7 @@ def _parse_insert(self) -> exp.Insert:
and self._parse_conjunction(),
expression=self._parse_ddl_select(),
conflict=self._parse_on_conflict(),
returning=self._parse_returning(),
returning=returning or self._parse_returning(),
overwrite=overwrite,
alternative=alternative,
ignore=ignore,
Expand Down Expand Up @@ -1761,8 +1763,11 @@ def _parse_on_conflict(self) -> t.Optional[exp.OnConflict]:
def _parse_returning(self) -> t.Optional[exp.Returning]:
if not self._match(TokenType.RETURNING):
return None

return self.expression(exp.Returning, expressions=self._parse_csv(self._parse_column))
return self.expression(
exp.Returning,
expressions=self._parse_csv(self._parse_column),
into=self._match(TokenType.INTO) and self._parse_table_part(),
)

def _parse_row(self) -> t.Optional[exp.RowFormatSerdeProperty | exp.RowFormatDelimitedProperty]:
if not self._match(TokenType.FORMAT):
Expand Down Expand Up @@ -1824,25 +1829,30 @@ def _parse_delete(self) -> exp.Delete:
if not self._match(TokenType.FROM, advance=False):
tables = self._parse_csv(self._parse_table) or None

returning = self._parse_returning()

return self.expression(
exp.Delete,
tables=tables,
this=self._match(TokenType.FROM) and self._parse_table(joins=True),
using=self._match(TokenType.USING) and self._parse_table(joins=True),
where=self._parse_where(),
returning=self._parse_returning(),
returning=returning or self._parse_returning(),
limit=self._parse_limit(),
)

def _parse_update(self) -> exp.Update:
this = self._parse_table(alias_tokens=self.UPDATE_ALIAS_TOKENS)
expressions = self._match(TokenType.SET) and self._parse_csv(self._parse_equality)
returning = self._parse_returning()
return self.expression(
exp.Update,
**{ # type: ignore
"this": self._parse_table(alias_tokens=self.UPDATE_ALIAS_TOKENS),
"expressions": self._match(TokenType.SET) and self._parse_csv(self._parse_equality),
"this": this,
"expressions": expressions,
"from": self._parse_from(joins=True),
"where": self._parse_where(),
"returning": self._parse_returning(),
"returning": returning or self._parse_returning(),
"limit": self._parse_limit(),
},
)
Expand Down
5 changes: 5 additions & 0 deletions tests/dialects/test_tsql.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,11 @@ class TestTSQL(Validator):
dialect = "tsql"

def test_tsql(self):
self.validate_identity("UPDATE x SET y = 1 OUTPUT x.a, x.b INTO @y FROM y")
self.validate_identity("UPDATE x SET y = 1 OUTPUT x.a, x.b FROM y")
self.validate_identity("INSERT INTO x (y) OUTPUT x.a, x.b INTO l SELECT * FROM z")
self.validate_identity("INSERT INTO x (y) OUTPUT x.a, x.b SELECT * FROM z")
self.validate_identity("DELETE x OUTPUT x.a FROM z")
self.validate_identity("SELECT * FROM t WITH (TABLOCK, INDEX(myindex))")
self.validate_identity("SELECT * FROM t WITH (NOWAIT)")
self.validate_identity("SELECT CASE WHEN a > 1 THEN b END")
Expand Down

0 comments on commit d68f844

Please sign in to comment.