Skip to content

Commit

Permalink
Feat(databricks): add support for REPLACE WHERE in INSERT statement (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
georgesittas authored Jun 23, 2023
1 parent 4255faf commit 088e745
Show file tree
Hide file tree
Showing 5 changed files with 29 additions and 11 deletions.
1 change: 1 addition & 0 deletions sqlglot/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1528,6 +1528,7 @@ class Insert(Expression):
"exists": False,
"partition": False,
"alternative": False,
"where": False,
}

def with_(
Expand Down
21 changes: 11 additions & 10 deletions sqlglot/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -912,7 +912,7 @@ def national_sql(self, expression: exp.National, prefix: str = "N") -> str:
return f"{prefix}{string}"

def partition_sql(self, expression: exp.Partition) -> str:
return f"PARTITION({self.expressions(expression)})"
return f"PARTITION({self.expressions(expression, flat=True)})"

def properties_sql(self, expression: exp.Properties) -> str:
root_properties = []
Expand Down Expand Up @@ -1102,23 +1102,24 @@ def insert_sql(self, expression: exp.Insert) -> str:
overwrite = expression.args.get("overwrite")

if isinstance(expression.this, exp.Directory):
this = "OVERWRITE " if overwrite else "INTO "
this = " OVERWRITE" if overwrite else " INTO"
else:
this = "OVERWRITE TABLE " if overwrite else "INTO "
this = " OVERWRITE TABLE" if overwrite else " INTO"

alternative = expression.args.get("alternative")
alternative = f" OR {alternative} " if alternative else " "
this = f"{this}{self.sql(expression, 'this')}"
alternative = f" OR {alternative}" if alternative else ""
this = f"{this} {self.sql(expression, 'this')}"

exists = " IF EXISTS " if expression.args.get("exists") else " "
exists = " IF EXISTS" if expression.args.get("exists") else ""
partition_sql = (
self.sql(expression, "partition") if expression.args.get("partition") else ""
f" {self.sql(expression, 'partition')}" if expression.args.get("partition") else ""
)
expression_sql = self.sql(expression, "expression")
where = self.sql(expression, "where")
where = f"{self.sep()}REPLACE WHERE {where}" if where else ""
expression_sql = f"{self.sep()}{self.sql(expression, 'expression')}"
conflict = self.sql(expression, "conflict")
returning = self.sql(expression, "returning")
sep = self.sep() if partition_sql else ""
sql = f"INSERT{alternative}{this}{exists}{partition_sql}{sep}{expression_sql}{conflict}{returning}"
sql = f"INSERT{alternative}{this}{exists}{partition_sql}{where}{expression_sql}{conflict}{returning}"
return self.prepend_ctes(expression, sql)

def intersect_sql(self, expression: exp.Intersect) -> str:
Expand Down
2 changes: 2 additions & 0 deletions sqlglot/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -1677,6 +1677,8 @@ def _parse_insert(self) -> exp.Insert:
this=this,
exists=self._parse_exists(),
partition=self._parse_partition(),
where=self._match_pair(TokenType.REPLACE, TokenType.WHERE)
and self._parse_conjunction(),
expression=self._parse_ddl_select(),
conflict=self._parse_on_conflict(),
returning=self._parse_returning(),
Expand Down
1 change: 1 addition & 0 deletions tests/dialects/test_databricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ class TestDatabricks(Validator):
dialect = "databricks"

def test_databricks(self):
self.validate_identity("INSERT INTO a REPLACE WHERE cond VALUES (1), (2)")
self.validate_identity("SELECT c1 : price")
self.validate_identity("CREATE FUNCTION a.b(x INT) RETURNS INT RETURN x + 1")
self.validate_identity("CREATE FUNCTION a AS b")
Expand Down
15 changes: 14 additions & 1 deletion tests/fixtures/pretty.sql
Original file line number Diff line number Diff line change
Expand Up @@ -178,10 +178,23 @@ FROM (
) AS x;

INSERT OVERWRITE TABLE x VALUES (1, 2.0, '3.0'), (4, 5.0, '6.0');
INSERT OVERWRITE TABLE x VALUES
INSERT OVERWRITE TABLE x
VALUES
(1, 2.0, '3.0'),
(4, 5.0, '6.0');

INSERT INTO TABLE foo REPLACE WHERE cond SELECT * FROM bar;
INSERT INTO foo
REPLACE WHERE cond
SELECT
*
FROM bar;

INSERT OVERWRITE TABLE zipcodes PARTITION(state = '0') VALUES (896, 'US', 'TAMPA', 33607);
INSERT OVERWRITE TABLE zipcodes PARTITION(state = '0')
VALUES
(896, 'US', 'TAMPA', 33607);

WITH regional_sales AS (
SELECT region, SUM(amount) AS total_sales
FROM orders
Expand Down

0 comments on commit 088e745

Please sign in to comment.