Skip to content

Commit

Permalink
Feat(mysql): improve support for unsigned int types (#2172)
Browse files Browse the repository at this point in the history
* Feat(mysql): improve support for unsigned int types

* Refactor

* Add test for transpiling to DuckDB
  • Loading branch information
georgesittas authored Sep 6, 2023
1 parent 7e407ba commit 63ac621
Show file tree
Hide file tree
Showing 5 changed files with 51 additions and 1 deletion.
21 changes: 20 additions & 1 deletion sqlglot/dialects/mysql.py
Original file line number Diff line number Diff line change
Expand Up @@ -555,7 +555,19 @@ class Generator(generator.Generator):
exp.WeekOfYear: rename_func("WEEKOFYEAR"),
}

TYPE_MAPPING = generator.Generator.TYPE_MAPPING.copy()
UNSIGNED_TYPE_MAPPING = {
exp.DataType.Type.UBIGINT: "BIGINT",
exp.DataType.Type.UINT: "INT",
exp.DataType.Type.UMEDIUMINT: "MEDIUMINT",
exp.DataType.Type.USMALLINT: "SMALLINT",
exp.DataType.Type.UTINYINT: "TINYINT",
}

TYPE_MAPPING = {
**generator.Generator.TYPE_MAPPING,
**UNSIGNED_TYPE_MAPPING,
}

TYPE_MAPPING.pop(exp.DataType.Type.MEDIUMTEXT)
TYPE_MAPPING.pop(exp.DataType.Type.LONGTEXT)
TYPE_MAPPING.pop(exp.DataType.Type.MEDIUMBLOB)
Expand All @@ -580,6 +592,13 @@ class Generator(generator.Generator):
exp.DataType.Type.VARCHAR: "CHAR",
}

def datatype_sql(self, expression: exp.DataType) -> str:
# https://dev.mysql.com/doc/refman/8.0/en/numeric-type-syntax.html
result = super().datatype_sql(expression)
if expression.this in self.UNSIGNED_TYPE_MAPPING:
result = f"{result} UNSIGNED"
return result

def limit_sql(self, expression: exp.Limit, top: bool = False) -> str:
# MySQL requires simple literal values for its LIMIT clause.
expression = simplify_literal(expression.copy())
Expand Down
1 change: 1 addition & 0 deletions sqlglot/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -3543,6 +3543,7 @@ class Type(AutoName):
UINT = auto()
UINT128 = auto()
UINT256 = auto()
UMEDIUMINT = auto()
UNIQUEIDENTIFIER = auto()
UNKNOWN = auto() # Sentinel value, useful for type annotation
USERDEFINED = "USER-DEFINED"
Expand Down
16 changes: 16 additions & 0 deletions sqlglot/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,7 @@ class Parser(metaclass=_Parser):
TokenType.INT256,
TokenType.UINT256,
TokenType.MEDIUMINT,
TokenType.UMEDIUMINT,
TokenType.FIXEDSTRING,
TokenType.FLOAT,
TokenType.DOUBLE,
Expand Down Expand Up @@ -206,6 +207,14 @@ class Parser(metaclass=_Parser):
*NESTED_TYPE_TOKENS,
}

SIGNED_TO_UNSIGNED_TYPE_TOKEN = {
TokenType.BIGINT: TokenType.UBIGINT,
TokenType.INT: TokenType.UINT,
TokenType.MEDIUMINT: TokenType.UMEDIUMINT,
TokenType.SMALLINT: TokenType.USMALLINT,
TokenType.TINYINT: TokenType.UTINYINT,
}

SUBQUERY_PREDICATES = {
TokenType.ANY: exp.Any,
TokenType.ALL: exp.All,
Expand Down Expand Up @@ -3359,6 +3368,13 @@ def _parse_types(
self._retreat(index2)

if not this:
if self._match_text_seq("UNSIGNED"):
unsigned_type_token = self.SIGNED_TO_UNSIGNED_TYPE_TOKEN.get(type_token)
if not unsigned_type_token:
self.raise_error(f"Cannot convert {type_token.value} to unsigned.")

type_token = unsigned_type_token or type_token

this = exp.DataType(
this=exp.DataType.Type[type_token.value],
expressions=expressions,
Expand Down
1 change: 1 addition & 0 deletions sqlglot/tokens.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ class TokenType(AutoName):
SMALLINT = auto()
USMALLINT = auto()
MEDIUMINT = auto()
UMEDIUMINT = auto()
INT = auto()
UINT = auto()
BIGINT = auto()
Expand Down
13 changes: 13 additions & 0 deletions tests/dialects/test_mysql.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,19 @@ class TestMySQL(Validator):
dialect = "mysql"

def test_ddl(self):
int_types = {"BIGINT", "INT", "MEDIUMINT", "SMALLINT", "TINYINT"}

for t in int_types:
self.validate_identity(f"CREATE TABLE t (id {t} UNSIGNED)")
self.validate_identity(f"CREATE TABLE t (id {t}(10) UNSIGNED)")

self.validate_all(
"CREATE TABLE t (id INT UNSIGNED)",
write={
"duckdb": "CREATE TABLE t (id UINTEGER)",
},
)

self.validate_identity("CREATE TABLE foo (id BIGINT)")
self.validate_identity("CREATE TABLE 00f (1d BIGINT)")
self.validate_identity("UPDATE items SET items.price = 0 WHERE items.id >= 5 LIMIT 10")
Expand Down

0 comments on commit 63ac621

Please sign in to comment.