Skip to content

Commit

Permalink
Fix(schema): ensure the correct dialect is used in schema methods (#1710
Browse files Browse the repository at this point in the history
)
  • Loading branch information
georgesittas authored May 31, 2023
1 parent 12d3cca commit dd5457c
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 4 deletions.
13 changes: 9 additions & 4 deletions sqlglot/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,7 @@ def add_table(
dialect: the SQL dialect that will be used to parse `table` if it's a string.
"""
normalized_table = self._normalize_table(
exp.maybe_parse(table, into=exp.Table, dialect=dialect), dialect=dialect
self._ensure_table(table, dialect=dialect), dialect=dialect
)
normalized_column_mapping = {
self._normalize_name(key, dialect=dialect): value
Expand All @@ -250,7 +250,7 @@ def column_names(
dialect: DialectType = None,
) -> t.List[str]:
normalized_table = self._normalize_table(
exp.maybe_parse(table, into=exp.Table, dialect=dialect), dialect=dialect
self._ensure_table(table, dialect=dialect), dialect=dialect
)

schema = self.find(normalized_table)
Expand All @@ -270,7 +270,7 @@ def get_column_type(
dialect: DialectType = None,
) -> exp.DataType:
normalized_table = self._normalize_table(
exp.maybe_parse(table, into=exp.Table, dialect=dialect), dialect=dialect
self._ensure_table(table, dialect=dialect), dialect=dialect
)
normalized_column_name = self._normalize_name(
column if isinstance(column, str) else column.this, dialect=dialect
Expand Down Expand Up @@ -322,7 +322,9 @@ def _normalize_table(self, table: exp.Table, dialect: DialectType = None) -> exp
for arg in TABLE_ARGS:
value = normalized_table.args.get(arg)
if isinstance(value, (str, exp.Identifier)):
normalized_table.set(arg, self._normalize_name(value, dialect=dialect))
normalized_table.set(
arg, exp.to_identifier(self._normalize_name(value, dialect=dialect))
)

return normalized_table

Expand All @@ -345,6 +347,9 @@ def _depth(self) -> int:
# The columns themselves are a mapping, but we don't want to include those
return super()._depth() - 1

def _ensure_table(self, table: exp.Table | str, dialect: DialectType = None) -> exp.Table:
return exp.maybe_parse(table, into=exp.Table, dialect=dialect or self.dialect)

def _to_data_type(self, schema_type: str, dialect: DialectType = None) -> exp.DataType:
"""
Convert a type represented as a string to the corresponding `sqlglot.exp.DataType` object.
Expand Down
4 changes: 4 additions & 0 deletions tests/test_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,3 +225,7 @@ def test_schema_normalization(self):
# Check that switching off the normalization logic works as expected
schema = MappingSchema(schema={"x": {"foo": "int"}}, normalize=False, dialect="snowflake")
self.assertEqual(schema.column_names(exp.Table(this="x")), ["foo"])

# Check that the correct dialect is used when calling schema methods
schema = MappingSchema(schema={"[Fo]": {"x": "int"}}, dialect="tsql")
self.assertEqual(schema.column_names("[Fo]"), schema.column_names("`Fo`", dialect="spark"))

0 comments on commit dd5457c

Please sign in to comment.