Skip to content

Commit

Permalink
Fix(executor): add table normalization, fix python type mapping (#2015)
Browse files Browse the repository at this point in the history
  • Loading branch information
georgesittas authored Aug 9, 2023
1 parent f4e5858 commit bc46c3d
Show file tree
Hide file tree
Showing 4 changed files with 78 additions and 28 deletions.
10 changes: 8 additions & 2 deletions sqlglot/executor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,11 @@
from sqlglot.schema import Schema


PYTHON_TYPE_TO_SQLGLOT = {
"dict": "MAP",
}


def execute(
sql: str | Expression,
schema: t.Optional[t.Dict | Schema] = None,
Expand All @@ -50,7 +55,7 @@ def execute(
Returns:
Simple columnar data structure.
"""
tables_ = ensure_tables(tables)
tables_ = ensure_tables(tables, dialect=read)

if not schema:
schema = {}
Expand All @@ -61,7 +66,8 @@ def execute(
assert table is not None

for column in table.columns:
nested_set(schema, [*keys, column], type(table[0][column]).__name__)
py_type = type(table[0][column]).__name__
nested_set(schema, [*keys, column], PYTHON_TYPE_TO_SQLGLOT.get(py_type) or py_type)

schema = ensure_schema(schema, dialect=read)

Expand Down
34 changes: 23 additions & 11 deletions sqlglot/executor/table.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@

import typing as t

from sqlglot.dialects.dialect import DialectType
from sqlglot.helper import dict_depth
from sqlglot.schema import AbstractMappingSchema
from sqlglot.schema import AbstractMappingSchema, normalize_name


class Table:
Expand Down Expand Up @@ -108,26 +109,37 @@ class Tables(AbstractMappingSchema[Table]):
pass


def ensure_tables(d: t.Optional[t.Dict]) -> Tables:
return Tables(_ensure_tables(d))
def ensure_tables(d: t.Optional[t.Dict], dialect: DialectType = None) -> Tables:
return Tables(_ensure_tables(d, dialect=dialect))


def _ensure_tables(d: t.Optional[t.Dict]) -> t.Dict:
def _ensure_tables(d: t.Optional[t.Dict], dialect: DialectType = None) -> t.Dict:
if not d:
return {}

depth = dict_depth(d)

if depth > 1:
return {k: _ensure_tables(v) for k, v in d.items()}
return {
normalize_name(k, dialect=dialect, is_table=True): _ensure_tables(v, dialect=dialect)
for k, v in d.items()
}

result = {}
for name, table in d.items():
for table_name, table in d.items():
table_name = normalize_name(table_name, dialect=dialect)

if isinstance(table, Table):
result[name] = table
result[table_name] = table
else:
columns = tuple(table[0]) if table else ()
rows = [tuple(row[c] for c in columns) for row in table]
result[name] = Table(columns=columns, rows=rows)
table = [
{
normalize_name(column_name, dialect=dialect): value
for column_name, value in row.items()
}
for row in table
]
column_names = tuple(column_name for column_name in table[0]) if table else ()
rows = [tuple(row[name] for name in column_names) for row in table]
result[table_name] = Table(columns=column_names, rows=rows)

return result
41 changes: 26 additions & 15 deletions sqlglot/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,21 +372,12 @@ def _normalize_name(
is_table: bool = False,
normalize: t.Optional[bool] = None,
) -> str:
dialect = dialect or self.dialect
normalize = self.normalize if normalize is None else normalize

try:
identifier = sqlglot.maybe_parse(name, dialect=dialect, into=exp.Identifier)
except ParseError:
return name if isinstance(name, str) else name.name

name = identifier.name
if not normalize:
return name

# This can be useful for normalize_identifier
identifier.meta["is_table"] = is_table
return Dialect.get_or_raise(dialect).normalize_identifier(identifier).name
return normalize_name(
name,
dialect=dialect or self.dialect,
is_table=is_table,
normalize=self.normalize if normalize is None else normalize,
)

def depth(self) -> int:
if not self.empty and not self._depth:
Expand Down Expand Up @@ -418,6 +409,26 @@ def _to_data_type(self, schema_type: str, dialect: DialectType = None) -> exp.Da
return self._type_mapping_cache[schema_type]


def normalize_name(
name: str | exp.Identifier,
dialect: DialectType = None,
is_table: bool = False,
normalize: t.Optional[bool] = True,
) -> str:
try:
identifier = sqlglot.maybe_parse(name, dialect=dialect, into=exp.Identifier)
except ParseError:
return name if isinstance(name, str) else name.name

name = identifier.name
if not normalize:
return name

# This can be useful for normalize_identifier
identifier.meta["is_table"] = is_table
return Dialect.get_or_raise(dialect).normalize_identifier(identifier).name


def ensure_schema(schema: Schema | t.Optional[t.Dict], **kwargs: t.Any) -> Schema:
if isinstance(schema, Schema):
return schema
Expand Down
21 changes: 21 additions & 0 deletions tests/test_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -723,3 +723,24 @@ def test_group_by(self):
result = execute(sql, tables=tables)
self.assertEqual(result.columns, columns)
self.assertEqual(result.rows, expected)

def test_dict_values(self):
tables = {
"foo": [{"raw": {"name": "Hello, World"}}],
}
result = execute("SELECT raw:name AS name FROM foo", read="snowflake", tables=tables)

self.assertEqual(result.columns, ("NAME",))
self.assertEqual(result.rows, [("Hello, World",)])

tables = {
'"ITEM"': [
{"id": 1, "attributes": {"flavor": "cherry", "taste": "sweet"}},
{"id": 2, "attributes": {"flavor": "lime", "taste": "sour"}},
{"id": 3, "attributes": {"flavor": "apple", "taste": None}},
]
}
result = execute("SELECT i.attributes.flavor FROM `ITEM` i", read="bigquery", tables=tables)

self.assertEqual(result.columns, ("flavor",))
self.assertEqual(result.rows, [("cherry",), ("lime",), ("apple",)])

0 comments on commit bc46c3d

Please sign in to comment.