From c17eef2a620aa1a78eb6bc6b21b828c3673fedeb Mon Sep 17 00:00:00 2001 From: Elizabeth Thompson Date: Mon, 9 May 2022 09:23:19 -0700 Subject: [PATCH] add datasource_type to delete command --- .../ExploreViewContainer.test.tsx | 4 +- superset/explore/form_data/commands/delete.py | 3 +- superset/explore/form_data/commands/get.py | 1 - superset/explore/permalink/commands/get.py | 1 - superset/utils/cache_manager.py | 36 +++++++++---- superset/views/core.py | 40 +++++---------- .../explore/form_data/api_tests.py | 1 - .../explore/permalink/api_tests.py | 4 +- .../utils/cache_manager_tests.py | 50 +++++++++++++++++++ 9 files changed, 93 insertions(+), 47 deletions(-) create mode 100644 tests/integration_tests/utils/cache_manager_tests.py diff --git a/superset-frontend/src/explore/components/ExploreViewContainer/ExploreViewContainer.test.tsx b/superset-frontend/src/explore/components/ExploreViewContainer/ExploreViewContainer.test.tsx index 7743997a35529..2260346968dd3 100644 --- a/superset-frontend/src/explore/components/ExploreViewContainer/ExploreViewContainer.test.tsx +++ b/superset-frontend/src/explore/components/ExploreViewContainer/ExploreViewContainer.test.tsx @@ -92,7 +92,7 @@ test('generates a new form_data param when none is available', async () => { expect(replaceState).toHaveBeenCalledWith( expect.anything(), undefined, - expect.stringMatching('dataset_id'), + expect.stringMatching('datasource_id'), ); replaceState.mockRestore(); }); @@ -109,7 +109,7 @@ test('generates a different form_data param when one is provided and is mounting expect(replaceState).toHaveBeenCalledWith( expect.anything(), undefined, - expect.stringMatching('dataset_id'), + expect.stringMatching('datasource_id'), ); replaceState.mockRestore(); }); diff --git a/superset/explore/form_data/commands/delete.py b/superset/explore/form_data/commands/delete.py index 1065c7234c483..5b6e5d7251c22 100644 --- a/superset/explore/form_data/commands/delete.py +++ b/superset/explore/form_data/commands/delete.py @@ -31,7 +31,6 @@ TemporaryCacheDeleteFailedError, ) from superset.temporary_cache.utils import cache_key -from superset.utils.core import DatasourceType logger = logging.getLogger(__name__) @@ -56,7 +55,7 @@ def run(self) -> bool: raise TemporaryCacheAccessDeniedError() tab_id = self._cmd_params.tab_id contextual_key = cache_key( - session.get("_id"), tab_id, datasource_id, chart_id + session.get("_id"), tab_id, datasource_id, chart_id, datasource_type ) cache_manager.explore_form_data_cache.delete(contextual_key) return cache_manager.explore_form_data_cache.delete(key) diff --git a/superset/explore/form_data/commands/get.py b/superset/explore/form_data/commands/get.py index 6f8c451a33118..2d9a40773deb3 100644 --- a/superset/explore/form_data/commands/get.py +++ b/superset/explore/form_data/commands/get.py @@ -27,7 +27,6 @@ from superset.explore.utils import check_chart_access from superset.extensions import cache_manager from superset.temporary_cache.commands.exceptions import TemporaryCacheGetFailedError -from superset.utils.core import DatasourceType logger = logging.getLogger(__name__) diff --git a/superset/explore/permalink/commands/get.py b/superset/explore/permalink/commands/get.py index f599c49707a58..548f2980965c4 100644 --- a/superset/explore/permalink/commands/get.py +++ b/superset/explore/permalink/commands/get.py @@ -28,7 +28,6 @@ from superset.key_value.commands.get import GetKeyValueCommand from superset.key_value.exceptions import KeyValueGetFailedError, KeyValueParseKeyError from superset.key_value.utils import decode_permalink_id -from superset.utils.core import DatasourceType logger = logging.getLogger(__name__) diff --git a/superset/utils/cache_manager.py b/superset/utils/cache_manager.py index cfaee5526b9d8..6399a94ab8942 100644 --- a/superset/utils/cache_manager.py +++ b/superset/utils/cache_manager.py @@ -15,9 +15,11 @@ # specific language governing permissions and limitations # under the License. import logging +from typing import Any, Optional, Union from flask import Flask from flask_caching import Cache +from markupsafe import Markup from superset.utils.core import DatasourceType @@ -26,6 +28,27 @@ CACHE_IMPORT_PATH = "superset.extensions.metastore_cache.SupersetMetastoreCache" +class ExploreFormDataCache(Cache): + def get(self, *args: Any, **kwargs: Any) -> Optional[Union[str, Markup]]: + cache = self.cache.get(*args, **kwargs) + + if not cache: + return None + + # rename keys for existing cache based on new TemporaryExploreState model + if isinstance(cache, dict): + cache = { + ("datasource_id" if key == "dataset_id" else key): value + for (key, value) in cache.items() + } + # add default datasource_type if it doesn't exist + # temporarily defaulting to table until sqlatables are deprecated + if "datasource_type" not in cache: + cache["datasource_type"] = DatasourceType.TABLE + + return cache + + class CacheManager: def __init__(self) -> None: super().__init__() @@ -34,7 +57,7 @@ def __init__(self) -> None: self._data_cache = Cache() self._thumbnail_cache = Cache() self._filter_state_cache = Cache() - self._explore_form_data_cache = Cache() + self._explore_form_data_cache = ExploreFormDataCache() @staticmethod def _init_cache( @@ -95,13 +118,4 @@ def filter_state_cache(self) -> Cache: @property def explore_form_data_cache(self) -> Cache: - # rename keys for existing cache based on new TemporaryExploreState model - cache = { - ("datasource_id" if key == "dataset_id" else key): value - for (key, value) in self._explore_form_data_cache.items() - } - # add default datasource_type if it doesn't exist - # temporarily defaulting to table until sqlatables are deprecated - if "datasource_type" not in cache: - cache["datasource_type"] = DatasourceType.TABLE - return cache + return self._explore_form_data_cache diff --git a/superset/views/core.py b/superset/views/core.py index 9ced8b485ff3e..8d82128a58db7 100755 --- a/superset/views/core.py +++ b/superset/views/core.py @@ -320,9 +320,7 @@ def approve(self) -> FlaskResponse: # pylint: disable=too-many-locals,no-self-u def clean_fulfilled_requests(session: Session) -> None: for dar in session.query(DAR).all(): datasource = ConnectorRegistry.get_datasource( - dar.datasource_type, - dar.datasource_id, - session, + dar.datasource_type, dar.datasource_id, session ) if not datasource or security_manager.can_access_datasource(datasource): # Dataset does not exist anymore @@ -462,7 +460,7 @@ def get_raw_results(self, viz_obj: BaseViz) -> FlaskResponse: "data": payload["df"].to_dict("records"), "colnames": payload.get("colnames"), "coltypes": payload.get("coltypes"), - }, + } ) def get_samples(self, viz_obj: BaseViz) -> FlaskResponse: @@ -630,8 +628,7 @@ def explore_json( and not security_manager.can_access("can_csv", "Superset") ): return json_error_response( - _("You don't have the rights to ") + _("download as csv"), - status=403, + _("You don't have the rights to ") + _("download as csv"), status=403 ) form_data = get_form_data()[0] @@ -947,9 +944,7 @@ def filter( # pylint: disable=no-self-use """ # TODO: Cache endpoint by user, datasource and column datasource = ConnectorRegistry.get_datasource( - datasource_type, - datasource_id, - db.session, + datasource_type, datasource_id, db.session ) if not datasource: return json_error_response(DATASOURCE_MISSING_ERR) @@ -1418,10 +1413,7 @@ def get_user_activity_access_error(user_id: int) -> Optional[FlaskResponse]: try: security_manager.raise_for_user_activity_access(user_id) except SupersetSecurityException as ex: - return json_error_response( - ex.message, - status=403, - ) + return json_error_response(ex.message, status=403) return None @api @@ -1444,8 +1436,7 @@ def recent_activity( # pylint: disable=too-many-locals has_subject_title = or_( and_( - Dashboard.dashboard_title is not None, - Dashboard.dashboard_title != "", + Dashboard.dashboard_title is not None, Dashboard.dashboard_title != "" ), and_(Slice.slice_name is not None, Slice.slice_name != ""), ) @@ -1479,10 +1470,7 @@ def recent_activity( # pylint: disable=too-many-locals Slice.slice_name, ) .outerjoin(Dashboard, Dashboard.id == subqry.c.dashboard_id) - .outerjoin( - Slice, - Slice.id == subqry.c.slice_id, - ) + .outerjoin(Slice, Slice.id == subqry.c.slice_id) .filter(has_subject_title) .order_by(subqry.c.dttm.desc()) .limit(limit) @@ -1973,8 +1961,7 @@ def dashboard( @has_access @expose("/dashboard/p//", methods=["GET"]) def dashboard_permalink( # pylint: disable=no-self-use - self, - key: str, + self, key: str ) -> FlaskResponse: try: value = GetDashboardPermalinkCommand(g.user, key).run() @@ -2101,7 +2088,7 @@ def sqllab_viz(self) -> FlaskResponse: # pylint: disable=no-self-use @has_access @expose("/extra_table_metadata////") @event_logger.log_this - def extra_table_metadata( # pylint: disable=no-self-use + def extra_table_metadata( self, database_id: int, table_name: str, schema: str ) -> FlaskResponse: logger.warning( @@ -2294,7 +2281,7 @@ def stop_query(self) -> FlaskResponse: QueryStatus.TIMED_OUT, ]: logger.warning( - "Query with client_id could not be stopped: query already complete", + "Query with client_id could not be stopped: query already complete" ) return self.json_response("OK") @@ -2448,8 +2435,7 @@ def _set_http_status_into_Sql_lab_exception(ex: SqlLabException) -> None: ex.status = 403 def _create_response_from_execution_context( # pylint: disable=invalid-name, no-self-use - self, - command_result: CommandResult, + self, command_result: CommandResult ) -> FlaskResponse: status_code = 200 @@ -2540,9 +2526,7 @@ def fetch_datasource_metadata(self) -> FlaskResponse: # pylint: disable=no-self datasource_id, datasource_type = request.args["datasourceKey"].split("__") datasource = ConnectorRegistry.get_datasource( - datasource_type, - datasource_id, - db.session, + datasource_type, datasource_id, db.session ) # Check if datasource exists if not datasource: diff --git a/tests/integration_tests/explore/form_data/api_tests.py b/tests/integration_tests/explore/form_data/api_tests.py index 51ce679947318..965920b256d60 100644 --- a/tests/integration_tests/explore/form_data/api_tests.py +++ b/tests/integration_tests/explore/form_data/api_tests.py @@ -100,7 +100,6 @@ def test_post(client, chart_id: int, datasource_id: int, datasource_type: str): "form_data": INITIAL_FORM_DATA, } resp = client.post("api/v1/explore/form_data", json=payload) - print(resp.data) assert resp.status_code == 201 diff --git a/tests/integration_tests/explore/permalink/api_tests.py b/tests/integration_tests/explore/permalink/api_tests.py index a44bc70a7b49a..b5228ab301b24 100644 --- a/tests/integration_tests/explore/permalink/api_tests.py +++ b/tests/integration_tests/explore/permalink/api_tests.py @@ -27,6 +27,7 @@ from superset.key_value.types import KeyValueResource from superset.key_value.utils import decode_permalink_id, encode_permalink_key from superset.models.slice import Slice +from superset.utils.core import DatasourceType from tests.integration_tests.base_tests import login from tests.integration_tests.fixtures.client import client from tests.integration_tests.fixtures.world_bank_dashboard import ( @@ -97,7 +98,8 @@ def test_get_missing_chart(client, chart, permalink_salt: str) -> None: value=pickle.dumps( { "chartId": chart_id, - "datasetId": chart.datasource.id, + "datasourceId": chart.datasource.id, + "datasourceType": DatasourceType.TABLE, "formData": { "slice_id": chart_id, "datasource": f"{chart.datasource.id}__{chart.datasource.type}", diff --git a/tests/integration_tests/utils/cache_manager_tests.py b/tests/integration_tests/utils/cache_manager_tests.py new file mode 100644 index 0000000000000..825f728eeeda0 --- /dev/null +++ b/tests/integration_tests/utils/cache_manager_tests.py @@ -0,0 +1,50 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import pytest + +from superset.extensions import cache_manager +from superset.utils.core import DatasourceType + + +def test_get_set_explore_form_data_cache(): + key = "12345" + data = {"foo": "bar", "datasource_type": "query"} + cache_manager.explore_form_data_cache.set(key, data) + assert cache_manager.explore_form_data_cache.get(key) == data + + +def test_get_same_context_twice(): + key = "12345" + data = {"foo": "bar", "datasource_type": "query"} + cache_manager.explore_form_data_cache.set(key, data) + assert cache_manager.explore_form_data_cache.get(key) == data + assert cache_manager.explore_form_data_cache.get(key) == data + + +def test_get_set_explore_form_data_cache_no_datasource_type(): + key = "12345" + data = {"foo": "bar"} + cache_manager.explore_form_data_cache.set(key, data) + # datasource_type should be added because it is not present + assert cache_manager.explore_form_data_cache.get(key) == { + "datasource_type": DatasourceType.TABLE, + **data, + } + + +def test_get_explore_form_data_cache_invalid_key(): + assert cache_manager.explore_form_data_cache.get("foo") == None