diff --git a/scripts/tests/run.sh b/scripts/tests/run.sh index 24233010107dd..6ba578d989513 100755 --- a/scripts/tests/run.sh +++ b/scripts/tests/run.sh @@ -138,5 +138,5 @@ fi if [ $RUN_TESTS -eq 1 ] then - pytest --durations=0 --maxfail=1 "${TEST_MODULE}" + pytest --durations=0 --maxfail=10 "${TEST_MODULE}" fi diff --git a/superset-frontend/src/constants.ts b/superset-frontend/src/constants.ts index 1d8698334ecec..a7228b857bdf5 100644 --- a/superset-frontend/src/constants.ts +++ b/superset-frontend/src/constants.ts @@ -71,6 +71,10 @@ export const URL_PARAMS = { name: 'datasource_id', type: 'string', }, + datasetId: { + name: 'dataset_id', + type: 'string', + }, datasourceType: { name: 'datasource_type', type: 'string', @@ -90,6 +94,7 @@ export const RESERVED_CHART_URL_PARAMS: string[] = [ URL_PARAMS.sliceId.name, URL_PARAMS.datasourceId.name, URL_PARAMS.datasourceType.name, + URL_PARAMS.datasetId.name, ]; export const RESERVED_DASHBOARD_URL_PARAMS: string[] = [ URL_PARAMS.nativeFilters.name, diff --git a/superset-frontend/src/explore/components/controls/DatasourceControl/index.jsx b/superset-frontend/src/explore/components/controls/DatasourceControl/index.jsx index 1fd07668b1c7e..3d6ea2fdd2662 100644 --- a/superset-frontend/src/explore/components/controls/DatasourceControl/index.jsx +++ b/superset-frontend/src/explore/components/controls/DatasourceControl/index.jsx @@ -190,9 +190,8 @@ class DatasourceControl extends React.PureComponent { let isMissingParams = false; if (isMissingDatasource) { const datasourceId = getUrlParam(URL_PARAMS.datasourceId); - const datasourceType = getUrlParam(URL_PARAMS.datasourceType); const sliceId = getUrlParam(URL_PARAMS.sliceId); - if (!datasourceId && !sliceId && !datasourceType) { + if (!datasourceId && !sliceId) { isMissingParams = true; } } diff --git a/superset/cachekeys/schemas.py b/superset/cachekeys/schemas.py index a44a7c545add4..3d913e8b5f6e7 100644 --- a/superset/cachekeys/schemas.py +++ b/superset/cachekeys/schemas.py @@ -22,6 +22,7 @@ datasource_type_description, datasource_uid_description, ) +from superset.utils.core import DatasourceType class Datasource(Schema): @@ -36,7 +37,7 @@ class Datasource(Schema): ) datasource_type = fields.String( description=datasource_type_description, - validate=validate.OneOf(choices=("druid", "table", "view")), + validate=validate.OneOf(choices=[ds.value for ds in DatasourceType]), required=True, ) diff --git a/superset/charts/schemas.py b/superset/charts/schemas.py index 6a05e4d9942bc..8a82e364be47c 100644 --- a/superset/charts/schemas.py +++ b/superset/charts/schemas.py @@ -31,6 +31,7 @@ from superset.utils import pandas_postprocessing, schema as utils from superset.utils.core import ( AnnotationType, + DatasourceType, FilterOperator, PostProcessingBoxplotWhiskerType, PostProcessingContributionOrientation, @@ -198,7 +199,7 @@ class ChartPostSchema(Schema): datasource_id = fields.Integer(description=datasource_id_description, required=True) datasource_type = fields.String( description=datasource_type_description, - validate=validate.OneOf(choices=("druid", "table", "view")), + validate=validate.OneOf(choices=[ds.value for ds in DatasourceType]), required=True, ) datasource_name = fields.String( @@ -244,7 +245,7 @@ class ChartPutSchema(Schema): ) datasource_type = fields.String( description=datasource_type_description, - validate=validate.OneOf(choices=("druid", "table", "view")), + validate=validate.OneOf(choices=[ds.value for ds in DatasourceType]), allow_none=True, ) dashboards = fields.List(fields.Integer(description=dashboards_description)) @@ -983,7 +984,7 @@ class ChartDataDatasourceSchema(Schema): ) type = fields.String( description="Datasource type", - validate=validate.OneOf(choices=("druid", "table")), + validate=validate.OneOf(choices=[ds.value for ds in DatasourceType]), ) diff --git a/superset/commands/exceptions.py b/superset/commands/exceptions.py index ad590a596c59a..a661ef4d6047d 100644 --- a/superset/commands/exceptions.py +++ b/superset/commands/exceptions.py @@ -115,8 +115,24 @@ def __init__(self) -> None: super().__init__([_("Some roles do not exist")], field_name="roles") +class DatasourceTypeInvalidError(ValidationError): + status = 422 + + def __init__(self) -> None: + super().__init__( + [_("Datasource type is invalid")], field_name="datasource_type" + ) + + class DatasourceNotFoundValidationError(ValidationError): status = 404 def __init__(self) -> None: super().__init__([_("Datasource does not exist")], field_name="datasource_id") + + +class QueryNotFoundValidationError(ValidationError): + status = 404 + + def __init__(self) -> None: + super().__init__([_("Query does not exist")], field_name="datasource_id") diff --git a/superset/dao/datasource/dao.py b/superset/dao/datasource/dao.py index 8b4845db3c51b..caa45564aa250 100644 --- a/superset/dao/datasource/dao.py +++ b/superset/dao/datasource/dao.py @@ -39,11 +39,11 @@ class DatasourceDAO(BaseDAO): sources: Dict[DatasourceType, Type[Datasource]] = { - DatasourceType.SQLATABLE: SqlaTable, + DatasourceType.TABLE: SqlaTable, DatasourceType.QUERY: Query, DatasourceType.SAVEDQUERY: SavedQuery, DatasourceType.DATASET: Dataset, - DatasourceType.TABLE: Table, + DatasourceType.SLTABLE: Table, } @classmethod @@ -66,7 +66,7 @@ def get_datasource( @classmethod def get_all_sqlatables_datasources(cls, session: Session) -> List[Datasource]: - source_class = DatasourceDAO.sources[DatasourceType.SQLATABLE] + source_class = DatasourceDAO.sources[DatasourceType.TABLE] qry = session.query(source_class) qry = source_class.default_query(qry) return qry.all() diff --git a/superset/explore/form_data/commands/create.py b/superset/explore/form_data/commands/create.py index 5f8aeabb98751..7946980c82684 100644 --- a/superset/explore/form_data/commands/create.py +++ b/superset/explore/form_data/commands/create.py @@ -47,8 +47,7 @@ def run(self) -> str: form_data = self._cmd_params.form_data check_access(datasource_id, chart_id, actor, datasource_type) contextual_key = cache_key( - session.get( - "_id"), tab_id, datasource_id, chart_id, datasource_type + session.get("_id"), tab_id, datasource_id, chart_id, datasource_type ) key = cache_manager.explore_form_data_cache.get(contextual_key) if not key or not tab_id: diff --git a/superset/explore/form_data/commands/delete.py b/superset/explore/form_data/commands/delete.py index 25619d91a4985..f11bcb90c4310 100644 --- a/superset/explore/form_data/commands/delete.py +++ b/superset/explore/form_data/commands/delete.py @@ -32,6 +32,7 @@ TemporaryCacheDeleteFailedError, ) from superset.temporary_cache.utils import cache_key +from superset.utils.core import DatasourceType logger = logging.getLogger(__name__) @@ -50,7 +51,7 @@ def run(self) -> bool: if state: datasource_id: int = state["datasource_id"] chart_id: Optional[int] = state["chart_id"] - datasource_type: str = state["datasource_type"] + datasource_type = DatasourceType(state["datasource_type"]) check_access(datasource_id, chart_id, actor, datasource_type) if state["owner"] != get_owner(actor): raise TemporaryCacheAccessDeniedError() @@ -58,6 +59,11 @@ def run(self) -> bool: contextual_key = cache_key( session.get("_id"), tab_id, datasource_id, chart_id, datasource_type ) + if contextual_key is None: + # check again with old keys + contextual_key = cache_key( + session.get("_id"), tab_id, datasource_id, chart_id + ) cache_manager.explore_form_data_cache.delete(contextual_key) return cache_manager.explore_form_data_cache.delete(key) return False diff --git a/superset/explore/form_data/commands/get.py b/superset/explore/form_data/commands/get.py index b52681027b12e..982c8e3b4b7d7 100644 --- a/superset/explore/form_data/commands/get.py +++ b/superset/explore/form_data/commands/get.py @@ -27,6 +27,7 @@ from superset.explore.form_data.commands.utils import check_access from superset.extensions import cache_manager from superset.temporary_cache.commands.exceptions import TemporaryCacheGetFailedError +from superset.utils.core import DatasourceType logger = logging.getLogger(__name__) @@ -49,7 +50,7 @@ def run(self) -> Optional[str]: state["datasource_id"], state["chart_id"], actor, - state["datasource_type"], + DatasourceType(state["datasource_type"]), ) if self._refresh_timeout: cache_manager.explore_form_data_cache.set(key, state) diff --git a/superset/explore/form_data/commands/parameters.py b/superset/explore/form_data/commands/parameters.py index 1b30bdde67d5f..fec06a581fb79 100644 --- a/superset/explore/form_data/commands/parameters.py +++ b/superset/explore/form_data/commands/parameters.py @@ -19,11 +19,13 @@ from flask_appbuilder.security.sqla.models import User +from superset.utils.core import DatasourceType + @dataclass class CommandParameters: actor: User - datasource_type: str = "" + datasource_type: DatasourceType = DatasourceType.TABLE datasource_id: int = 0 chart_id: int = 0 tab_id: Optional[int] = None diff --git a/superset/explore/form_data/commands/update.py b/superset/explore/form_data/commands/update.py index b01846af77d4c..2190b04da28e2 100644 --- a/superset/explore/form_data/commands/update.py +++ b/superset/explore/form_data/commands/update.py @@ -65,14 +65,17 @@ def run(self) -> Optional[str]: # Generate a new key if tab_id changes or equals 0 tab_id = self._cmd_params.tab_id contextual_key = cache_key( - session.get( - "_id"), tab_id, datasource_id, chart_id, datasource_type + session.get("_id"), tab_id, datasource_id, chart_id, datasource_type ) + if contextual_key is None: + # check again with old keys + contextual_key = cache_key( + session.get("_id"), tab_id, datasource_id, chart_id + ) key = cache_manager.explore_form_data_cache.get(contextual_key) if not key or not tab_id: key = random_key() - cache_manager.explore_form_data_cache.set( - contextual_key, key) + cache_manager.explore_form_data_cache.set(contextual_key, key) new_state: TemporaryExploreState = { "owner": owner, diff --git a/superset/explore/form_data/commands/utils.py b/superset/explore/form_data/commands/utils.py index 4f12bd9a97a18..7927457178c9e 100644 --- a/superset/explore/form_data/commands/utils.py +++ b/superset/explore/form_data/commands/utils.py @@ -31,9 +31,15 @@ TemporaryCacheAccessDeniedError, TemporaryCacheResourceNotFoundError, ) +from superset.utils.core import DatasourceType -def check_access(datasource_id: int, chart_id: Optional[int], actor: User, datasource_type: str) -> None: +def check_access( + datasource_id: int, + chart_id: Optional[int], + actor: User, + datasource_type: DatasourceType, +) -> None: try: explore_check_access(datasource_id, chart_id, actor, datasource_type) except (ChartNotFoundError, DatasetNotFoundError) as ex: diff --git a/superset/explore/form_data/schemas.py b/superset/explore/form_data/schemas.py index 187d1b86469f2..192df089e818b 100644 --- a/superset/explore/form_data/schemas.py +++ b/superset/explore/form_data/schemas.py @@ -14,7 +14,9 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -from marshmallow import fields, Schema +from marshmallow import fields, Schema, validate + +from superset.utils.core import DatasourceType class FormDataPostSchema(Schema): @@ -22,7 +24,10 @@ class FormDataPostSchema(Schema): required=True, allow_none=False, description="The datasource ID" ) datasource_type = fields.String( - required=True, allow_none=False, description="The datasource type" + required=True, + allow_none=False, + description="The datasource type", + validate=validate.OneOf(choices=[ds.value for ds in DatasourceType]), ) chart_id = fields.Integer(required=False, description="The chart ID") form_data = fields.String( @@ -35,7 +40,10 @@ class FormDataPutSchema(Schema): required=True, allow_none=False, description="The datasource ID" ) datasource_type = fields.String( - required=True, allow_none=False, description="The datasource type" + required=True, + allow_none=False, + description="The datasource type", + validate=validate.OneOf(choices=[ds.value for ds in DatasourceType]), ) chart_id = fields.Integer(required=False, description="The chart ID") form_data = fields.String( diff --git a/superset/explore/permalink/commands/create.py b/superset/explore/permalink/commands/create.py index 57f4fbd9020e7..7bd6365d814bd 100644 --- a/superset/explore/permalink/commands/create.py +++ b/superset/explore/permalink/commands/create.py @@ -25,6 +25,7 @@ from superset.explore.utils import check_access as check_chart_access from superset.key_value.commands.create import CreateKeyValueCommand from superset.key_value.utils import encode_permalink_key +from superset.utils.core import DatasourceType logger = logging.getLogger(__name__) @@ -39,9 +40,9 @@ def __init__(self, actor: User, state: Dict[str, Any]): def run(self) -> str: self.validate() try: - datasource = self.datasource.split("__") - datasource_id: int = int(datasource[0]) - datasource_type: str = datasource[1] + d_id, d_type = self.datasource.split("__") + datasource_id = int(d_id) + datasource_type = DatasourceType(d_type) check_chart_access( datasource_id, self.chart_id, self.actor, datasource_type ) @@ -59,8 +60,7 @@ def run(self) -> str: ) key = command.run() if key.id is None: - raise ExplorePermalinkCreateFailedError( - "Unexpected missing key id") + raise ExplorePermalinkCreateFailedError("Unexpected missing key id") return encode_permalink_key(key=key.id, salt=self.salt) except SQLAlchemyError as ex: logger.exception("Error running create command") diff --git a/superset/explore/permalink/commands/get.py b/superset/explore/permalink/commands/get.py index 6aa85242a1fd5..f75df69d7a63e 100644 --- a/superset/explore/permalink/commands/get.py +++ b/superset/explore/permalink/commands/get.py @@ -28,6 +28,7 @@ from superset.key_value.commands.get import GetKeyValueCommand from superset.key_value.exceptions import KeyValueGetFailedError, KeyValueParseKeyError from superset.key_value.utils import decode_permalink_id +from superset.utils.core import DatasourceType logger = logging.getLogger(__name__) @@ -48,9 +49,8 @@ def run(self) -> Optional[ExplorePermalinkValue]: if value: chart_id: Optional[int] = value.get("chartId") datasource_id: int = value["datasourceId"] - datasource_type: str = value["datasourceType"] - check_chart_access(datasource_id, chart_id, - self.actor, datasource_type) + datasource_type = DatasourceType(value["datasourceType"]) + check_chart_access(datasource_id, chart_id, self.actor, datasource_type) return value return None except ( diff --git a/superset/explore/utils.py b/superset/explore/utils.py index 05e8599446cdc..f0bfd8f0aa40c 100644 --- a/superset/explore/utils.py +++ b/superset/explore/utils.py @@ -24,38 +24,64 @@ ChartNotFoundError, ) from superset.charts.dao import ChartDAO -from superset.commands.exceptions import DatasourceNotFoundValidationError +from superset.commands.exceptions import ( + DatasourceNotFoundValidationError, + DatasourceTypeInvalidError, + QueryNotFoundValidationError, +) from superset.datasets.commands.exceptions import ( DatasetAccessDeniedError, DatasetNotFoundError, ) from superset.datasets.dao import DatasetDAO +from superset.queries.dao import QueryDAO from superset.utils.core import DatasourceType from superset.views.base import is_user_admin from superset.views.utils import is_owner -def check_datasource_access(datasource_id: int, datasource_type: str) -> Optional[bool]: - if datasource_id: - if datasource_type == DatasourceType.TABLE: - return check_dataset_access(datasource_id) - raise DatasourceNotFoundValidationError - - def check_dataset_access(dataset_id: int) -> Optional[bool]: if dataset_id: dataset = DatasetDAO.find_by_id(dataset_id) if dataset: - can_access_datasource = security_manager.can_access_datasource( - dataset) + can_access_datasource = security_manager.can_access_datasource(dataset) if can_access_datasource: return True raise DatasetAccessDeniedError() raise DatasetNotFoundError() +def check_query_access(query_id: int) -> Optional[bool]: + if query_id: + query = QueryDAO.find_by_id(query_id) + if query: + security_manager.raise_for_access(query=query) + return True + raise QueryNotFoundValidationError() + + +ACCESS_FUNCTION_MAP = { + DatasourceType.TABLE: check_dataset_access, + DatasourceType.QUERY: check_query_access, +} + + +def check_datasource_access( + datasource_id: int, datasource_type: DatasourceType +) -> Optional[bool]: + if datasource_id: + try: + return ACCESS_FUNCTION_MAP[datasource_type](datasource_id) + except KeyError as ex: + raise DatasourceTypeInvalidError() from ex + raise DatasourceNotFoundValidationError() + + def check_access( - datasource_id: int, chart_id: Optional[int], actor: User, datasource_type: str + datasource_id: int, + chart_id: Optional[int], + actor: User, + datasource_type: DatasourceType, ) -> Optional[bool]: check_datasource_access(datasource_id, datasource_type) if not chart_id: diff --git a/superset/migrations/versions/b47bb0d9fddb_testing_performance.py b/superset/migrations/versions/b47bb0d9fddb_testing_performance.py deleted file mode 100644 index 4e51d43d8e28d..0000000000000 --- a/superset/migrations/versions/b47bb0d9fddb_testing_performance.py +++ /dev/null @@ -1,45 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""testing performance - -Revision ID: b47bb0d9fddb -Revises: 6f139c533bea -Create Date: 2022-05-24 17:48:22.786769 - -""" - -# revision identifiers, used by Alembic. -from superset.models.sql_lab import Query -from sqlalchemy.dialects import postgresql -import sqlalchemy as sa -from alembic import op -from superset import db -revision = 'b47bb0d9fddb' -down_revision = '6f139c533bea' - - -def upgrade(): - bind = op.get_bind() - session = db.Session(bind=bind) - - count_query = session.query(Query).with_entities(sa.func.count(Query.id)) - count = count_query.scalar() - print(f"COUNTCOUNT {count}") - - -def downgrade(): - pass diff --git a/superset/utils/cache_manager.py b/superset/utils/cache_manager.py index 6399a94ab8942..d3b2dbdb00d5e 100644 --- a/superset/utils/cache_manager.py +++ b/superset/utils/cache_manager.py @@ -35,7 +35,7 @@ def get(self, *args: Any, **kwargs: Any) -> Optional[Union[str, Markup]]: if not cache: return None - # rename keys for existing cache based on new TemporaryExploreState model + # rename data keys for existing cache based on new TemporaryExploreState model if isinstance(cache, dict): cache = { ("datasource_id" if key == "dataset_id" else key): value diff --git a/superset/utils/core.py b/superset/utils/core.py index b15511224a1a8..6c90837959edb 100644 --- a/superset/utils/core.py +++ b/superset/utils/core.py @@ -176,11 +176,12 @@ class GenericDataType(IntEnum): class DatasourceType(str, Enum): - SQLATABLE = "sqlatable" + SLTABLE = "sl_table" TABLE = "table" DATASET = "dataset" QUERY = "query" SAVEDQUERY = "saved_query" + VIEW = "view" class DatasourceDict(TypedDict): @@ -331,8 +332,7 @@ class ReservedUrlParameters(str, Enum): @staticmethod def is_standalone_mode() -> Optional[bool]: - standalone_param = request.args.get( - ReservedUrlParameters.STANDALONE.value) + standalone_param = request.args.get(ReservedUrlParameters.STANDALONE.value) standalone: Optional[bool] = bool( standalone_param and standalone_param != "false" and standalone_param != "0" ) @@ -500,8 +500,7 @@ def default(self, o: Any) -> Union[Dict[Any, Any], str]: if isinstance(o, uuid.UUID): return str(o) try: - vals = {k: v for k, v in o.__dict__.items() if k != - "_sa_instance_state"} + vals = {k: v for k, v in o.__dict__.items() if k != "_sa_instance_state"} return {"__{}__".format(o.__class__.__name__): vals} except Exception: # pylint: disable=broad-except if isinstance(o, datetime): @@ -585,8 +584,7 @@ def json_iso_dttm_ser(obj: Any, pessimistic: bool = False) -> str: if pessimistic: return "Unserializable [{}]".format(type(obj)) - raise TypeError( - "Unserializable object {} of type {}".format(obj, type(obj))) + raise TypeError("Unserializable object {} of type {}".format(obj, type(obj))) return obj @@ -607,8 +605,7 @@ def json_int_dttm_ser(obj: Any) -> float: elif isinstance(obj, date): obj = (obj - EPOCH.date()).total_seconds() * 1000 else: - raise TypeError( - "Unserializable object {} of type {}".format(obj, type(obj))) + raise TypeError("Unserializable object {} of type {}".format(obj, type(obj))) return obj @@ -1061,23 +1058,20 @@ def simple_filter_to_adhoc( } if filter_clause.get("isExtra"): result["isExtra"] = True - result["filterOptionName"] = md5_sha_from_dict( - cast(Dict[Any, Any], result)) + result["filterOptionName"] = md5_sha_from_dict(cast(Dict[Any, Any], result)) return result def form_data_to_adhoc(form_data: Dict[str, Any], clause: str) -> AdhocFilterClause: if clause not in ("where", "having"): - raise ValueError( - __("Unsupported clause type: %(clause)s", clause=clause)) + raise ValueError(__("Unsupported clause type: %(clause)s", clause=clause)) result: AdhocFilterClause = { "clause": clause.upper(), "expressionType": "SQL", "sqlExpression": form_data.get(clause), } - result["filterOptionName"] = md5_sha_from_dict( - cast(Dict[Any, Any], result)) + result["filterOptionName"] = md5_sha_from_dict(cast(Dict[Any, Any], result)) return result @@ -1089,8 +1083,7 @@ def merge_extra_form_data(form_data: Dict[str, Any]) -> None: """ filter_keys = ["filters", "adhoc_filters"] extra_form_data = form_data.pop("extra_form_data", {}) - append_filters: List[QueryObjectFilterClause] = extra_form_data.get( - "filters", None) + append_filters: List[QueryObjectFilterClause] = extra_form_data.get("filters", None) # merge append extras for key in [key for key in EXTRA_FORM_DATA_APPEND_KEYS if key not in filter_keys]: @@ -1127,8 +1120,7 @@ def merge_extra_form_data(form_data: Dict[str, Any]) -> None: for key, value in form_data.items(): if re.match("adhoc_filter.*", key): value.extend( - simple_filter_to_adhoc( - {"isExtra": True, **fltr}) # type: ignore + simple_filter_to_adhoc({"isExtra": True, **fltr}) # type: ignore for fltr in append_filters if fltr ) @@ -1171,8 +1163,7 @@ def get_filter_key(f: Dict[str, Any]) -> str: and existing.get("comparator") is not None and existing.get("subject") is not None ): - existing_filters[get_filter_key( - existing)] = existing["comparator"] + existing_filters[get_filter_key(existing)] = existing["comparator"] for filtr in form_data[ # pylint: disable=too-many-nested-blocks "extra_filters" @@ -1196,8 +1187,7 @@ def get_filter_key(f: Dict[str, Any]) -> str: # Add filters for unequal lists # order doesn't matter if set(existing_filters[filter_key]) != set(filtr["val"]): - adhoc_filters.append( - simple_filter_to_adhoc(filtr)) + adhoc_filters.append(simple_filter_to_adhoc(filtr)) else: adhoc_filters.append(simple_filter_to_adhoc(filtr)) else: @@ -1563,7 +1553,7 @@ def split( elif character == ")": parens -= 1 elif character == quote: - if quotes and string[j - len(escaped_quote) + 1: j + 1] != escaped_quote: + if quotes and string[j - len(escaped_quote) + 1 : j + 1] != escaped_quote: quotes = False elif not quotes: quotes = True @@ -1706,8 +1696,7 @@ def get_time_filter_status( datasource: "BaseDatasource", applied_time_extras: Dict[str, str], ) -> Tuple[List[Dict[str, str]], List[Dict[str, str]]]: - temporal_columns = { - col.column_name for col in datasource.columns if col.is_dttm} + temporal_columns = {col.column_name for col in datasource.columns if col.is_dttm} applied: List[Dict[str, str]] = [] rejected: List[Dict[str, str]] = [] time_column = applied_time_extras.get(ExtraFiltersTimeColumnType.TIME_COL) diff --git a/superset/views/core.py b/superset/views/core.py index 2bb94dd1e2a5b..12b04dd706c4d 100755 --- a/superset/views/core.py +++ b/superset/views/core.py @@ -320,7 +320,9 @@ def approve(self) -> FlaskResponse: # pylint: disable=too-many-locals,no-self-u def clean_fulfilled_requests(session: Session) -> None: for dar in session.query(DAR).all(): datasource = ConnectorRegistry.get_datasource( - dar.datasource_type, dar.datasource_id, session + dar.datasource_type, + dar.datasource_id, + session, ) if not datasource or security_manager.can_access_datasource(datasource): # Dataset does not exist anymore @@ -460,7 +462,7 @@ def get_raw_results(self, viz_obj: BaseViz) -> FlaskResponse: "data": payload["df"].to_dict("records"), "colnames": payload.get("colnames"), "coltypes": payload.get("coltypes"), - } + }, ) def get_samples(self, viz_obj: BaseViz) -> FlaskResponse: @@ -628,7 +630,8 @@ def explore_json( and not security_manager.can_access("can_csv", "Superset") ): return json_error_response( - _("You don't have the rights to ") + _("download as csv"), status=403 + _("You don't have the rights to ") + _("download as csv"), + status=403, ) form_data = get_form_data()[0] @@ -948,7 +951,9 @@ def filter( # pylint: disable=no-self-use """ # TODO: Cache endpoint by user, datasource and column datasource = ConnectorRegistry.get_datasource( - datasource_type, datasource_id, db.session + datasource_type, + datasource_id, + db.session, ) if not datasource: return json_error_response(DATASOURCE_MISSING_ERR) @@ -1413,7 +1418,10 @@ def get_user_activity_access_error(user_id: int) -> Optional[FlaskResponse]: try: security_manager.raise_for_user_activity_access(user_id) except SupersetSecurityException as ex: - return json_error_response(ex.message, status=403) + return json_error_response( + ex.message, + status=403, + ) return None @api @@ -1436,7 +1444,8 @@ def recent_activity( # pylint: disable=too-many-locals has_subject_title = or_( and_( - Dashboard.dashboard_title is not None, Dashboard.dashboard_title != "" + Dashboard.dashboard_title is not None, + Dashboard.dashboard_title != "", ), and_(Slice.slice_name is not None, Slice.slice_name != ""), ) @@ -1470,7 +1479,10 @@ def recent_activity( # pylint: disable=too-many-locals Slice.slice_name, ) .outerjoin(Dashboard, Dashboard.id == subqry.c.dashboard_id) - .outerjoin(Slice, Slice.id == subqry.c.slice_id) + .outerjoin( + Slice, + Slice.id == subqry.c.slice_id, + ) .filter(has_subject_title) .order_by(subqry.c.dttm.desc()) .limit(limit) @@ -1961,7 +1973,8 @@ def dashboard( @has_access @expose("/dashboard/p//", methods=["GET"]) def dashboard_permalink( # pylint: disable=no-self-use - self, key: str + self, + key: str, ) -> FlaskResponse: try: value = GetDashboardPermalinkCommand(g.user, key).run() @@ -2281,7 +2294,7 @@ def stop_query(self) -> FlaskResponse: QueryStatus.TIMED_OUT, ]: logger.warning( - "Query with client_id could not be stopped: query already complete" + "Query with client_id could not be stopped: query already complete", ) return self.json_response("OK") @@ -2435,7 +2448,8 @@ def _set_http_status_into_Sql_lab_exception(ex: SqlLabException) -> None: ex.status = 403 def _create_response_from_execution_context( # pylint: disable=invalid-name, no-self-use - self, command_result: CommandResult + self, + command_result: CommandResult, ) -> FlaskResponse: status_code = 200 @@ -2526,7 +2540,9 @@ def fetch_datasource_metadata(self) -> FlaskResponse: # pylint: disable=no-self datasource_id, datasource_type = request.args["datasourceKey"].split("__") datasource = ConnectorRegistry.get_datasource( - datasource_type, datasource_id, db.session + datasource_type, + datasource_id, + db.session, ) # Check if datasource exists if not datasource: diff --git a/tests/integration_tests/charts/api_tests.py b/tests/integration_tests/charts/api_tests.py index 3a1f4d6ea27a6..a37acf6eafc3a 100644 --- a/tests/integration_tests/charts/api_tests.py +++ b/tests/integration_tests/charts/api_tests.py @@ -520,7 +520,13 @@ def test_create_chart_validate_datasource(self): response = json.loads(rv.data.decode("utf-8")) self.assertEqual( response, - {"message": {"datasource_type": ["Must be one of: druid, table, view."]}}, + { + "message": { + "datasource_type": [ + "Must be one of: sl_table, table, dataset, query, saved_query, view." + ] + } + }, ) chart_data = { "slice_name": "title1", @@ -686,7 +692,13 @@ def test_update_chart_validate_datasource(self): response = json.loads(rv.data.decode("utf-8")) self.assertEqual( response, - {"message": {"datasource_type": ["Must be one of: druid, table, view."]}}, + { + "message": { + "datasource_type": [ + "Must be one of: sl_table, table, dataset, query, saved_query, view." + ] + } + }, ) chart_data = {"datasource_id": 0, "datasource_type": "table"} diff --git a/tests/integration_tests/explore/form_data/api_tests.py b/tests/integration_tests/explore/form_data/api_tests.py index 965920b256d60..8b375df56ae38 100644 --- a/tests/integration_tests/explore/form_data/api_tests.py +++ b/tests/integration_tests/explore/form_data/api_tests.py @@ -56,7 +56,7 @@ def admin_id() -> int: @pytest.fixture -def datasource_id() -> int: +def datasource() -> int: with app.app_context() as ctx: session: Session = ctx.app.appbuilder.get_session dataset = ( @@ -64,38 +64,26 @@ def datasource_id() -> int: .filter_by(table_name="wb_health_population") .first() ) - return dataset.id - - -@pytest.fixture -def datasource_type() -> int: - with app.app_context() as ctx: - session: Session = ctx.app.appbuilder.get_session - dataset = ( - session.query(SqlaTable) - .filter_by(table_name="wb_health_population") - .first() - ) - return dataset.type + return dataset @pytest.fixture(autouse=True) -def cache(chart_id, admin_id, datasource_id, datasource_type): +def cache(chart_id, admin_id, datasource): entry: TemporaryExploreState = { "owner": admin_id, - "datasource_id": datasource_id, - "datasource_type": datasource_type, + "datasource_id": datasource.id, + "datasource_type": datasource.type, "chart_id": chart_id, "form_data": INITIAL_FORM_DATA, } cache_manager.explore_form_data_cache.set(KEY, entry) -def test_post(client, chart_id: int, datasource_id: int, datasource_type: str): +def test_post(client, chart_id: int, datasource: SqlaTable): login(client, "admin") payload = { - "datasource_id": datasource_id, - "datasource_type": datasource_type, + "datasource_id": datasource.id, + "datasource_type": datasource.type, "chart_id": chart_id, "form_data": INITIAL_FORM_DATA, } @@ -103,13 +91,11 @@ def test_post(client, chart_id: int, datasource_id: int, datasource_type: str): assert resp.status_code == 201 -def test_post_bad_request_non_string( - client, chart_id: int, datasource_id: int, datasource_type: str -): +def test_post_bad_request_non_string(client, chart_id: int, datasource: SqlaTable): login(client, "admin") payload = { - "datasource_id": datasource_id, - "datasource_type": datasource_type, + "datasource_id": datasource.id, + "datasource_type": datasource.type, "chart_id": chart_id, "form_data": 1234, } @@ -117,13 +103,11 @@ def test_post_bad_request_non_string( assert resp.status_code == 400 -def test_post_bad_request_non_json_string( - client, chart_id: int, datasource_id: int, datasource_type: str -): +def test_post_bad_request_non_json_string(client, chart_id: int, datasource: SqlaTable): login(client, "admin") payload = { - "datasource_id": datasource_id, - "datasource_type": datasource_type, + "datasource_id": datasource.id, + "datasource_type": datasource.type, "chart_id": chart_id, "form_data": "foo", } @@ -131,13 +115,11 @@ def test_post_bad_request_non_json_string( assert resp.status_code == 400 -def test_post_access_denied( - client, chart_id: int, datasource_id: int, datasource_type: str -): +def test_post_access_denied(client, chart_id: int, datasource: SqlaTable): login(client, "gamma") payload = { - "datasource_id": datasource_id, - "datasource_type": datasource_type, + "datasource_id": datasource.id, + "datasource_type": datasource.type, "chart_id": chart_id, "form_data": INITIAL_FORM_DATA, } @@ -145,13 +127,11 @@ def test_post_access_denied( assert resp.status_code == 404 -def test_post_same_key_for_same_context( - client, chart_id: int, datasource_id: int, datasource_type: str -): +def test_post_same_key_for_same_context(client, chart_id: int, datasource: SqlaTable): login(client, "admin") payload = { - "datasource_id": datasource_id, - "datasource_type": datasource_type, + "datasource_id": datasource.id, + "datasource_type": datasource.type, "chart_id": chart_id, "form_data": UPDATED_FORM_DATA, } @@ -165,12 +145,12 @@ def test_post_same_key_for_same_context( def test_post_different_key_for_different_context( - client, chart_id: int, datasource_id: int, datasource_type: str + client, chart_id: int, datasource: SqlaTable ): login(client, "admin") payload = { - "datasource_id": datasource_id, - "datasource_type": datasource_type, + "datasource_id": datasource.id, + "datasource_type": datasource.type, "chart_id": chart_id, "form_data": UPDATED_FORM_DATA, } @@ -178,8 +158,8 @@ def test_post_different_key_for_different_context( data = json.loads(resp.data.decode("utf-8")) first_key = data.get("key") payload = { - "datasource_id": datasource_id, - "datasource_type": datasource_type, + "datasource_id": datasource.id, + "datasource_type": datasource.type, "form_data": json.dumps({"test": "initial value"}), } resp = client.post("api/v1/explore/form_data?tab_id=1", json=payload) @@ -188,13 +168,11 @@ def test_post_different_key_for_different_context( assert first_key != second_key -def test_post_same_key_for_same_tab_id( - client, chart_id: int, datasource_id: int, datasource_type: str -): +def test_post_same_key_for_same_tab_id(client, chart_id: int, datasource: SqlaTable): login(client, "admin") payload = { - "datasource_id": datasource_id, - "datasource_type": datasource_type, + "datasource_id": datasource.id, + "datasource_type": datasource.type, "chart_id": chart_id, "form_data": json.dumps({"test": "initial value"}), } @@ -208,12 +186,12 @@ def test_post_same_key_for_same_tab_id( def test_post_different_key_for_different_tab_id( - client, chart_id: int, datasource_id: int, datasource_type: str + client, chart_id: int, datasource: SqlaTable ): login(client, "admin") payload = { - "datasource_id": datasource_id, - "datasource_type": datasource_type, + "datasource_id": datasource.id, + "datasource_type": datasource.type, "chart_id": chart_id, "form_data": json.dumps({"test": "initial value"}), } @@ -226,13 +204,11 @@ def test_post_different_key_for_different_tab_id( assert first_key != second_key -def test_post_different_key_for_no_tab_id( - client, chart_id: int, datasource_id: int, datasource_type: str -): +def test_post_different_key_for_no_tab_id(client, chart_id: int, datasource: SqlaTable): login(client, "admin") payload = { - "datasource_id": datasource_id, - "datasource_type": datasource_type, + "datasource_id": datasource.id, + "datasource_type": datasource.type, "chart_id": chart_id, "form_data": INITIAL_FORM_DATA, } @@ -245,11 +221,11 @@ def test_post_different_key_for_no_tab_id( assert first_key != second_key -def test_put(client, chart_id: int, datasource_id: int, datasource_type: str): +def test_put(client, chart_id: int, datasource: SqlaTable): login(client, "admin") payload = { - "datasource_id": datasource_id, - "datasource_type": datasource_type, + "datasource_id": datasource.id, + "datasource_type": datasource.type, "chart_id": chart_id, "form_data": UPDATED_FORM_DATA, } @@ -257,13 +233,11 @@ def test_put(client, chart_id: int, datasource_id: int, datasource_type: str): assert resp.status_code == 200 -def test_put_same_key_for_same_tab_id( - client, chart_id: int, datasource_id: int, datasource_type: str -): +def test_put_same_key_for_same_tab_id(client, chart_id: int, datasource: SqlaTable): login(client, "admin") payload = { - "datasource_id": datasource_id, - "datasource_type": datasource_type, + "datasource_id": datasource.id, + "datasource_type": datasource.type, "chart_id": chart_id, "form_data": UPDATED_FORM_DATA, } @@ -277,12 +251,12 @@ def test_put_same_key_for_same_tab_id( def test_put_different_key_for_different_tab_id( - client, chart_id: int, datasource_id: int, datasource_type: str + client, chart_id: int, datasource: SqlaTable ): login(client, "admin") payload = { - "datasource_id": datasource_id, - "datasource_type": datasource_type, + "datasource_id": datasource.id, + "datasource_type": datasource.type, "chart_id": chart_id, "form_data": UPDATED_FORM_DATA, } @@ -295,13 +269,11 @@ def test_put_different_key_for_different_tab_id( assert first_key != second_key -def test_put_different_key_for_no_tab_id( - client, chart_id: int, datasource_id: int, datasource_type: str -): +def test_put_different_key_for_no_tab_id(client, chart_id: int, datasource: SqlaTable): login(client, "admin") payload = { - "datasource_id": datasource_id, - "datasource_type": datasource_type, + "datasource_id": datasource.id, + "datasource_type": datasource.type, "chart_id": chart_id, "form_data": UPDATED_FORM_DATA, } @@ -314,13 +286,11 @@ def test_put_different_key_for_no_tab_id( assert first_key != second_key -def test_put_bad_request( - client, chart_id: int, datasource_id: int, datasource_type: str -): +def test_put_bad_request(client, chart_id: int, datasource: SqlaTable): login(client, "admin") payload = { - "datasource_id": datasource_id, - "datasource_type": datasource_type, + "datasource_id": datasource.id, + "datasource_type": datasource.type, "chart_id": chart_id, "form_data": 1234, } @@ -328,13 +298,11 @@ def test_put_bad_request( assert resp.status_code == 400 -def test_put_bad_request_non_string( - client, chart_id: int, datasource_id: int, datasource_type: str -): +def test_put_bad_request_non_string(client, chart_id: int, datasource: SqlaTable): login(client, "admin") payload = { - "datasource_id": datasource_id, - "datasource_type": datasource_type, + "datasource_id": datasource.id, + "datasource_type": datasource.type, "chart_id": chart_id, "form_data": 1234, } @@ -342,13 +310,11 @@ def test_put_bad_request_non_string( assert resp.status_code == 400 -def test_put_bad_request_non_json_string( - client, chart_id: int, datasource_id: int, datasource_type: str -): +def test_put_bad_request_non_json_string(client, chart_id: int, datasource: SqlaTable): login(client, "admin") payload = { - "datasource_id": datasource_id, - "datasource_type": datasource_type, + "datasource_id": datasource.id, + "datasource_type": datasource.type, "chart_id": chart_id, "form_data": "foo", } @@ -356,13 +322,11 @@ def test_put_bad_request_non_json_string( assert resp.status_code == 400 -def test_put_access_denied( - client, chart_id: int, datasource_id: int, datasource_type: str -): +def test_put_access_denied(client, chart_id: int, datasource: SqlaTable): login(client, "gamma") payload = { - "datasource_id": datasource_id, - "datasource_type": datasource_type, + "datasource_id": datasource.id, + "datasource_type": datasource.type, "chart_id": chart_id, "form_data": UPDATED_FORM_DATA, } @@ -370,11 +334,11 @@ def test_put_access_denied( assert resp.status_code == 404 -def test_put_not_owner(client, chart_id: int, datasource_id: int, datasource_type: str): +def test_put_not_owner(client, chart_id: int, datasource: SqlaTable): login(client, "gamma") payload = { - "datasource_id": datasource_id, - "datasource_type": datasource_type, + "datasource_id": datasource.id, + "datasource_type": datasource.type, "chart_id": chart_id, "form_data": UPDATED_FORM_DATA, } @@ -422,15 +386,13 @@ def test_delete_access_denied(client): assert resp.status_code == 404 -def test_delete_not_owner( - client, chart_id: int, datasource_id: int, datasource_type: str, admin_id: int -): +def test_delete_not_owner(client, chart_id: int, datasource: SqlaTable, admin_id: int): another_key = "another_key" another_owner = admin_id + 1 entry: TemporaryExploreState = { "owner": another_owner, - "datasource_id": datasource_id, - "datasource_type": datasource_type, + "datasource_id": datasource.id, + "datasource_type": datasource.type, "chart_id": chart_id, "form_data": INITIAL_FORM_DATA, } diff --git a/tests/integration_tests/explore/form_data/commands_tests.py b/tests/integration_tests/explore/form_data/commands_tests.py new file mode 100644 index 0000000000000..c2c4c767f0c44 --- /dev/null +++ b/tests/integration_tests/explore/form_data/commands_tests.py @@ -0,0 +1,94 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from unittest.mock import patch + +import pytest + +from superset import app, db, security_manager +from superset.connectors.sqla.models import SqlaTable +from superset.explore.form_data.commands.create import CreateFormDataCommand +from superset.explore.form_data.commands.parameters import CommandParameters +from superset.models.slice import Slice +from superset.utils.core import DatasourceType, get_example_default_schema +from superset.utils.database import get_example_database +from tests.integration_tests.base_tests import SupersetTestCase + + +class TestCreateFormDataCommand(SupersetTestCase): + @pytest.fixture() + def create_dataset(self): + with self.create_app().app_context(): + dataset = SqlaTable( + table_name="dummy_sql_table", + database=get_example_database(), + schema=get_example_default_schema(), + sql="select 123 as intcol, 'abc' as strcol", + ) + session = db.session + session.add(dataset) + session.commit() + + yield dataset + + # rollback + session.delete(dataset) + session.commit() + + @pytest.fixture() + def create_slice(self): + with self.create_app().app_context(): + session = db.session + dataset = ( + session.query(SqlaTable).filter_by(table_name="dummy_sql_table").first() + ) + slice = Slice( + datasource_id=dataset.id, + datasource_type=DatasourceType.TABLE, + datasource_name="tmp_perm_table", + slice_name="slice_name", + ) + + session.add(slice) + session.commit() + + yield slice + + # rollback + session.delete(slice) + session.commit() + + @patch("superset.security.manager.g") + @pytest.mark.usefixtures("create_dataset", "create_slice") + def test_create_form_data_command(self, mock_g): + mock_g.user = security_manager.find_user("admin") + + dataset = ( + db.session.query(SqlaTable).filter_by(table_name="dummy_sql_table").first() + ) + slice = db.session.query(Slice).filter_by(slice_name="slice_name").first() + args = CommandParameters( + actor=mock_g.user, + datasource_id=dataset.id, + datasource_type=DatasourceType.TABLE, + chart_id=slice.id, + tab_id=1, + form_data="", + ) + command = CreateFormDataCommand(args) + + assert isinstance(command.run(), str) diff --git a/tests/integration_tests/utils/cache_manager_tests.py b/tests/integration_tests/utils/cache_manager_tests.py index 825f728eeeda0..c5d4b390f9c90 100644 --- a/tests/integration_tests/utils/cache_manager_tests.py +++ b/tests/integration_tests/utils/cache_manager_tests.py @@ -17,34 +17,33 @@ import pytest from superset.extensions import cache_manager -from superset.utils.core import DatasourceType - - -def test_get_set_explore_form_data_cache(): - key = "12345" - data = {"foo": "bar", "datasource_type": "query"} - cache_manager.explore_form_data_cache.set(key, data) - assert cache_manager.explore_form_data_cache.get(key) == data - - -def test_get_same_context_twice(): - key = "12345" - data = {"foo": "bar", "datasource_type": "query"} - cache_manager.explore_form_data_cache.set(key, data) - assert cache_manager.explore_form_data_cache.get(key) == data - assert cache_manager.explore_form_data_cache.get(key) == data - - -def test_get_set_explore_form_data_cache_no_datasource_type(): - key = "12345" - data = {"foo": "bar"} - cache_manager.explore_form_data_cache.set(key, data) - # datasource_type should be added because it is not present - assert cache_manager.explore_form_data_cache.get(key) == { - "datasource_type": DatasourceType.TABLE, - **data, - } - - -def test_get_explore_form_data_cache_invalid_key(): - assert cache_manager.explore_form_data_cache.get("foo") == None +from superset.utils.core import backend, DatasourceType +from tests.integration_tests.base_tests import SupersetTestCase + + +class UtilsCacheManagerTests(SupersetTestCase): + def test_get_set_explore_form_data_cache(self): + key = "12345" + data = {"foo": "bar", "datasource_type": "query"} + cache_manager.explore_form_data_cache.set(key, data) + assert cache_manager.explore_form_data_cache.get(key) == data + + def test_get_same_context_twice(self): + key = "12345" + data = {"foo": "bar", "datasource_type": "query"} + cache_manager.explore_form_data_cache.set(key, data) + assert cache_manager.explore_form_data_cache.get(key) == data + assert cache_manager.explore_form_data_cache.get(key) == data + + def test_get_set_explore_form_data_cache_no_datasource_type(self): + key = "12345" + data = {"foo": "bar"} + cache_manager.explore_form_data_cache.set(key, data) + # datasource_type should be added because it is not present + assert cache_manager.explore_form_data_cache.get(key) == { + "datasource_type": DatasourceType.TABLE, + **data, + } + + def test_get_explore_form_data_cache_invalid_key(self): + assert cache_manager.explore_form_data_cache.get("foo") == None diff --git a/tests/unit_tests/dao/datasource_test.py b/tests/unit_tests/dao/datasource_test.py index dd0db265e7a02..a15684d71e699 100644 --- a/tests/unit_tests/dao/datasource_test.py +++ b/tests/unit_tests/dao/datasource_test.py @@ -106,7 +106,7 @@ def test_get_datasource_sqlatable( from superset.dao.datasource.dao import DatasourceDAO result = DatasourceDAO.get_datasource( - datasource_type=DatasourceType.SQLATABLE, + datasource_type=DatasourceType.TABLE, datasource_id=1, session=session_with_data, ) @@ -151,7 +151,9 @@ def test_get_datasource_sl_table(app_context: None, session_with_data: Session) # todo(hugh): This will break once we remove the dual write # update the datsource_id=1 and this will pass again result = DatasourceDAO.get_datasource( - datasource_type=DatasourceType.TABLE, datasource_id=2, session=session_with_data + datasource_type=DatasourceType.SLTABLE, + datasource_id=2, + session=session_with_data, ) assert result.id == 2 diff --git a/tests/unit_tests/explore/utils_test.py b/tests/unit_tests/explore/utils_test.py index efcad7bd7388c..9ef92872177ee 100644 --- a/tests/unit_tests/explore/utils_test.py +++ b/tests/unit_tests/explore/utils_test.py @@ -23,14 +23,21 @@ ChartAccessDeniedError, ChartNotFoundError, ) -from superset.commands.exceptions import DatasourceNotFoundValidationError +from superset.commands.exceptions import ( + DatasourceNotFoundValidationError, + DatasourceTypeInvalidError, + OwnersNotFoundValidationError, + QueryNotFoundValidationError, +) from superset.datasets.commands.exceptions import ( DatasetAccessDeniedError, DatasetNotFoundError, ) +from superset.exceptions import SupersetSecurityException from superset.utils.core import DatasourceType dataset_find_by_id = "superset.datasets.dao.DatasetDAO.find_by_id" +query_find_by_id = "superset.queries.dao.QueryDAO.find_by_id" chart_find_by_id = "superset.charts.dao.ChartDAO.find_by_id" is_user_admin = "superset.explore.utils.is_user_admin" is_owner = "superset.explore.utils.is_owner" @@ -38,6 +45,10 @@ "superset.security.SupersetSecurityManager.can_access_datasource" ) can_access = "superset.security.SupersetSecurityManager.can_access" +raise_for_access = "superset.security.SupersetSecurityManager.raise_for_access" +query_datasources_by_name = ( + "superset.connectors.sqla.models.SqlaTable.query_datasources_by_name" +) def test_unsaved_chart_no_dataset_id(app_context: AppContext) -> None: @@ -67,6 +78,21 @@ def test_unsaved_chart_unknown_dataset_id( ) +def test_unsaved_chart_unknown_query_id( + mocker: MockFixture, app_context: AppContext +) -> None: + from superset.explore.utils import check_access as check_chart_access + + with raises(QueryNotFoundValidationError): + mocker.patch(query_find_by_id, return_value=None) + check_chart_access( + datasource_id=1, + chart_id=0, + actor=User(), + datasource_type=DatasourceType.QUERY, + ) + + def test_unsaved_chart_unauthorized_dataset( mocker: MockFixture, app_context: AppContext ) -> None: @@ -100,22 +126,6 @@ def test_unsaved_chart_authorized_dataset( ) -def test_saved_chart_no_datasource_type( - mocker: MockFixture, app_context: AppContext - - -) -> None: - from superset.connectors.sqla.models import SqlaTable - from superset.explore.utils import check_access as check_chart_access - - with raises(DatasourceNotFoundValidationError): - mocker.patch(dataset_find_by_id, return_value=SqlaTable()) - mocker.patch(can_access_datasource, return_value=True) - check_chart_access( - datasource_id=1, chart_id=1, actor=User(), datasource_type=None - ) - - def test_saved_chart_unknown_chart_id( mocker: MockFixture, app_context: AppContext ) -> None: @@ -223,3 +233,60 @@ def test_saved_chart_no_access(mocker: MockFixture, app_context: AppContext) -> actor=User(), datasource_type=DatasourceType.TABLE, ) + + +def test_dataset_has_access(mocker: MockFixture, app_context: AppContext) -> None: + from superset.connectors.sqla.models import SqlaTable + from superset.explore.utils import check_datasource_access + + mocker.patch(dataset_find_by_id, return_value=SqlaTable()) + mocker.patch(can_access_datasource, return_value=True) + mocker.patch(is_user_admin, return_value=False) + mocker.patch(is_owner, return_value=False) + mocker.patch(can_access, return_value=True) + assert ( + check_datasource_access( + datasource_id=1, + datasource_type=DatasourceType.TABLE, + ) + == True + ) + + +def test_query_has_access(mocker: MockFixture, app_context: AppContext) -> None: + from superset.explore.utils import check_datasource_access + from superset.models.sql_lab import Query + + mocker.patch(query_find_by_id, return_value=Query()) + mocker.patch(raise_for_access, return_value=True) + mocker.patch(is_user_admin, return_value=False) + mocker.patch(is_owner, return_value=False) + mocker.patch(can_access, return_value=True) + assert ( + check_datasource_access( + datasource_id=1, + datasource_type=DatasourceType.QUERY, + ) + == True + ) + + +def test_query_no_access(mocker: MockFixture, app_context: AppContext) -> None: + from superset.connectors.sqla.models import SqlaTable + from superset.explore.utils import check_datasource_access + from superset.models.core import Database + from superset.models.sql_lab import Query + + with raises(SupersetSecurityException): + mocker.patch( + query_find_by_id, + return_value=Query(database=Database(), sql="select * from foo"), + ) + mocker.patch(query_datasources_by_name, return_value=[SqlaTable()]) + mocker.patch(is_user_admin, return_value=False) + mocker.patch(is_owner, return_value=False) + mocker.patch(can_access, return_value=False) + check_datasource_access( + datasource_id=1, + datasource_type=DatasourceType.QUERY, + )