Skip to content

Commit

Permalink
fix: TemporalWrapperType string representation (apache#16614)
Browse files Browse the repository at this point in the history
* fix: TemporalWrapperType string representation

* fix tests
  • Loading branch information
villebro authored Sep 7, 2021
1 parent f42d0d4 commit 4083fe4
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 9 deletions.
2 changes: 1 addition & 1 deletion superset/db_engine_specs/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1295,7 +1295,7 @@ def get_column_spec(
# using datetimes
if generic_type == GenericDataType.TEMPORAL:
column_type = literal_dttm_type_factory(
type(column_type), cls, native_type or ""
column_type, cls, native_type or ""
)
is_dttm = generic_type == GenericDataType.TEMPORAL
return ColumnSpec(
Expand Down
10 changes: 4 additions & 6 deletions superset/models/sql_types/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,8 @@


def literal_dttm_type_factory(
sqla_type: Type[types.TypeEngine],
db_engine_spec: Type["BaseEngineSpec"],
col_type: str,
) -> Type[types.TypeEngine]:
sqla_type: types.TypeEngine, db_engine_spec: Type["BaseEngineSpec"], col_type: str,
) -> types.TypeEngine:
"""
Create a custom SQLAlchemy type that supports datetime literal binds.
Expand All @@ -39,7 +37,7 @@ def literal_dttm_type_factory(
:return: SQLAlchemy type that supports using datetima as literal bind
"""
# pylint: disable=too-few-public-methods
class TemporalWrapperType(sqla_type): # type: ignore
class TemporalWrapperType(type(sqla_type)): # type: ignore
# pylint: disable=unused-argument
def literal_processor(self, dialect: Dialect) -> Callable[[Any], Any]:
def process(value: Any) -> Any:
Expand All @@ -58,4 +56,4 @@ def process(value: Any) -> Any:

return process

return TemporalWrapperType
return TemporalWrapperType()
6 changes: 4 additions & 2 deletions tests/integration_tests/db_engine_specs/presto_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -561,11 +561,13 @@ def test_get_sqla_column_type(self):
self.assertEqual(column_spec.generic_type, GenericDataType.NUMERIC)

column_spec = PrestoEngineSpec.get_column_spec("time")
assert issubclass(column_spec.sqla_type, types.Time)
assert isinstance(column_spec.sqla_type, types.Time)
assert type(column_spec.sqla_type).__name__ == "TemporalWrapperType"
self.assertEqual(column_spec.generic_type, GenericDataType.TEMPORAL)

column_spec = PrestoEngineSpec.get_column_spec("timestamp")
assert issubclass(column_spec.sqla_type, types.TIMESTAMP)
assert isinstance(column_spec.sqla_type, types.TIMESTAMP)
assert type(column_spec.sqla_type).__name__ == "TemporalWrapperType"
self.assertEqual(column_spec.generic_type, GenericDataType.TEMPORAL)

sqla_type = PrestoEngineSpec.get_sqla_column_type(None)
Expand Down
10 changes: 10 additions & 0 deletions tests/integration_tests/model_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,14 @@

import pytest
from sqlalchemy.engine.url import make_url
from sqlalchemy.types import DateTime

import tests.integration_tests.test_app
from superset import app, db as metadata_db
from superset.db_engine_specs.postgres import PostgresEngineSpec
from superset.models.core import Database
from superset.models.slice import Slice
from superset.models.sql_types.base import literal_dttm_type_factory
from superset.utils.core import get_example_database, QueryStatus

from .base_tests import SupersetTestCase
Expand Down Expand Up @@ -516,3 +519,10 @@ def test_data_for_slices(self):
assert set(data_for_slices["verbose_map"].keys()) == set(
["__timestamp", "sum__num", "gender",]
)


def test_literal_dttm_type_factory():
orig_type = DateTime()
new_type = literal_dttm_type_factory(orig_type, PostgresEngineSpec, "TIMESTAMP")
assert type(new_type).__name__ == "TemporalWrapperType"
assert str(new_type) == str(orig_type)

0 comments on commit 4083fe4

Please sign in to comment.