diff --git a/superset/annotation_layers/annotations/commands/delete.py b/superset/annotation_layers/annotations/commands/delete.py index b86ae997a4b4e..2af01f57f4708 100644 --- a/superset/annotation_layers/annotations/commands/delete.py +++ b/superset/annotation_layers/annotations/commands/delete.py @@ -38,6 +38,8 @@ def __init__(self, model_id: int): def run(self) -> Model: self.validate() + assert self._model + try: annotation = AnnotationDAO.delete(self._model) except DAODeleteFailedError as ex: diff --git a/superset/annotation_layers/annotations/commands/update.py b/superset/annotation_layers/annotations/commands/update.py index 03797a555ba0d..76287d24a99db 100644 --- a/superset/annotation_layers/annotations/commands/update.py +++ b/superset/annotation_layers/annotations/commands/update.py @@ -45,6 +45,8 @@ def __init__(self, model_id: int, data: dict[str, Any]): def run(self) -> Model: self.validate() + assert self._model + try: annotation = AnnotationDAO.update(self._model, self._properties) except DAOUpdateFailedError as ex: diff --git a/superset/annotation_layers/commands/delete.py b/superset/annotation_layers/commands/delete.py index 0692d4dd834f7..1af4242dce762 100644 --- a/superset/annotation_layers/commands/delete.py +++ b/superset/annotation_layers/commands/delete.py @@ -39,6 +39,8 @@ def __init__(self, model_id: int): def run(self) -> Model: self.validate() + assert self._model + try: annotation_layer = AnnotationLayerDAO.delete(self._model) except DAODeleteFailedError as ex: diff --git a/superset/annotation_layers/commands/update.py b/superset/annotation_layers/commands/update.py index ca3a288413f06..e7f6963e820c4 100644 --- a/superset/annotation_layers/commands/update.py +++ b/superset/annotation_layers/commands/update.py @@ -42,6 +42,8 @@ def __init__(self, model_id: int, data: dict[str, Any]): def run(self) -> Model: self.validate() + assert self._model + try: annotation_layer = AnnotationLayerDAO.update(self._model, self._properties) except DAOUpdateFailedError as ex: diff --git a/superset/charts/commands/update.py b/superset/charts/commands/update.py index 9a5b4e1f291d8..32fd49e7cd4a1 100644 --- a/superset/charts/commands/update.py +++ b/superset/charts/commands/update.py @@ -56,6 +56,8 @@ def __init__(self, model_id: int, data: dict[str, Any]): def run(self) -> Model: self.validate() + assert self._model + try: if self._properties.get("query_context_generation") is None: self._properties["last_saved_at"] = datetime.now() diff --git a/superset/commands/export/models.py b/superset/commands/export/models.py index 27f4572af3fd4..61532d4a031c4 100644 --- a/superset/commands/export/models.py +++ b/superset/commands/export/models.py @@ -30,7 +30,7 @@ class ExportModelsCommand(BaseCommand): - dao: type[BaseDAO] = BaseDAO + dao: type[BaseDAO[Model]] = BaseDAO not_found: type[CommandException] = CommandException def __init__(self, model_ids: list[int], export_related: bool = True): diff --git a/superset/daos/annotation.py b/superset/daos/annotation.py index 171a708fa422b..2df336647a1eb 100644 --- a/superset/daos/annotation.py +++ b/superset/daos/annotation.py @@ -27,9 +27,7 @@ logger = logging.getLogger(__name__) -class AnnotationDAO(BaseDAO): - model_cls = Annotation - +class AnnotationDAO(BaseDAO[Annotation]): @staticmethod def bulk_delete(models: Optional[list[Annotation]], commit: bool = True) -> None: item_ids = [model.id for model in models] if models else [] @@ -64,9 +62,7 @@ def validate_update_uniqueness( return not db.session.query(query.exists()).scalar() -class AnnotationLayerDAO(BaseDAO): - model_cls = AnnotationLayer - +class AnnotationLayerDAO(BaseDAO[AnnotationLayer]): @staticmethod def bulk_delete( models: Optional[list[AnnotationLayer]], commit: bool = True diff --git a/superset/daos/base.py b/superset/daos/base.py index 6465e5b177b08..c0758f51dd9f3 100644 --- a/superset/daos/base.py +++ b/superset/daos/base.py @@ -14,8 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# pylint: disable=isinstance-second-argument-not-valid-type -from typing import Any, Optional, Union +from typing import Any, Generic, get_args, Optional, TypeVar, Union from flask_appbuilder.models.filters import BaseFilter from flask_appbuilder.models.sqla import Model @@ -31,8 +30,10 @@ ) from superset.extensions import db +T = TypeVar("T", bound=Model) # pylint: disable=invalid-name -class BaseDAO: + +class BaseDAO(Generic[T]): """ Base DAO, implement base CRUD sqlalchemy operations """ @@ -48,6 +49,11 @@ class BaseDAO: """ id_column_name = "id" + def __init_subclass__(cls) -> None: # pylint: disable=arguments-differ + cls.model_cls = get_args( + cls.__orig_bases__[0] # type: ignore # pylint: disable=no-member + )[0] + @classmethod def find_by_id( cls, @@ -78,7 +84,7 @@ def find_by_ids( model_ids: Union[list[str], list[int]], session: Session = None, skip_base_filter: bool = False, - ) -> list[Model]: + ) -> list[T]: """ Find a List of models by a list of ids, if defined applies `base_filter` """ @@ -95,7 +101,7 @@ def find_by_ids( return query.all() @classmethod - def find_all(cls) -> list[Model]: + def find_all(cls) -> list[T]: """ Get all that fit the `base_filter` """ @@ -108,7 +114,7 @@ def find_all(cls) -> list[Model]: return query.all() @classmethod - def find_one_or_none(cls, **filter_by: Any) -> Optional[Model]: + def find_one_or_none(cls, **filter_by: Any) -> Optional[T]: """ Get the first that fit the `base_filter` """ @@ -121,7 +127,7 @@ def find_one_or_none(cls, **filter_by: Any) -> Optional[Model]: return query.filter_by(**filter_by).one_or_none() @classmethod - def create(cls, properties: dict[str, Any], commit: bool = True) -> Model: + def create(cls, properties: dict[str, Any], commit: bool = True) -> T: """ Generic for creating models :raises: DAOCreateFailedError @@ -141,17 +147,13 @@ def create(cls, properties: dict[str, Any], commit: bool = True) -> Model: return model @classmethod - def save(cls, instance_model: Model, commit: bool = True) -> Model: + def save(cls, instance_model: T, commit: bool = True) -> None: """ Generic for saving models :raises: DAOCreateFailedError """ if cls.model_cls is None: raise DAOConfigError() - if not isinstance(instance_model, cls.model_cls): - raise DAOCreateFailedError( - "the instance model is not a type of the model class" - ) try: db.session.add(instance_model) if commit: @@ -159,12 +161,9 @@ def save(cls, instance_model: Model, commit: bool = True) -> Model: except SQLAlchemyError as ex: # pragma: no cover db.session.rollback() raise DAOCreateFailedError(exception=ex) from ex - return instance_model @classmethod - def update( - cls, model: Model, properties: dict[str, Any], commit: bool = True - ) -> Model: + def update(cls, model: T, properties: dict[str, Any], commit: bool = True) -> T: """ Generic update a model :raises: DAOCreateFailedError @@ -181,7 +180,7 @@ def update( return model @classmethod - def delete(cls, model: Model, commit: bool = True) -> Model: + def delete(cls, model: T, commit: bool = True) -> T: """ Generic delete a model :raises: DAODeleteFailedError @@ -196,7 +195,7 @@ def delete(cls, model: Model, commit: bool = True) -> Model: return model @classmethod - def bulk_delete(cls, models: list[Model], commit: bool = True) -> None: + def bulk_delete(cls, models: list[T], commit: bool = True) -> None: try: for model in models: cls.delete(model, False) diff --git a/superset/daos/chart.py b/superset/daos/chart.py index 838d93abdf8b9..1a13965022221 100644 --- a/superset/daos/chart.py +++ b/superset/daos/chart.py @@ -34,8 +34,7 @@ logger = logging.getLogger(__name__) -class ChartDAO(BaseDAO): - model_cls = Slice +class ChartDAO(BaseDAO[Slice]): base_filter = ChartFilter @staticmethod diff --git a/superset/daos/css.py b/superset/daos/css.py index 224277a40ad1c..3a1cbe8fda148 100644 --- a/superset/daos/css.py +++ b/superset/daos/css.py @@ -27,9 +27,7 @@ logger = logging.getLogger(__name__) -class CssTemplateDAO(BaseDAO): - model_cls = CssTemplate - +class CssTemplateDAO(BaseDAO[CssTemplate]): @staticmethod def bulk_delete(models: Optional[list[CssTemplate]], commit: bool = True) -> None: item_ids = [model.id for model in models] if models else [] diff --git a/superset/daos/dashboard.py b/superset/daos/dashboard.py index 1e31591e1f56b..1650711d50188 100644 --- a/superset/daos/dashboard.py +++ b/superset/daos/dashboard.py @@ -49,8 +49,7 @@ logger = logging.getLogger(__name__) -class DashboardDAO(BaseDAO): - model_cls = Dashboard +class DashboardDAO(BaseDAO[Dashboard]): base_filter = DashboardAccessFilter @classmethod @@ -379,8 +378,7 @@ def remove_favorite(dashboard: Dashboard) -> None: db.session.commit() -class EmbeddedDashboardDAO(BaseDAO): - model_cls = EmbeddedDashboard +class EmbeddedDashboardDAO(BaseDAO[EmbeddedDashboard]): # There isn't really a regular scenario where we would rather get Embedded by id id_column_name = "uuid" @@ -407,9 +405,7 @@ def create(cls, properties: dict[str, Any], commit: bool = True) -> Any: raise NotImplementedError("Use EmbeddedDashboardDAO.upsert() instead.") -class FilterSetDAO(BaseDAO): - model_cls = FilterSet - +class FilterSetDAO(BaseDAO[FilterSet]): @classmethod def create(cls, properties: dict[str, Any], commit: bool = True) -> Model: if cls.model_cls is None: diff --git a/superset/daos/database.py b/superset/daos/database.py index 569568472ae16..0a3cb65b2ef1a 100644 --- a/superset/daos/database.py +++ b/superset/daos/database.py @@ -31,8 +31,7 @@ logger = logging.getLogger(__name__) -class DatabaseDAO(BaseDAO): - model_cls = Database +class DatabaseDAO(BaseDAO[Database]): base_filter = DatabaseFilter @classmethod @@ -138,9 +137,7 @@ def get_ssh_tunnel(cls, database_id: int) -> Optional[SSHTunnel]: return ssh_tunnel -class SSHTunnelDAO(BaseDAO): - model_cls = SSHTunnel - +class SSHTunnelDAO(BaseDAO[SSHTunnel]): @classmethod def update( cls, diff --git a/superset/daos/dataset.py b/superset/daos/dataset.py index 3937e6c312024..4634a7e46f96e 100644 --- a/superset/daos/dataset.py +++ b/superset/daos/dataset.py @@ -31,8 +31,7 @@ logger = logging.getLogger(__name__) -class DatasetDAO(BaseDAO): # pylint: disable=too-many-public-methods - model_cls = SqlaTable +class DatasetDAO(BaseDAO[SqlaTable]): # pylint: disable=too-many-public-methods base_filter = DatasourceFilter @staticmethod @@ -151,7 +150,7 @@ def update( model: SqlaTable, properties: dict[str, Any], commit: bool = True, - ) -> Optional[SqlaTable]: + ) -> SqlaTable: """ Updates a Dataset model on the metadata DB """ @@ -397,9 +396,9 @@ def get_table_by_name(database_id: int, table_name: str) -> Optional[SqlaTable]: ) -class DatasetColumnDAO(BaseDAO): - model_cls = TableColumn +class DatasetColumnDAO(BaseDAO[TableColumn]): + pass -class DatasetMetricDAO(BaseDAO): - model_cls = SqlMetric +class DatasetMetricDAO(BaseDAO[SqlMetric]): + pass diff --git a/superset/daos/datasource.py b/superset/daos/datasource.py index 684106161c34b..2bdf4ca21fb7a 100644 --- a/superset/daos/datasource.py +++ b/superset/daos/datasource.py @@ -33,7 +33,7 @@ Datasource = Union[Dataset, SqlaTable, Table, Query, SavedQuery] -class DatasourceDAO(BaseDAO): +class DatasourceDAO(BaseDAO[Datasource]): sources: dict[Union[DatasourceType, str], type[Datasource]] = { DatasourceType.TABLE: SqlaTable, DatasourceType.QUERY: Query, diff --git a/superset/daos/log.py b/superset/daos/log.py index 81767a48cba90..002c3f2307255 100644 --- a/superset/daos/log.py +++ b/superset/daos/log.py @@ -30,9 +30,7 @@ from superset.utils.dates import datetime_to_epoch -class LogDAO(BaseDAO): - model_cls = Log - +class LogDAO(BaseDAO[Log]): @staticmethod def get_recent_activity( actions: list[str], diff --git a/superset/daos/query.py b/superset/daos/query.py index 8aca1a4e25649..6c8f4d572fd3a 100644 --- a/superset/daos/query.py +++ b/superset/daos/query.py @@ -35,8 +35,7 @@ logger = logging.getLogger(__name__) -class QueryDAO(BaseDAO): - model_cls = Query +class QueryDAO(BaseDAO[Query]): base_filter = QueryFilter @staticmethod @@ -101,8 +100,7 @@ def stop_query(client_id: str) -> None: db.session.commit() -class SavedQueryDAO(BaseDAO): - model_cls = SavedQuery +class SavedQueryDAO(BaseDAO[SavedQuery]): base_filter = SavedQueryFilter @staticmethod diff --git a/superset/daos/report.py b/superset/daos/report.py index 4f8d914adc9b4..70a87a6454110 100644 --- a/superset/daos/report.py +++ b/superset/daos/report.py @@ -42,8 +42,7 @@ REPORT_SCHEDULE_ERROR_NOTIFICATION_MARKER = "Notification sent with error" -class ReportScheduleDAO(BaseDAO): - model_cls = ReportSchedule +class ReportScheduleDAO(BaseDAO[ReportSchedule]): base_filter = ReportScheduleFilter @staticmethod diff --git a/superset/daos/security.py b/superset/daos/security.py index a435f224a651e..392d741e3d50d 100644 --- a/superset/daos/security.py +++ b/superset/daos/security.py @@ -19,5 +19,5 @@ from superset.daos.base import BaseDAO -class RLSDAO(BaseDAO): - model_cls = RowLevelSecurityFilter +class RLSDAO(BaseDAO[RowLevelSecurityFilter]): + pass diff --git a/superset/daos/tag.py b/superset/daos/tag.py index ec991edb13680..90b0134ca7c99 100644 --- a/superset/daos/tag.py +++ b/superset/daos/tag.py @@ -31,8 +31,7 @@ logger = logging.getLogger(__name__) -class TagDAO(BaseDAO): - model_cls = Tag +class TagDAO(BaseDAO[Tag]): # base_filter = TagAccessFilter @staticmethod diff --git a/superset/dashboards/commands/delete.py b/superset/dashboards/commands/delete.py index f774b92a51196..1f5eb4ae3b69c 100644 --- a/superset/dashboards/commands/delete.py +++ b/superset/dashboards/commands/delete.py @@ -44,6 +44,8 @@ def __init__(self, model_id: int): def run(self) -> Model: self.validate() + assert self._model + try: dashboard = DashboardDAO.delete(self._model) except DAODeleteFailedError as ex: diff --git a/superset/dashboards/commands/update.py b/superset/dashboards/commands/update.py index cd9c07e0fdfdf..c880eebe899e0 100644 --- a/superset/dashboards/commands/update.py +++ b/superset/dashboards/commands/update.py @@ -48,6 +48,8 @@ def __init__(self, model_id: int, data: dict[str, Any]): def run(self) -> Model: self.validate() + assert self._model + try: dashboard = DashboardDAO.update(self._model, self._properties, commit=False) if self._properties.get("json_metadata"): diff --git a/superset/dashboards/filter_sets/commands/delete.py b/superset/dashboards/filter_sets/commands/delete.py index 93f43833994ad..c058354245204 100644 --- a/superset/dashboards/filter_sets/commands/delete.py +++ b/superset/dashboards/filter_sets/commands/delete.py @@ -36,8 +36,10 @@ def __init__(self, dashboard_id: int, filter_set_id: int): self._filter_set_id = filter_set_id def run(self) -> Model: + self.validate() + assert self._filter_set + try: - self.validate() return FilterSetDAO.delete(self._filter_set, commit=True) except DAODeleteFailedError as err: raise FilterSetDeleteFailedError(str(self._filter_set_id), "") from err diff --git a/superset/dashboards/filter_sets/commands/update.py b/superset/dashboards/filter_sets/commands/update.py index eecaa34aeb8e9..a63c8d46f2b6b 100644 --- a/superset/dashboards/filter_sets/commands/update.py +++ b/superset/dashboards/filter_sets/commands/update.py @@ -39,6 +39,8 @@ def __init__(self, dashboard_id: int, filter_set_id: int, data: dict[str, Any]): def run(self) -> Model: try: self.validate() + assert self._filter_set + if ( OWNER_TYPE_FIELD in self._properties and self._properties[OWNER_TYPE_FIELD] == "Dashboard" diff --git a/superset/databases/commands/delete.py b/superset/databases/commands/delete.py index b8eb3f6e5e637..95d212e2903e9 100644 --- a/superset/databases/commands/delete.py +++ b/superset/databases/commands/delete.py @@ -42,6 +42,8 @@ def __init__(self, model_id: int): def run(self) -> Model: self.validate() + assert self._model + try: database = DatabaseDAO.delete(self._model) except DAODeleteFailedError as ex: diff --git a/superset/databases/ssh_tunnel/commands/delete.py b/superset/databases/ssh_tunnel/commands/delete.py index 910df35a19ab1..375c496f2a488 100644 --- a/superset/databases/ssh_tunnel/commands/delete.py +++ b/superset/databases/ssh_tunnel/commands/delete.py @@ -42,6 +42,8 @@ def run(self) -> Model: if not is_feature_enabled("SSH_TUNNELING"): raise SSHTunnelingNotEnabledError() self.validate() + assert self._model + try: ssh_tunnel = SSHTunnelDAO.delete(self._model) except DAODeleteFailedError as ex: diff --git a/superset/reports/commands/delete.py b/superset/reports/commands/delete.py index 3f7e4e5d232a1..f52d96f7f5b8a 100644 --- a/superset/reports/commands/delete.py +++ b/superset/reports/commands/delete.py @@ -41,6 +41,8 @@ def __init__(self, model_id: int): def run(self) -> Model: self.validate() + assert self._model + try: report_schedule = ReportScheduleDAO.delete(self._model) except DAODeleteFailedError as ex: diff --git a/superset/row_level_security/commands/update.py b/superset/row_level_security/commands/update.py index d44aa3efaf9e8..bc5ef368bacfd 100644 --- a/superset/row_level_security/commands/update.py +++ b/superset/row_level_security/commands/update.py @@ -41,6 +41,8 @@ def __init__(self, model_id: int, data: dict[str, Any]): def run(self) -> Any: self.validate() + assert self._model + try: rule = RLSDAO.update(self._model, self._properties) except DAOUpdateFailedError as ex: