Skip to content

Commit

Permalink
Feat: sqlite primary key transforms closes #1557
Browse files Browse the repository at this point in the history
  • Loading branch information
tobymao committed May 5, 2023
1 parent 6124d0c commit 1fa8ae9
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 0 deletions.
35 changes: 35 additions & 0 deletions sqlglot/dialects/sqlite.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,40 @@ def _date_add_sql(self, expression):
return self.func("DATE", expression.this, modifier)


def _transform_create(expression: exp.Expression) -> exp.Expression:
"""Move primary key to a column and enforce auto_increment on primary keys."""
schema = expression.this

if isinstance(expression, exp.Create) and isinstance(schema, exp.Schema):
defs = {}
primary_key = None

for e in schema.expressions:
if isinstance(e, exp.ColumnDef):
defs[e.name] = e
elif isinstance(e, exp.PrimaryKey):
primary_key = e

if primary_key and len(primary_key.expressions) == 1:
column = defs[primary_key.expressions[0].name]
column.append(
"constraints", exp.ColumnConstraint(kind=exp.PrimaryKeyColumnConstraint())
)
schema.expressions.remove(primary_key)
else:
for column in defs.values():
auto_increment = None
for constraint in column.constraints.copy():
if isinstance(constraint.kind, exp.PrimaryKeyColumnConstraint):
break
if isinstance(constraint.kind, exp.AutoIncrementColumnConstraint):
auto_increment = constraint
if auto_increment:
column.constraints.remove(auto_increment)

return expression


class SQLite(Dialect):
class Tokenizer(tokens.Tokenizer):
IDENTIFIERS = ['"', ("[", "]"), "`"]
Expand Down Expand Up @@ -66,6 +100,7 @@ class Generator(generator.Generator):
TRANSFORMS = {
**generator.Generator.TRANSFORMS, # type: ignore
exp.CountIf: count_if_to_sum,
exp.Create: transforms.preprocess([_transform_create]),
exp.CurrentDate: lambda *_: "CURRENT_DATE",
exp.CurrentTime: lambda *_: "CURRENT_TIME",
exp.CurrentTimestamp: lambda *_: "CURRENT_TIMESTAMP",
Expand Down
8 changes: 8 additions & 0 deletions sqlglot/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1083,6 +1083,10 @@ class ColumnDef(Expression):
"position": False,
}

@property
def constraints(self) -> t.List[ColumnConstraint]:
return self.args.get("constraints") or []


class AlterColumn(Expression):
arg_types = {
Expand Down Expand Up @@ -1110,6 +1114,10 @@ class Comment(Expression):
class ColumnConstraint(Expression):
arg_types = {"this": False, "kind": True}

@property
def kind(self) -> ColumnConstraintKind:
return self.args["kind"]


class ColumnConstraintKind(Expression):
pass
Expand Down
13 changes: 13 additions & 0 deletions tests/dialects/test_mysql.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,19 @@ def test_ddl(self):
"INSERT INTO x VALUES (1, 'a', 2.0) ON DUPLICATE KEY UPDATE SET x.id = 1"
)

self.validate_all(
"CREATE TABLE x (id int not null auto_increment, primary key (id))",
write={
"sqlite": "CREATE TABLE x (id INTEGER NOT NULL AUTOINCREMENT PRIMARY KEY)",
},
)
self.validate_all(
"CREATE TABLE x (id int not null auto_increment)",
write={
"sqlite": "CREATE TABLE x (id INTEGER NOT NULL)",
},
)

def test_identity(self):
self.validate_identity("SELECT CURRENT_TIMESTAMP(6)")
self.validate_identity("x ->> '$.name'")
Expand Down

0 comments on commit 1fa8ae9

Please sign in to comment.