Skip to content

Commit

Permalink
chore(dao): Add generic type for better type checking (#24465)
Browse files Browse the repository at this point in the history
  • Loading branch information
john-bodley authored Jun 21, 2023
1 parent d5f88c1 commit 92e2ee9
Show file tree
Hide file tree
Showing 27 changed files with 68 additions and 64 deletions.
2 changes: 2 additions & 0 deletions superset/annotation_layers/annotations/commands/delete.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ def __init__(self, model_id: int):

def run(self) -> Model:
self.validate()
assert self._model

try:
annotation = AnnotationDAO.delete(self._model)
except DAODeleteFailedError as ex:
Expand Down
2 changes: 2 additions & 0 deletions superset/annotation_layers/annotations/commands/update.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@ def __init__(self, model_id: int, data: dict[str, Any]):

def run(self) -> Model:
self.validate()
assert self._model

try:
annotation = AnnotationDAO.update(self._model, self._properties)
except DAOUpdateFailedError as ex:
Expand Down
2 changes: 2 additions & 0 deletions superset/annotation_layers/commands/delete.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ def __init__(self, model_id: int):

def run(self) -> Model:
self.validate()
assert self._model

try:
annotation_layer = AnnotationLayerDAO.delete(self._model)
except DAODeleteFailedError as ex:
Expand Down
2 changes: 2 additions & 0 deletions superset/annotation_layers/commands/update.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@ def __init__(self, model_id: int, data: dict[str, Any]):

def run(self) -> Model:
self.validate()
assert self._model

try:
annotation_layer = AnnotationLayerDAO.update(self._model, self._properties)
except DAOUpdateFailedError as ex:
Expand Down
2 changes: 2 additions & 0 deletions superset/charts/commands/update.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,8 @@ def __init__(self, model_id: int, data: dict[str, Any]):

def run(self) -> Model:
self.validate()
assert self._model

try:
if self._properties.get("query_context_generation") is None:
self._properties["last_saved_at"] = datetime.now()
Expand Down
2 changes: 1 addition & 1 deletion superset/commands/export/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@


class ExportModelsCommand(BaseCommand):
dao: type[BaseDAO] = BaseDAO
dao: type[BaseDAO[Model]] = BaseDAO
not_found: type[CommandException] = CommandException

def __init__(self, model_ids: list[int], export_related: bool = True):
Expand Down
8 changes: 2 additions & 6 deletions superset/daos/annotation.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,7 @@
logger = logging.getLogger(__name__)


class AnnotationDAO(BaseDAO):
model_cls = Annotation

class AnnotationDAO(BaseDAO[Annotation]):
@staticmethod
def bulk_delete(models: Optional[list[Annotation]], commit: bool = True) -> None:
item_ids = [model.id for model in models] if models else []
Expand Down Expand Up @@ -64,9 +62,7 @@ def validate_update_uniqueness(
return not db.session.query(query.exists()).scalar()


class AnnotationLayerDAO(BaseDAO):
model_cls = AnnotationLayer

class AnnotationLayerDAO(BaseDAO[AnnotationLayer]):
@staticmethod
def bulk_delete(
models: Optional[list[AnnotationLayer]], commit: bool = True
Expand Down
35 changes: 17 additions & 18 deletions superset/daos/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=isinstance-second-argument-not-valid-type
from typing import Any, Optional, Union
from typing import Any, Generic, get_args, Optional, TypeVar, Union

from flask_appbuilder.models.filters import BaseFilter
from flask_appbuilder.models.sqla import Model
Expand All @@ -31,8 +30,10 @@
)
from superset.extensions import db

T = TypeVar("T", bound=Model) # pylint: disable=invalid-name

class BaseDAO:

class BaseDAO(Generic[T]):
"""
Base DAO, implement base CRUD sqlalchemy operations
"""
Expand All @@ -48,6 +49,11 @@ class BaseDAO:
"""
id_column_name = "id"

def __init_subclass__(cls) -> None: # pylint: disable=arguments-differ
cls.model_cls = get_args(
cls.__orig_bases__[0] # type: ignore # pylint: disable=no-member
)[0]

@classmethod
def find_by_id(
cls,
Expand Down Expand Up @@ -78,7 +84,7 @@ def find_by_ids(
model_ids: Union[list[str], list[int]],
session: Session = None,
skip_base_filter: bool = False,
) -> list[Model]:
) -> list[T]:
"""
Find a List of models by a list of ids, if defined applies `base_filter`
"""
Expand All @@ -95,7 +101,7 @@ def find_by_ids(
return query.all()

@classmethod
def find_all(cls) -> list[Model]:
def find_all(cls) -> list[T]:
"""
Get all that fit the `base_filter`
"""
Expand All @@ -108,7 +114,7 @@ def find_all(cls) -> list[Model]:
return query.all()

@classmethod
def find_one_or_none(cls, **filter_by: Any) -> Optional[Model]:
def find_one_or_none(cls, **filter_by: Any) -> Optional[T]:
"""
Get the first that fit the `base_filter`
"""
Expand All @@ -121,7 +127,7 @@ def find_one_or_none(cls, **filter_by: Any) -> Optional[Model]:
return query.filter_by(**filter_by).one_or_none()

@classmethod
def create(cls, properties: dict[str, Any], commit: bool = True) -> Model:
def create(cls, properties: dict[str, Any], commit: bool = True) -> T:
"""
Generic for creating models
:raises: DAOCreateFailedError
Expand All @@ -141,30 +147,23 @@ def create(cls, properties: dict[str, Any], commit: bool = True) -> Model:
return model

@classmethod
def save(cls, instance_model: Model, commit: bool = True) -> Model:
def save(cls, instance_model: T, commit: bool = True) -> None:
"""
Generic for saving models
:raises: DAOCreateFailedError
"""
if cls.model_cls is None:
raise DAOConfigError()
if not isinstance(instance_model, cls.model_cls):
raise DAOCreateFailedError(
"the instance model is not a type of the model class"
)
try:
db.session.add(instance_model)
if commit:
db.session.commit()
except SQLAlchemyError as ex: # pragma: no cover
db.session.rollback()
raise DAOCreateFailedError(exception=ex) from ex
return instance_model

@classmethod
def update(
cls, model: Model, properties: dict[str, Any], commit: bool = True
) -> Model:
def update(cls, model: T, properties: dict[str, Any], commit: bool = True) -> T:
"""
Generic update a model
:raises: DAOCreateFailedError
Expand All @@ -181,7 +180,7 @@ def update(
return model

@classmethod
def delete(cls, model: Model, commit: bool = True) -> Model:
def delete(cls, model: T, commit: bool = True) -> T:
"""
Generic delete a model
:raises: DAODeleteFailedError
Expand All @@ -196,7 +195,7 @@ def delete(cls, model: Model, commit: bool = True) -> Model:
return model

@classmethod
def bulk_delete(cls, models: list[Model], commit: bool = True) -> None:
def bulk_delete(cls, models: list[T], commit: bool = True) -> None:
try:
for model in models:
cls.delete(model, False)
Expand Down
3 changes: 1 addition & 2 deletions superset/daos/chart.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,7 @@
logger = logging.getLogger(__name__)


class ChartDAO(BaseDAO):
model_cls = Slice
class ChartDAO(BaseDAO[Slice]):
base_filter = ChartFilter

@staticmethod
Expand Down
4 changes: 1 addition & 3 deletions superset/daos/css.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,7 @@
logger = logging.getLogger(__name__)


class CssTemplateDAO(BaseDAO):
model_cls = CssTemplate

class CssTemplateDAO(BaseDAO[CssTemplate]):
@staticmethod
def bulk_delete(models: Optional[list[CssTemplate]], commit: bool = True) -> None:
item_ids = [model.id for model in models] if models else []
Expand Down
10 changes: 3 additions & 7 deletions superset/daos/dashboard.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,7 @@
logger = logging.getLogger(__name__)


class DashboardDAO(BaseDAO):
model_cls = Dashboard
class DashboardDAO(BaseDAO[Dashboard]):
base_filter = DashboardAccessFilter

@classmethod
Expand Down Expand Up @@ -379,8 +378,7 @@ def remove_favorite(dashboard: Dashboard) -> None:
db.session.commit()


class EmbeddedDashboardDAO(BaseDAO):
model_cls = EmbeddedDashboard
class EmbeddedDashboardDAO(BaseDAO[EmbeddedDashboard]):
# There isn't really a regular scenario where we would rather get Embedded by id
id_column_name = "uuid"

Expand All @@ -407,9 +405,7 @@ def create(cls, properties: dict[str, Any], commit: bool = True) -> Any:
raise NotImplementedError("Use EmbeddedDashboardDAO.upsert() instead.")


class FilterSetDAO(BaseDAO):
model_cls = FilterSet

class FilterSetDAO(BaseDAO[FilterSet]):
@classmethod
def create(cls, properties: dict[str, Any], commit: bool = True) -> Model:
if cls.model_cls is None:
Expand Down
7 changes: 2 additions & 5 deletions superset/daos/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,7 @@
logger = logging.getLogger(__name__)


class DatabaseDAO(BaseDAO):
model_cls = Database
class DatabaseDAO(BaseDAO[Database]):
base_filter = DatabaseFilter

@classmethod
Expand Down Expand Up @@ -138,9 +137,7 @@ def get_ssh_tunnel(cls, database_id: int) -> Optional[SSHTunnel]:
return ssh_tunnel


class SSHTunnelDAO(BaseDAO):
model_cls = SSHTunnel

class SSHTunnelDAO(BaseDAO[SSHTunnel]):
@classmethod
def update(
cls,
Expand Down
13 changes: 6 additions & 7 deletions superset/daos/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,7 @@
logger = logging.getLogger(__name__)


class DatasetDAO(BaseDAO): # pylint: disable=too-many-public-methods
model_cls = SqlaTable
class DatasetDAO(BaseDAO[SqlaTable]): # pylint: disable=too-many-public-methods
base_filter = DatasourceFilter

@staticmethod
Expand Down Expand Up @@ -151,7 +150,7 @@ def update(
model: SqlaTable,
properties: dict[str, Any],
commit: bool = True,
) -> Optional[SqlaTable]:
) -> SqlaTable:
"""
Updates a Dataset model on the metadata DB
"""
Expand Down Expand Up @@ -397,9 +396,9 @@ def get_table_by_name(database_id: int, table_name: str) -> Optional[SqlaTable]:
)


class DatasetColumnDAO(BaseDAO):
model_cls = TableColumn
class DatasetColumnDAO(BaseDAO[TableColumn]):
pass


class DatasetMetricDAO(BaseDAO):
model_cls = SqlMetric
class DatasetMetricDAO(BaseDAO[SqlMetric]):
pass
2 changes: 1 addition & 1 deletion superset/daos/datasource.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
Datasource = Union[Dataset, SqlaTable, Table, Query, SavedQuery]


class DatasourceDAO(BaseDAO):
class DatasourceDAO(BaseDAO[Datasource]):
sources: dict[Union[DatasourceType, str], type[Datasource]] = {
DatasourceType.TABLE: SqlaTable,
DatasourceType.QUERY: Query,
Expand Down
4 changes: 1 addition & 3 deletions superset/daos/log.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,7 @@
from superset.utils.dates import datetime_to_epoch


class LogDAO(BaseDAO):
model_cls = Log

class LogDAO(BaseDAO[Log]):
@staticmethod
def get_recent_activity(
actions: list[str],
Expand Down
6 changes: 2 additions & 4 deletions superset/daos/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,7 @@
logger = logging.getLogger(__name__)


class QueryDAO(BaseDAO):
model_cls = Query
class QueryDAO(BaseDAO[Query]):
base_filter = QueryFilter

@staticmethod
Expand Down Expand Up @@ -104,8 +103,7 @@ def stop_query(client_id: str) -> None:
db.session.commit()


class SavedQueryDAO(BaseDAO):
model_cls = SavedQuery
class SavedQueryDAO(BaseDAO[SavedQuery]):
base_filter = SavedQueryFilter

@staticmethod
Expand Down
3 changes: 1 addition & 2 deletions superset/daos/report.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,7 @@
REPORT_SCHEDULE_ERROR_NOTIFICATION_MARKER = "Notification sent with error"


class ReportScheduleDAO(BaseDAO):
model_cls = ReportSchedule
class ReportScheduleDAO(BaseDAO[ReportSchedule]):
base_filter = ReportScheduleFilter

@staticmethod
Expand Down
4 changes: 2 additions & 2 deletions superset/daos/security.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,5 +19,5 @@
from superset.daos.base import BaseDAO


class RLSDAO(BaseDAO):
model_cls = RowLevelSecurityFilter
class RLSDAO(BaseDAO[RowLevelSecurityFilter]):
pass
3 changes: 1 addition & 2 deletions superset/daos/tag.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,7 @@
logger = logging.getLogger(__name__)


class TagDAO(BaseDAO):
model_cls = Tag
class TagDAO(BaseDAO[Tag]):
# base_filter = TagAccessFilter

@staticmethod
Expand Down
2 changes: 2 additions & 0 deletions superset/dashboards/commands/delete.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ def __init__(self, model_id: int):

def run(self) -> Model:
self.validate()
assert self._model

try:
dashboard = DashboardDAO.delete(self._model)
except DAODeleteFailedError as ex:
Expand Down
2 changes: 2 additions & 0 deletions superset/dashboards/commands/update.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ def __init__(self, model_id: int, data: dict[str, Any]):

def run(self) -> Model:
self.validate()
assert self._model

try:
dashboard = DashboardDAO.update(self._model, self._properties, commit=False)
if self._properties.get("json_metadata"):
Expand Down
Loading

0 comments on commit 92e2ee9

Please sign in to comment.