From 1530c3479237e0540fe5dda7bb14fbfc8291c88c Mon Sep 17 00:00:00 2001 From: Yongjie Zhao Date: Wed, 1 Jun 2022 18:24:01 +0800 Subject: [PATCH] fix: failed samples should throw exception (#20228) --- .../DataTablesPane/components/SamplesPane.tsx | 11 ++-- superset/common/utils/query_cache_manager.py | 15 ++++++ superset/datasets/api.py | 8 +-- superset/datasets/commands/samples.py | 11 +++- tests/integration_tests/datasets/api_tests.py | 52 +++++++++++++++++++ 5 files changed, 85 insertions(+), 12 deletions(-) diff --git a/superset-frontend/src/explore/components/DataTablesPane/components/SamplesPane.tsx b/superset-frontend/src/explore/components/DataTablesPane/components/SamplesPane.tsx index de3ef919f6b6c..be30e4e827aa8 100644 --- a/superset-frontend/src/explore/components/DataTablesPane/components/SamplesPane.tsx +++ b/superset-frontend/src/explore/components/DataTablesPane/components/SamplesPane.tsx @@ -17,7 +17,7 @@ * under the License. */ import React, { useState, useEffect, useMemo } from 'react'; -import { GenericDataType, styled, t } from '@superset-ui/core'; +import { ensureIsArray, GenericDataType, styled, t } from '@superset-ui/core'; import Loading from 'src/components/Loading'; import { EmptyStateMedium } from 'src/components/EmptyState'; import TableView, { EmptyWrapperType } from 'src/components/TableView'; @@ -63,9 +63,9 @@ export const SamplesPane = ({ setIsLoading(true); getDatasetSamples(datasource.id, queryForce) .then(response => { - setData(response.data); - setColnames(response.colnames); - setColtypes(response.coltypes); + setData(ensureIsArray(response.data)); + setColnames(ensureIsArray(response.colnames)); + setColtypes(ensureIsArray(response.coltypes)); setResponseError(''); cache.add(datasource); if (queryForce && actions) { @@ -73,6 +73,9 @@ export const SamplesPane = ({ } }) .catch(error => { + setData([]); + setColnames([]); + setColtypes([]); setResponseError(`${error.name}: ${error.message}`); }) .finally(() => { diff --git a/superset/common/utils/query_cache_manager.py b/superset/common/utils/query_cache_manager.py index 92fb3561234f4..76aa5ddef32e3 100644 --- a/superset/common/utils/query_cache_manager.py +++ b/superset/common/utils/query_cache_manager.py @@ -187,3 +187,18 @@ def set( """ if key: set_and_log_cache(_cache[region], key, value, timeout, datasource_uid) + + @staticmethod + def delete( + key: Optional[str], + region: CacheRegion = CacheRegion.DEFAULT, + ) -> None: + if key: + _cache[region].delete(key) + + @staticmethod + def has( + key: Optional[str], + region: CacheRegion = CacheRegion.DEFAULT, + ) -> bool: + return bool(_cache[region].get(key)) if key else False diff --git a/superset/datasets/api.py b/superset/datasets/api.py index 313634766c001..17e99959e9675 100644 --- a/superset/datasets/api.py +++ b/superset/datasets/api.py @@ -825,10 +825,4 @@ def samples(self, pk: int) -> Response: except DatasetForbiddenError: return self.response_403() except DatasetSamplesFailedError as ex: - logger.error( - "Error get dataset samples %s: %s", - self.__class__.__name__, - str(ex), - exc_info=True, - ) - return self.response_422(message=str(ex)) + return self.response_400(message=str(ex)) diff --git a/superset/datasets/commands/samples.py b/superset/datasets/commands/samples.py index 4be2c6e90f850..79ac729be0801 100644 --- a/superset/datasets/commands/samples.py +++ b/superset/datasets/commands/samples.py @@ -22,7 +22,9 @@ from superset.commands.base import BaseCommand from superset.common.chart_data import ChartDataResultType from superset.common.query_context_factory import QueryContextFactory +from superset.common.utils.query_cache_manager import QueryCacheManager from superset.connectors.sqla.models import SqlaTable +from superset.constants import CacheRegion from superset.datasets.commands.exceptions import ( DatasetForbiddenError, DatasetNotFoundError, @@ -30,6 +32,7 @@ ) from superset.datasets.dao import DatasetDAO from superset.exceptions import SupersetSecurityException +from superset.utils.core import QueryStatus from superset.views.base import check_ownership logger = logging.getLogger(__name__) @@ -58,7 +61,13 @@ def run(self) -> Dict[str, Any]: ) results = qc_instance.get_payload() try: - return results["queries"][0] + sample_data = results["queries"][0] + error_msg = sample_data.get("error") + if sample_data.get("status") == QueryStatus.FAILED and error_msg: + cache_key = sample_data.get("cache_key") + QueryCacheManager.delete(cache_key, region=CacheRegion.DATA) + raise DatasetSamplesFailedError(error_msg) + return sample_data except (IndexError, KeyError) as exc: raise DatasetSamplesFailedError from exc diff --git a/tests/integration_tests/datasets/api_tests.py b/tests/integration_tests/datasets/api_tests.py index b426ddad71523..28bb617c17c19 100644 --- a/tests/integration_tests/datasets/api_tests.py +++ b/tests/integration_tests/datasets/api_tests.py @@ -27,7 +27,9 @@ import yaml from sqlalchemy.sql import func +from superset.common.utils.query_cache_manager import QueryCacheManager from superset.connectors.sqla.models import SqlaTable, SqlMetric, TableColumn +from superset.constants import CacheRegion from superset.dao.exceptions import ( DAOCreateFailedError, DAODeleteFailedError, @@ -1883,6 +1885,8 @@ def test_get_dataset_samples(self): assert rv.status_code == 200 assert "result" in rv_data assert rv_data["result"]["cached_dttm"] is not None + cache_key1 = rv_data["result"]["cache_key"] + assert QueryCacheManager.has(cache_key1, region=CacheRegion.DATA) # 2. should through cache uri2 = f"api/v1/dataset/{dataset.id}/samples?force=true" @@ -1892,6 +1896,8 @@ def test_get_dataset_samples(self): rv2 = self.client.get(uri2) rv_data2 = json.loads(rv2.data) assert rv_data2["result"]["cached_dttm"] is None + cache_key2 = rv_data2["result"]["cache_key"] + assert QueryCacheManager.has(cache_key2, region=CacheRegion.DATA) # 3. data precision assert "colnames" in rv_data2["result"] @@ -1903,3 +1909,49 @@ def test_get_dataset_samples(self): f' limit {self.app.config["SAMPLES_ROW_LIMIT"]}' ).to_dict(orient="records") assert eager_samples == rv_data2["result"]["data"] + + @pytest.mark.usefixtures("create_datasets") + def test_get_dataset_samples_with_failed_cc(self): + dataset = self.get_fixture_datasets()[0] + + self.login(username="admin") + failed_column = TableColumn( + column_name="DUMMY CC", + type="VARCHAR(255)", + table=dataset, + expression="INCORRECT SQL", + ) + uri = f"api/v1/dataset/{dataset.id}/samples" + dataset.columns.append(failed_column) + rv = self.client.get(uri) + assert rv.status_code == 400 + rv_data = json.loads(rv.data) + assert "message" in rv_data + if dataset.database.db_engine_spec.engine_name == "PostgreSQL": + assert "INCORRECT SQL" in rv_data.get("message") + + def test_get_dataset_samples_on_virtual_dataset(self): + virtual_dataset = SqlaTable( + table_name="virtual_dataset", + sql=("SELECT 'foo' as foo, 'bar' as bar"), + database=get_example_database(), + ) + TableColumn(column_name="foo", type="VARCHAR(255)", table=virtual_dataset) + TableColumn(column_name="bar", type="VARCHAR(255)", table=virtual_dataset) + SqlMetric(metric_name="count", expression="count(*)", table=virtual_dataset) + + self.login(username="admin") + uri = f"api/v1/dataset/{virtual_dataset.id}/samples" + rv = self.client.get(uri) + assert rv.status_code == 200 + rv_data = json.loads(rv.data) + cache_key = rv_data["result"]["cache_key"] + assert QueryCacheManager.has(cache_key, region=CacheRegion.DATA) + + # remove original column in dataset + virtual_dataset.sql = "SELECT 'foo' as foo" + rv = self.client.get(uri) + assert rv.status_code == 400 + + db.session.delete(virtual_dataset) + db.session.commit()