Skip to content

Commit

Permalink
Fix: create index with order closes #1692
Browse files Browse the repository at this point in the history
  • Loading branch information
tobymao committed May 26, 2023
1 parent 6cce5fc commit fbf5f47
Show file tree
Hide file tree
Showing 6 changed files with 51 additions and 63 deletions.
9 changes: 1 addition & 8 deletions sqlglot/dialects/hive.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,13 +147,6 @@ def _to_date_sql(self: generator.Generator, expression: exp.TsOrDsToDate) -> str
return f"TO_DATE({this})"


def _index_sql(self: generator.Generator, expression: exp.Index) -> str:
this = self.sql(expression, "this")
table = self.sql(expression, "table")
columns = self.sql(expression, "columns")
return f"{this} ON TABLE {table} {columns}"


class Hive(Dialect):
alias_post_tablesample = True

Expand Down Expand Up @@ -289,6 +282,7 @@ class Generator(generator.Generator):
TABLESAMPLE_SIZE_IS_PERCENT = True
JOIN_HINTS = False
TABLE_HINTS = False
INDEX_ON = "ON TABLE"

TYPE_MAPPING = {
**generator.Generator.TYPE_MAPPING,
Expand Down Expand Up @@ -325,7 +319,6 @@ class Generator(generator.Generator):
exp.FileFormatProperty: lambda self, e: f"STORED AS {self.sql(e, 'this') if isinstance(e.this, exp.InputOutputFormat) else e.name.upper()}",
exp.FromBase64: rename_func("UNBASE64"),
exp.If: if_sql,
exp.Index: _index_sql,
exp.ILike: no_ilike_sql,
exp.JSONExtract: rename_func("GET_JSON_OBJECT"),
exp.JSONExtractScalar: rename_func("GET_JSON_OBJECT"),
Expand Down
44 changes: 14 additions & 30 deletions sqlglot/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,9 @@ class Generator:
# The separator for grouping sets and rollups
GROUPINGS_SEP = ","

# The string used for creating index on a table
INDEX_ON = "ON"

TYPE_MAPPING = {
exp.DataType.Type.NCHAR: "CHAR",
exp.DataType.Type.NVARCHAR: "VARCHAR",
Expand Down Expand Up @@ -683,33 +686,9 @@ def create_sql(self, expression: exp.Create) -> str:
prefix=" ",
)

indexes = expression.args.get("indexes")
if indexes:
indexes_sql: t.List[str] = []
for index in indexes:
ind_unique = " UNIQUE" if index.args.get("unique") else ""
ind_primary = " PRIMARY" if index.args.get("primary") else ""
ind_amp = " AMP" if index.args.get("amp") else ""
ind_name = f" {index.name}" if index.name else ""
ind_columns = (
f' ({self.expressions(index, key="columns", flat=True)})'
if index.args.get("columns")
else ""
)
ind_sql = f"{ind_unique}{ind_primary}{ind_amp} INDEX{ind_name}{ind_columns}"

if indexes_sql:
indexes_sql.append(ind_sql)
else:
indexes_sql.append(
f"{ind_sql}{postindex_props_sql}"
if index.args.get("primary")
else f"{postindex_props_sql}{ind_sql}"
)

index_sql = "".join(indexes_sql)
else:
index_sql = postindex_props_sql
indexes = self.expressions(expression, "indexes", indent=False, sep=" ")
indexes = f" {indexes}" if indexes else ""
index_sql = indexes + postindex_props_sql

replace = " OR REPLACE" if expression.args.get("replace") else ""
unique = " UNIQUE" if expression.args.get("unique") else ""
Expand Down Expand Up @@ -891,10 +870,15 @@ def hint_sql(self, expression: exp.Hint) -> str:
return ""

def index_sql(self, expression: exp.Index) -> str:
this = self.sql(expression, "this")
unique = "UNIQUE " if expression.args.get("unique") else ""
primary = "PRIMARY " if expression.args.get("primary") else ""
amp = "AMP " if expression.args.get("amp") else ""
name = f"{expression.name} " if expression.name else ""
table = self.sql(expression, "table")
columns = self.sql(expression, "columns")
return f"{this} ON {table} {columns}"
table = f"{self.INDEX_ON} {table} " if table else ""
index = "INDEX " if not table else ""
columns = self.expressions(expression, key="columns", flat=True)
return f"{unique}{primary}{amp}{index}{name}{table}({columns})"

def identifier_sql(self, expression: exp.Identifier) -> str:
text = expression.name
Expand Down
47 changes: 26 additions & 21 deletions sqlglot/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -1143,7 +1143,7 @@ def _parse_create(self) -> t.Optional[exp.Expression]:
if return_:
expression = self.expression(exp.Return, this=expression)
elif create_token.token_type == TokenType.INDEX:
this = self._parse_index()
this = self._parse_index(index=self._parse_id_var())
elif create_token.token_type in self.DB_CREATABLES:
table_parts = self._parse_table_parts(schema=True)

Expand Down Expand Up @@ -1183,7 +1183,7 @@ def _parse_create(self) -> t.Optional[exp.Expression]:
if create_token.token_type == TokenType.TABLE:
indexes = []
while True:
index = self._parse_create_table_index()
index = self._parse_index()

# exp.Properties.Location.POST_EXPRESSION or exp.Properties.Location.POST_INDEX
temp_properties = self._parse_properties()
Expand Down Expand Up @@ -2193,31 +2193,36 @@ def _parse_join(self, skip_join_token: bool = False) -> t.Optional[exp.Expressio

return self.expression(exp.Join, **kwargs) # type: ignore

def _parse_index(self) -> exp.Expression:
index = self._parse_id_var()
self._match(TokenType.ON)
self._match(TokenType.TABLE) # hive
def _parse_index(
self,
index: t.Optional[exp.Expression] = None,
) -> t.Optional[exp.Expression]:
if index:
unique = None
primary = None
amp = None

return self.expression(
exp.Index,
this=index,
table=self.expression(exp.Table, this=self._parse_id_var()),
columns=self._parse_expression(),
)
self._match(TokenType.ON)
self._match(TokenType.TABLE) # hive
table = self._parse_table_parts(schema=True)
else:
unique = self._match(TokenType.UNIQUE)
primary = self._match_text_seq("PRIMARY")
amp = self._match_text_seq("AMP")
if not self._match(TokenType.INDEX):
return None
index = self._parse_id_var()
table = None

def _parse_create_table_index(self) -> t.Optional[exp.Expression]:
unique = self._match(TokenType.UNIQUE)
primary = self._match_text_seq("PRIMARY")
amp = self._match_text_seq("AMP")
if not self._match(TokenType.INDEX):
return None
index = self._parse_id_var()
columns = None
if self._match(TokenType.L_PAREN, advance=False):
columns = self._parse_wrapped_csv(self._parse_column)
columns = self._parse_wrapped_csv(self._parse_ordered)
else:
columns = None

return self.expression(
exp.Index,
this=index,
table=table,
columns=columns,
unique=unique,
primary=primary,
Expand Down
4 changes: 2 additions & 2 deletions tests/dialects/test_dialect.py
Original file line number Diff line number Diff line change
Expand Up @@ -1354,7 +1354,7 @@ def test_limit(self):
},
write={
"hive": "CREATE INDEX my_idx ON TABLE tbl (a, b)",
"postgres": "CREATE INDEX my_idx ON tbl (a, b)",
"postgres": "CREATE INDEX my_idx ON tbl (a NULLS FIRST, b NULLS FIRST)",
"sqlite": "CREATE INDEX my_idx ON tbl (a, b)",
},
)
Expand All @@ -1366,7 +1366,7 @@ def test_limit(self):
},
write={
"hive": "CREATE UNIQUE INDEX my_idx ON TABLE tbl (a, b)",
"postgres": "CREATE UNIQUE INDEX my_idx ON tbl (a, b)",
"postgres": "CREATE UNIQUE INDEX my_idx ON tbl (a NULLS FIRST, b NULLS FIRST)",
"sqlite": "CREATE UNIQUE INDEX my_idx ON tbl (a, b)",
},
)
Expand Down
9 changes: 7 additions & 2 deletions tests/dialects/test_teradata.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,12 @@ def test_update(self):
)

def test_create(self):
self.validate_identity("CREATE TABLE x (y INT) PRIMARY INDEX (y) PARTITION BY y INDEX (y)")
self.validate_all(
"CREATE TABLE x (y INT) PRIMARY INDEX (y) PARTITION BY y INDEX (y)",
write={
"teradata": "CREATE TABLE x (y INT) PRIMARY INDEX (y) INDEX (y) PARTITION BY y",
},
)
self.validate_identity(
"CREATE MULTISET VOLATILE TABLE my_table (id INT) PRIMARY INDEX (id) ON COMMIT PRESERVE ROWS"
)
Expand All @@ -37,7 +42,7 @@ def test_create(self):
"CREATE TABLE a (b INT) PARTITION BY RANGE_N(b BETWEEN 0, 1 AND 2 EACH 1)"
)
self.validate_identity(
"CREATE TABLE a (b INT) PARTITION BY RANGE_N(b BETWEEN *, 1 AND * EACH b) INDEX (a)"
"CREATE TABLE a (b INT) INDEX (a) PARTITION BY RANGE_N(b BETWEEN *, 1 AND * EACH b)"
)

self.validate_all(
Expand Down
1 change: 1 addition & 0 deletions tests/fixtures/identity.sql
Original file line number Diff line number Diff line change
Expand Up @@ -606,6 +606,7 @@ CREATE FUNCTION a.b(x INT) RETURNS INT AS RETURN x + 1
CREATE FUNCTION a.b.c()
CREATE INDEX abc ON t (a)
CREATE INDEX abc ON t (a, b, b)
CREATE INDEX abc ON t (a NULLS LAST)
CREATE UNIQUE INDEX abc ON t (a, b, b)
CREATE UNIQUE INDEX IF NOT EXISTS my_idx ON tbl (a, b)
CREATE SCHEMA x
Expand Down

0 comments on commit fbf5f47

Please sign in to comment.