Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore(dao): Add generic type for better type checking #24465

Merged
merged 1 commit into from
Jun 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions superset/annotation_layers/annotations/commands/delete.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ def __init__(self, model_id: int):

def run(self) -> Model:
self.validate()
assert self._model
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was needed because now Mypy correctly identified that on line #44 self._model could have been None which didn't adhere to the base class definition.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have a draft SIP recommending a slight refactor of the commands which address this problem.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@betodealmeida I’ve been thinking about the DAOs, commands and validation a bit recently and wonder whether we should adopt the “ask for forgiveness” try approach more often which relies less on Python validation and more on the database schema (foreign keys, uniqueness constraints, etc.) for validation, i.e., the DAO or command would fail if invalid.

The benefits are that we would reduce the Python code footprint and prevent possible race conditions.


try:
annotation = AnnotationDAO.delete(self._model)
except DAODeleteFailedError as ex:
Expand Down
2 changes: 2 additions & 0 deletions superset/annotation_layers/annotations/commands/update.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@ def __init__(self, model_id: int, data: dict[str, Any]):

def run(self) -> Model:
self.validate()
assert self._model

try:
annotation = AnnotationDAO.update(self._model, self._properties)
except DAOUpdateFailedError as ex:
Expand Down
2 changes: 2 additions & 0 deletions superset/annotation_layers/commands/delete.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ def __init__(self, model_id: int):

def run(self) -> Model:
self.validate()
assert self._model

try:
annotation_layer = AnnotationLayerDAO.delete(self._model)
except DAODeleteFailedError as ex:
Expand Down
2 changes: 2 additions & 0 deletions superset/annotation_layers/commands/update.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@ def __init__(self, model_id: int, data: dict[str, Any]):

def run(self) -> Model:
self.validate()
assert self._model

try:
annotation_layer = AnnotationLayerDAO.update(self._model, self._properties)
except DAOUpdateFailedError as ex:
Expand Down
2 changes: 2 additions & 0 deletions superset/charts/commands/update.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,8 @@ def __init__(self, model_id: int, data: dict[str, Any]):

def run(self) -> Model:
self.validate()
assert self._model

try:
if self._properties.get("query_context_generation") is None:
self._properties["last_saved_at"] = datetime.now()
Expand Down
2 changes: 1 addition & 1 deletion superset/commands/export/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@


class ExportModelsCommand(BaseCommand):
dao: type[BaseDAO] = BaseDAO
dao: type[BaseDAO[Model]] = BaseDAO
not_found: type[CommandException] = CommandException

def __init__(self, model_ids: list[int], export_related: bool = True):
Expand Down
8 changes: 2 additions & 6 deletions superset/daos/annotation.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,7 @@
logger = logging.getLogger(__name__)


class AnnotationDAO(BaseDAO):
model_cls = Annotation

class AnnotationDAO(BaseDAO[Annotation]):
@staticmethod
def bulk_delete(models: Optional[list[Annotation]], commit: bool = True) -> None:
item_ids = [model.id for model in models] if models else []
Expand Down Expand Up @@ -64,9 +62,7 @@ def validate_update_uniqueness(
return not db.session.query(query.exists()).scalar()


class AnnotationLayerDAO(BaseDAO):
model_cls = AnnotationLayer

class AnnotationLayerDAO(BaseDAO[AnnotationLayer]):
@staticmethod
def bulk_delete(
models: Optional[list[AnnotationLayer]], commit: bool = True
Expand Down
35 changes: 17 additions & 18 deletions superset/daos/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=isinstance-second-argument-not-valid-type
from typing import Any, Optional, Union
from typing import Any, Generic, get_args, Optional, TypeVar, Union

from flask_appbuilder.models.filters import BaseFilter
from flask_appbuilder.models.sqla import Model
Expand All @@ -31,8 +30,10 @@
)
from superset.extensions import db

T = TypeVar("T", bound=Model) # pylint: disable=invalid-name

class BaseDAO:

class BaseDAO(Generic[T]):
"""
Base DAO, implement base CRUD sqlalchemy operations
"""
Expand All @@ -48,6 +49,11 @@ class BaseDAO:
"""
id_column_name = "id"

def __init_subclass__(cls) -> None: # pylint: disable=arguments-differ
cls.model_cls = get_args(
cls.__orig_bases__[0] # type: ignore # pylint: disable=no-member
)[0]

Comment on lines +52 to +56
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wow I didn't know Python has added core support for generic types! Adding a reference here for other reviewers: https://peps.python.org/pep-0560/

@classmethod
def find_by_id(
cls,
Expand Down Expand Up @@ -78,7 +84,7 @@ def find_by_ids(
model_ids: Union[list[str], list[int]],
session: Session = None,
skip_base_filter: bool = False,
) -> list[Model]:
) -> list[T]:
"""
Find a List of models by a list of ids, if defined applies `base_filter`
"""
Expand All @@ -95,7 +101,7 @@ def find_by_ids(
return query.all()

@classmethod
def find_all(cls) -> list[Model]:
def find_all(cls) -> list[T]:
"""
Get all that fit the `base_filter`
"""
Expand All @@ -108,7 +114,7 @@ def find_all(cls) -> list[Model]:
return query.all()

@classmethod
def find_one_or_none(cls, **filter_by: Any) -> Optional[Model]:
def find_one_or_none(cls, **filter_by: Any) -> Optional[T]:
"""
Get the first that fit the `base_filter`
"""
Expand All @@ -121,7 +127,7 @@ def find_one_or_none(cls, **filter_by: Any) -> Optional[Model]:
return query.filter_by(**filter_by).one_or_none()

@classmethod
def create(cls, properties: dict[str, Any], commit: bool = True) -> Model:
def create(cls, properties: dict[str, Any], commit: bool = True) -> T:
"""
Generic for creating models
:raises: DAOCreateFailedError
Expand All @@ -141,30 +147,23 @@ def create(cls, properties: dict[str, Any], commit: bool = True) -> Model:
return model

@classmethod
def save(cls, instance_model: Model, commit: bool = True) -> Model:
def save(cls, instance_model: T, commit: bool = True) -> None:
"""
Generic for saving models
:raises: DAOCreateFailedError
"""
if cls.model_cls is None:
raise DAOConfigError()
if not isinstance(instance_model, cls.model_cls):
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is no longer needed. The Mypy checks ensure this can't happen.

raise DAOCreateFailedError(
"the instance model is not a type of the model class"
)
try:
db.session.add(instance_model)
if commit:
db.session.commit()
except SQLAlchemyError as ex: # pragma: no cover
db.session.rollback()
raise DAOCreateFailedError(exception=ex) from ex
return instance_model
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Inconsistent return type detected with the ChartDAO implementation. Note there's no need to return instance_model as it's passed in as in input.


@classmethod
def update(
cls, model: Model, properties: dict[str, Any], commit: bool = True
) -> Model:
def update(cls, model: T, properties: dict[str, Any], commit: bool = True) -> T:
"""
Generic update a model
:raises: DAOCreateFailedError
Expand All @@ -181,7 +180,7 @@ def update(
return model

@classmethod
def delete(cls, model: Model, commit: bool = True) -> Model:
def delete(cls, model: T, commit: bool = True) -> T:
"""
Generic delete a model
:raises: DAODeleteFailedError
Expand All @@ -196,7 +195,7 @@ def delete(cls, model: Model, commit: bool = True) -> Model:
return model

@classmethod
def bulk_delete(cls, models: list[Model], commit: bool = True) -> None:
def bulk_delete(cls, models: list[T], commit: bool = True) -> None:
try:
for model in models:
cls.delete(model, False)
Expand Down
3 changes: 1 addition & 2 deletions superset/daos/chart.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,7 @@
logger = logging.getLogger(__name__)


class ChartDAO(BaseDAO):
model_cls = Slice
class ChartDAO(BaseDAO[Slice]):
base_filter = ChartFilter

@staticmethod
Expand Down
4 changes: 1 addition & 3 deletions superset/daos/css.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,7 @@
logger = logging.getLogger(__name__)


class CssTemplateDAO(BaseDAO):
model_cls = CssTemplate

class CssTemplateDAO(BaseDAO[CssTemplate]):
@staticmethod
def bulk_delete(models: Optional[list[CssTemplate]], commit: bool = True) -> None:
item_ids = [model.id for model in models] if models else []
Expand Down
10 changes: 3 additions & 7 deletions superset/daos/dashboard.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,7 @@
logger = logging.getLogger(__name__)


class DashboardDAO(BaseDAO):
model_cls = Dashboard
class DashboardDAO(BaseDAO[Dashboard]):
base_filter = DashboardAccessFilter

@classmethod
Expand Down Expand Up @@ -379,8 +378,7 @@ def remove_favorite(dashboard: Dashboard) -> None:
db.session.commit()


class EmbeddedDashboardDAO(BaseDAO):
model_cls = EmbeddedDashboard
class EmbeddedDashboardDAO(BaseDAO[EmbeddedDashboard]):
# There isn't really a regular scenario where we would rather get Embedded by id
id_column_name = "uuid"

Expand All @@ -407,9 +405,7 @@ def create(cls, properties: dict[str, Any], commit: bool = True) -> Any:
raise NotImplementedError("Use EmbeddedDashboardDAO.upsert() instead.")


class FilterSetDAO(BaseDAO):
model_cls = FilterSet

class FilterSetDAO(BaseDAO[FilterSet]):
@classmethod
def create(cls, properties: dict[str, Any], commit: bool = True) -> Model:
if cls.model_cls is None:
Expand Down
7 changes: 2 additions & 5 deletions superset/daos/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,7 @@
logger = logging.getLogger(__name__)


class DatabaseDAO(BaseDAO):
model_cls = Database
class DatabaseDAO(BaseDAO[Database]):
base_filter = DatabaseFilter

@classmethod
Expand Down Expand Up @@ -138,9 +137,7 @@ def get_ssh_tunnel(cls, database_id: int) -> Optional[SSHTunnel]:
return ssh_tunnel


class SSHTunnelDAO(BaseDAO):
model_cls = SSHTunnel

class SSHTunnelDAO(BaseDAO[SSHTunnel]):
@classmethod
def update(
cls,
Expand Down
13 changes: 6 additions & 7 deletions superset/daos/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,7 @@
logger = logging.getLogger(__name__)


class DatasetDAO(BaseDAO): # pylint: disable=too-many-public-methods
model_cls = SqlaTable
class DatasetDAO(BaseDAO[SqlaTable]): # pylint: disable=too-many-public-methods
base_filter = DatasourceFilter

@staticmethod
Expand Down Expand Up @@ -151,7 +150,7 @@ def update(
model: SqlaTable,
properties: dict[str, Any],
commit: bool = True,
) -> Optional[SqlaTable]:
) -> SqlaTable:
"""
Updates a Dataset model on the metadata DB
"""
Expand Down Expand Up @@ -397,9 +396,9 @@ def get_table_by_name(database_id: int, table_name: str) -> Optional[SqlaTable]:
)


class DatasetColumnDAO(BaseDAO):
model_cls = TableColumn
class DatasetColumnDAO(BaseDAO[TableColumn]):
pass


class DatasetMetricDAO(BaseDAO):
model_cls = SqlMetric
class DatasetMetricDAO(BaseDAO[SqlMetric]):
pass
2 changes: 1 addition & 1 deletion superset/daos/datasource.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
Datasource = Union[Dataset, SqlaTable, Table, Query, SavedQuery]


class DatasourceDAO(BaseDAO):
class DatasourceDAO(BaseDAO[Datasource]):
sources: dict[Union[DatasourceType, str], type[Datasource]] = {
DatasourceType.TABLE: SqlaTable,
DatasourceType.QUERY: Query,
Expand Down
4 changes: 1 addition & 3 deletions superset/daos/log.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,7 @@
from superset.utils.dates import datetime_to_epoch


class LogDAO(BaseDAO):
model_cls = Log

class LogDAO(BaseDAO[Log]):
@staticmethod
def get_recent_activity(
actions: list[str],
Expand Down
6 changes: 2 additions & 4 deletions superset/daos/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,7 @@
logger = logging.getLogger(__name__)


class QueryDAO(BaseDAO):
model_cls = Query
class QueryDAO(BaseDAO[Query]):
base_filter = QueryFilter

@staticmethod
Expand Down Expand Up @@ -101,8 +100,7 @@ def stop_query(client_id: str) -> None:
db.session.commit()


class SavedQueryDAO(BaseDAO):
model_cls = SavedQuery
class SavedQueryDAO(BaseDAO[SavedQuery]):
base_filter = SavedQueryFilter

@staticmethod
Expand Down
3 changes: 1 addition & 2 deletions superset/daos/report.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,7 @@
REPORT_SCHEDULE_ERROR_NOTIFICATION_MARKER = "Notification sent with error"


class ReportScheduleDAO(BaseDAO):
model_cls = ReportSchedule
class ReportScheduleDAO(BaseDAO[ReportSchedule]):
base_filter = ReportScheduleFilter

@staticmethod
Expand Down
4 changes: 2 additions & 2 deletions superset/daos/security.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,5 +19,5 @@
from superset.daos.base import BaseDAO


class RLSDAO(BaseDAO):
model_cls = RowLevelSecurityFilter
class RLSDAO(BaseDAO[RowLevelSecurityFilter]):
pass
3 changes: 1 addition & 2 deletions superset/daos/tag.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,7 @@
logger = logging.getLogger(__name__)


class TagDAO(BaseDAO):
model_cls = Tag
class TagDAO(BaseDAO[Tag]):
# base_filter = TagAccessFilter

@staticmethod
Expand Down
2 changes: 2 additions & 0 deletions superset/dashboards/commands/delete.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ def __init__(self, model_id: int):

def run(self) -> Model:
self.validate()
assert self._model

try:
dashboard = DashboardDAO.delete(self._model)
except DAODeleteFailedError as ex:
Expand Down
2 changes: 2 additions & 0 deletions superset/dashboards/commands/update.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ def __init__(self, model_id: int, data: dict[str, Any]):

def run(self) -> Model:
self.validate()
assert self._model

try:
dashboard = DashboardDAO.update(self._model, self._properties, commit=False)
if self._properties.get("json_metadata"):
Expand Down
Loading