",
+ }
+
+ ImportAssetsCommand = mocker.patch("superset.importexport.api.ImportAssetsCommand")
+
+ root = Path("assets_export")
+ buf = BytesIO()
+ with ZipFile(buf, "w") as bundle:
+ for path, contents in mocked_contents.items():
+ with bundle.open(str(root / path), "w") as fp:
+ fp.write(contents.encode())
+ buf.seek(0)
+
+ form_data = {
+ "bundle": (buf, "assets_export.zip"),
+ "passwords": json.dumps(
+ {"assets_export/databases/imported_database.yaml": "SECRET"}
+ ),
+ }
+ response = client.post(
+ "/api/v1/assets/import/", data=form_data, content_type="multipart/form-data"
+ )
+ assert response.status_code == 200
+ assert response.json == {"message": "OK"}
+
+ passwords = {"assets_export/databases/imported_database.yaml": "SECRET"}
+ ImportAssetsCommand.assert_called_with(mocked_contents, passwords=passwords)
+
+
+def test_import_assets_not_zip(
+ mocker: MockFixture,
+ client: Any,
+ full_api_access: None,
+) -> None:
+ """
+ Test error message when the upload is not a ZIP file.
+ """
+ buf = BytesIO(b"definitely_not_a_zip_file")
+ form_data = {
+ "bundle": (buf, "broken.txt"),
+ }
+ response = client.post(
+ "/api/v1/assets/import/", data=form_data, content_type="multipart/form-data"
+ )
+ assert response.status_code == 422
+ assert response.json == {
+ "errors": [
+ {
+ "message": "Not a ZIP file",
+ "error_type": "GENERIC_COMMAND_ERROR",
+ "level": "warning",
+ "extra": {
+ "issue_codes": [
+ {
+ "code": 1010,
+ "message": (
+ "Issue 1010 - Superset encountered an error while "
+ "running a command."
+ ),
+ }
+ ]
+ },
+ }
+ ]
+ }
+
+
+def test_import_assets_no_form_data(
+ mocker: MockFixture,
+ client: Any,
+ full_api_access: None,
+) -> None:
+ """
+ Test error message when the upload has no form data.
+ """
+ mocker.patch.object(security_manager, "has_access", return_value=True)
+
+ response = client.post("/api/v1/assets/import/", data="some_content")
+ assert response.status_code == 400
+ assert response.json == {
+ "errors": [
+ {
+ "message": "Request MIME type is not 'multipart/form-data'",
+ "error_type": "INVALID_PAYLOAD_FORMAT_ERROR",
+ "level": "error",
+ "extra": {
+ "issue_codes": [
+ {
+ "code": 1019,
+ "message": (
+ "Issue 1019 - The submitted payload has the incorrect "
+ "format."
+ ),
+ }
+ ]
+ },
+ }
+ ]
+ }
+
+
+def test_import_assets_incorrect_form_data(
+ mocker: MockFixture,
+ client: Any,
+ full_api_access: None,
+) -> None:
+ """
+ Test error message when the upload form data has the wrong key.
+ """
+ buf = BytesIO(b"definitely_not_a_zip_file")
+ form_data = {
+ "wrong": (buf, "broken.txt"),
+ }
+ response = client.post(
+ "/api/v1/assets/import/", data=form_data, content_type="multipart/form-data"
+ )
+ assert response.status_code == 400
+ assert response.json == {"message": "Arguments are not correct"}
+
+
+def test_import_assets_no_contents(
+ mocker: MockFixture,
+ client: Any,
+ full_api_access: None,
+) -> None:
+ """
+ Test error message when the ZIP bundle has no contents.
+ """
+ mocked_contents = {
+ "README.txt": "Something is wrong",
+ }
+
+ root = Path("assets_export")
+ buf = BytesIO()
+ with ZipFile(buf, "w") as bundle:
+ for path, contents in mocked_contents.items():
+ with bundle.open(str(root / path), "w") as fp:
+ fp.write(contents.encode())
+ buf.seek(0)
+
+ form_data = {
+ "bundle": (buf, "assets_export.zip"),
+ "passwords": json.dumps(
+ {"assets_export/databases/imported_database.yaml": "SECRET"}
+ ),
+ }
+ response = client.post(
+ "/api/v1/assets/import/", data=form_data, content_type="multipart/form-data"
+ )
+ assert response.status_code == 400
+ assert response.json == {
+ "errors": [
+ {
+ "message": "No valid import files were found",
+ "error_type": "GENERIC_COMMAND_ERROR",
+ "level": "warning",
+ "extra": {
+ "issue_codes": [
+ {
+ "code": 1010,
+ "message": (
+ "Issue 1010 - Superset encountered an error while "
+ "running a command."
+ ),
+ }
+ ]
+ },
+ }
+ ]
+ }
diff --git a/tests/unit_tests/jinja_context_test.py b/tests/unit_tests/jinja_context_test.py
new file mode 100644
index 0000000000000..13b3ae9e9c948
--- /dev/null
+++ b/tests/unit_tests/jinja_context_test.py
@@ -0,0 +1,126 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name, unused-argument
+
+import json
+
+import pytest
+from pytest_mock import MockFixture
+
+from superset.datasets.commands.exceptions import DatasetNotFoundError
+from superset.jinja_context import dataset_macro, where_in
+
+
+def test_where_in() -> None:
+ """
+ Test the ``where_in`` Jinja2 filter.
+ """
+ assert where_in([1, "b", 3]) == "(1, 'b', 3)"
+ assert where_in([1, "b", 3], '"') == '(1, "b", 3)'
+ assert where_in(["O'Malley's"]) == "('O''Malley''s')"
+
+
+def test_dataset_macro(mocker: MockFixture) -> None:
+ """
+ Test the ``dataset_macro`` macro.
+ """
+ # pylint: disable=import-outside-toplevel
+ from superset.connectors.sqla.models import SqlaTable, SqlMetric, TableColumn
+ from superset.models.core import Database
+
+ columns = [
+ TableColumn(column_name="ds", is_dttm=1, type="TIMESTAMP"),
+ TableColumn(column_name="num_boys", type="INTEGER"),
+ TableColumn(column_name="revenue", type="INTEGER"),
+ TableColumn(column_name="expenses", type="INTEGER"),
+ TableColumn(
+ column_name="profit", type="INTEGER", expression="revenue-expenses"
+ ),
+ ]
+ metrics = [
+ SqlMetric(metric_name="cnt", expression="COUNT(*)"),
+ ]
+
+ dataset = SqlaTable(
+ table_name="old_dataset",
+ columns=columns,
+ metrics=metrics,
+ main_dttm_col="ds",
+ default_endpoint="https://www.youtube.com/watch?v=dQw4w9WgXcQ", # not used
+ database=Database(database_name="my_database", sqlalchemy_uri="sqlite://"),
+ offset=-8,
+ description="This is the description",
+ is_featured=1,
+ cache_timeout=3600,
+ schema="my_schema",
+ sql=None,
+ params=json.dumps(
+ {
+ "remote_id": 64,
+ "database_name": "examples",
+ "import_time": 1606677834,
+ }
+ ),
+ perm=None,
+ filter_select_enabled=1,
+ fetch_values_predicate="foo IN (1, 2)",
+ is_sqllab_view=0, # no longer used?
+ template_params=json.dumps({"answer": "42"}),
+ schema_perm=None,
+ extra=json.dumps({"warning_markdown": "*WARNING*"}),
+ )
+ DatasetDAO = mocker.patch("superset.datasets.dao.DatasetDAO")
+ DatasetDAO.find_by_id.return_value = dataset
+
+ assert (
+ dataset_macro(1)
+ == """(SELECT ds AS ds,
+ num_boys AS num_boys,
+ revenue AS revenue,
+ expenses AS expenses,
+ revenue-expenses AS profit
+FROM my_schema.old_dataset) AS dataset_1"""
+ )
+
+ assert (
+ dataset_macro(1, include_metrics=True)
+ == """(SELECT ds AS ds,
+ num_boys AS num_boys,
+ revenue AS revenue,
+ expenses AS expenses,
+ revenue-expenses AS profit,
+ COUNT(*) AS cnt
+FROM my_schema.old_dataset
+GROUP BY ds,
+ num_boys,
+ revenue,
+ expenses,
+ revenue-expenses) AS dataset_1"""
+ )
+
+ assert (
+ dataset_macro(1, include_metrics=True, columns=["ds"])
+ == """(SELECT ds AS ds,
+ COUNT(*) AS cnt
+FROM my_schema.old_dataset
+GROUP BY ds) AS dataset_1"""
+ )
+
+ DatasetDAO.find_by_id.return_value = None
+ with pytest.raises(DatasetNotFoundError) as excinfo:
+ dataset_macro(1)
+ assert str(excinfo.value) == "Dataset 1 not found!"
diff --git a/tests/unit_tests/key_value/__init__.py b/tests/unit_tests/key_value/__init__.py
new file mode 100644
index 0000000000000..13a83393a9124
--- /dev/null
+++ b/tests/unit_tests/key_value/__init__.py
@@ -0,0 +1,16 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
diff --git a/tests/unit_tests/key_value/utils_test.py b/tests/unit_tests/key_value/utils_test.py
new file mode 100644
index 0000000000000..5d78f6361c02c
--- /dev/null
+++ b/tests/unit_tests/key_value/utils_test.py
@@ -0,0 +1,60 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from uuid import UUID
+
+import pytest
+
+from superset.key_value.exceptions import KeyValueParseKeyError
+from superset.key_value.types import KeyValueResource
+
+RESOURCE = KeyValueResource.APP
+UUID_KEY = UUID("3e7a2ab8-bcaf-49b0-a5df-dfb432f291cc")
+ID_KEY = 123
+
+
+def test_get_filter_uuid() -> None:
+ from superset.key_value.utils import get_filter
+
+ assert get_filter(resource=RESOURCE, key=UUID_KEY) == {
+ "resource": RESOURCE,
+ "uuid": UUID_KEY,
+ }
+
+
+def test_get_filter_id() -> None:
+ from superset.key_value.utils import get_filter
+
+ assert get_filter(resource=RESOURCE, key=ID_KEY) == {
+ "resource": RESOURCE,
+ "id": ID_KEY,
+ }
+
+
+def test_encode_permalink_id_valid() -> None:
+ from superset.key_value.utils import encode_permalink_key
+
+ salt = "abc"
+ assert encode_permalink_key(1, salt) == "AyBn4lm9qG8"
+
+
+def test_decode_permalink_id_invalid() -> None:
+ from superset.key_value.utils import decode_permalink_id
+
+ with pytest.raises(KeyValueParseKeyError):
+ decode_permalink_id("foo", "bar")
diff --git a/tests/unit_tests/models/__init__.py b/tests/unit_tests/models/__init__.py
new file mode 100644
index 0000000000000..13a83393a9124
--- /dev/null
+++ b/tests/unit_tests/models/__init__.py
@@ -0,0 +1,16 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
diff --git a/tests/unit_tests/models/core_test.py b/tests/unit_tests/models/core_test.py
new file mode 100644
index 0000000000000..f8534391d837e
--- /dev/null
+++ b/tests/unit_tests/models/core_test.py
@@ -0,0 +1,145 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+# pylint: disable=import-outside-toplevel
+
+from typing import List, Optional
+
+from pytest_mock import MockFixture
+from sqlalchemy.engine.reflection import Inspector
+
+
+def test_get_metrics(mocker: MockFixture) -> None:
+ """
+ Tests for ``get_metrics``.
+ """
+ from superset.db_engine_specs.base import MetricType
+ from superset.db_engine_specs.sqlite import SqliteEngineSpec
+ from superset.models.core import Database
+
+ database = Database(database_name="my_database", sqlalchemy_uri="sqlite://")
+ assert database.get_metrics("table") == [
+ {
+ "expression": "COUNT(*)",
+ "metric_name": "count",
+ "metric_type": "count",
+ "verbose_name": "COUNT(*)",
+ }
+ ]
+
+ class CustomSqliteEngineSpec(SqliteEngineSpec):
+ @classmethod
+ def get_metrics(
+ cls,
+ database: Database,
+ inspector: Inspector,
+ table_name: str,
+ schema: Optional[str],
+ ) -> List[MetricType]:
+ return [
+ {
+ "expression": "COUNT(DISTINCT user_id)",
+ "metric_name": "count_distinct_user_id",
+ "metric_type": "count_distinct",
+ "verbose_name": "COUNT(DISTINCT user_id)",
+ },
+ ]
+
+ database.get_db_engine_spec = mocker.MagicMock(return_value=CustomSqliteEngineSpec)
+ assert database.get_metrics("table") == [
+ {
+ "expression": "COUNT(DISTINCT user_id)",
+ "metric_name": "count_distinct_user_id",
+ "metric_type": "count_distinct",
+ "verbose_name": "COUNT(DISTINCT user_id)",
+ },
+ ]
+
+
+def test_get_db_engine_spec(mocker: MockFixture) -> None:
+ """
+ Tests for ``get_db_engine_spec``.
+ """
+ from superset.db_engine_specs import BaseEngineSpec
+ from superset.models.core import Database
+
+ # pylint: disable=abstract-method
+ class PostgresDBEngineSpec(BaseEngineSpec):
+ """
+ A DB engine spec with drivers and a default driver.
+ """
+
+ engine = "postgresql"
+ engine_aliases = {"postgres"}
+ drivers = {
+ "psycopg2": "The default Postgres driver",
+ "asyncpg": "An async Postgres driver",
+ }
+ default_driver = "psycopg2"
+
+ # pylint: disable=abstract-method
+ class OldDBEngineSpec(BaseEngineSpec):
+ """
+ And old DB engine spec without drivers nor a default driver.
+ """
+
+ engine = "mysql"
+
+ load_engine_specs = mocker.patch("superset.db_engine_specs.load_engine_specs")
+ load_engine_specs.return_value = [
+ PostgresDBEngineSpec,
+ OldDBEngineSpec,
+ ]
+
+ assert (
+ Database(database_name="db", sqlalchemy_uri="postgresql://").db_engine_spec
+ == PostgresDBEngineSpec
+ )
+ assert (
+ Database(
+ database_name="db", sqlalchemy_uri="postgresql+psycopg2://"
+ ).db_engine_spec
+ == PostgresDBEngineSpec
+ )
+ assert (
+ Database(
+ database_name="db", sqlalchemy_uri="postgresql+asyncpg://"
+ ).db_engine_spec
+ == PostgresDBEngineSpec
+ )
+ assert (
+ Database(
+ database_name="db", sqlalchemy_uri="postgresql+fancynewdriver://"
+ ).db_engine_spec
+ == PostgresDBEngineSpec
+ )
+ assert (
+ Database(database_name="db", sqlalchemy_uri="mysql://").db_engine_spec
+ == OldDBEngineSpec
+ )
+ assert (
+ Database(
+ database_name="db", sqlalchemy_uri="mysql+mysqlconnector://"
+ ).db_engine_spec
+ == OldDBEngineSpec
+ )
+ assert (
+ Database(
+ database_name="db", sqlalchemy_uri="mysql+fancynewdriver://"
+ ).db_engine_spec
+ == OldDBEngineSpec
+ )
diff --git a/tests/unit_tests/notifications/email_tests.py b/tests/unit_tests/notifications/email_tests.py
new file mode 100644
index 0000000000000..4ce34b99cac4d
--- /dev/null
+++ b/tests/unit_tests/notifications/email_tests.py
@@ -0,0 +1,54 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import pandas as pd
+
+
+def test_render_description_with_html() -> None:
+ # `superset.models.helpers`, a dependency of following imports,
+ # requires app context
+ from superset.reports.models import ReportRecipients, ReportRecipientType
+ from superset.reports.notifications.base import NotificationContent
+ from superset.reports.notifications.email import EmailNotification
+
+ content = NotificationContent(
+ name="test alert",
+ embedded_data=pd.DataFrame(
+ {
+ "A": [1, 2, 3],
+ "B": [4, 5, 6],
+ "C": ["111", "222", '333'],
+ }
+ ),
+ description='This is a test alert
',
+ header_data={
+ "notification_format": "PNG",
+ "notification_type": "Alert",
+ "owners": [1],
+ "notification_source": None,
+ "chart_id": None,
+ "dashboard_id": None,
+ },
+ )
+ email_body = (
+ EmailNotification(
+ recipient=ReportRecipients(type=ReportRecipientType.EMAIL), content=content
+ )
+ ._get_content()
+ .body
+ )
+ assert 'This is a test alert
' in email_body
+ assert '<a href="http://www.example.com">333</a> | ' in email_body
diff --git a/tests/unit_tests/pandas_postprocessing/__init__.py b/tests/unit_tests/pandas_postprocessing/__init__.py
new file mode 100644
index 0000000000000..13a83393a9124
--- /dev/null
+++ b/tests/unit_tests/pandas_postprocessing/__init__.py
@@ -0,0 +1,16 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
diff --git a/tests/unit_tests/pandas_postprocessing/test_aggregate.py b/tests/unit_tests/pandas_postprocessing/test_aggregate.py
new file mode 100644
index 0000000000000..69d42e36f06be
--- /dev/null
+++ b/tests/unit_tests/pandas_postprocessing/test_aggregate.py
@@ -0,0 +1,40 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from superset.utils.pandas_postprocessing import aggregate
+from tests.unit_tests.fixtures.dataframes import categories_df
+from tests.unit_tests.pandas_postprocessing.utils import series_to_list
+
+
+def test_aggregate():
+ aggregates = {
+ "asc sum": {"column": "asc_idx", "operator": "sum"},
+ "asc q2": {
+ "column": "asc_idx",
+ "operator": "percentile",
+ "options": {"q": 75},
+ },
+ "desc q1": {
+ "column": "desc_idx",
+ "operator": "percentile",
+ "options": {"q": 25},
+ },
+ }
+ df = aggregate(df=categories_df, groupby=["constant"], aggregates=aggregates)
+ assert df.columns.tolist() == ["constant", "asc sum", "asc q2", "desc q1"]
+ assert series_to_list(df["asc sum"])[0] == 5050
+ assert series_to_list(df["asc q2"])[0] == 75
+ assert series_to_list(df["desc q1"])[0] == 25
diff --git a/tests/unit_tests/pandas_postprocessing/test_boxplot.py b/tests/unit_tests/pandas_postprocessing/test_boxplot.py
new file mode 100644
index 0000000000000..27dff0adeb894
--- /dev/null
+++ b/tests/unit_tests/pandas_postprocessing/test_boxplot.py
@@ -0,0 +1,151 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import pytest
+
+from superset.exceptions import InvalidPostProcessingError
+from superset.utils.core import PostProcessingBoxplotWhiskerType
+from superset.utils.pandas_postprocessing import boxplot
+from tests.unit_tests.fixtures.dataframes import names_df
+
+
+def test_boxplot_tukey():
+ df = boxplot(
+ df=names_df,
+ groupby=["region"],
+ whisker_type=PostProcessingBoxplotWhiskerType.TUKEY,
+ metrics=["cars"],
+ )
+ columns = {column for column in df.columns}
+ assert columns == {
+ "cars__mean",
+ "cars__median",
+ "cars__q1",
+ "cars__q3",
+ "cars__max",
+ "cars__min",
+ "cars__count",
+ "cars__outliers",
+ "region",
+ }
+ assert len(df) == 4
+
+
+def test_boxplot_min_max():
+ df = boxplot(
+ df=names_df,
+ groupby=["region"],
+ whisker_type=PostProcessingBoxplotWhiskerType.MINMAX,
+ metrics=["cars"],
+ )
+ columns = {column for column in df.columns}
+ assert columns == {
+ "cars__mean",
+ "cars__median",
+ "cars__q1",
+ "cars__q3",
+ "cars__max",
+ "cars__min",
+ "cars__count",
+ "cars__outliers",
+ "region",
+ }
+ assert len(df) == 4
+
+
+def test_boxplot_percentile():
+ df = boxplot(
+ df=names_df,
+ groupby=["region"],
+ whisker_type=PostProcessingBoxplotWhiskerType.PERCENTILE,
+ metrics=["cars"],
+ percentiles=[1, 99],
+ )
+ columns = {column for column in df.columns}
+ assert columns == {
+ "cars__mean",
+ "cars__median",
+ "cars__q1",
+ "cars__q3",
+ "cars__max",
+ "cars__min",
+ "cars__count",
+ "cars__outliers",
+ "region",
+ }
+ assert len(df) == 4
+
+
+def test_boxplot_percentile_incorrect_params():
+ with pytest.raises(InvalidPostProcessingError):
+ boxplot(
+ df=names_df,
+ groupby=["region"],
+ whisker_type=PostProcessingBoxplotWhiskerType.PERCENTILE,
+ metrics=["cars"],
+ )
+
+ with pytest.raises(InvalidPostProcessingError):
+ boxplot(
+ df=names_df,
+ groupby=["region"],
+ whisker_type=PostProcessingBoxplotWhiskerType.PERCENTILE,
+ metrics=["cars"],
+ percentiles=[10],
+ )
+
+ with pytest.raises(InvalidPostProcessingError):
+ boxplot(
+ df=names_df,
+ groupby=["region"],
+ whisker_type=PostProcessingBoxplotWhiskerType.PERCENTILE,
+ metrics=["cars"],
+ percentiles=[90, 10],
+ )
+
+ with pytest.raises(InvalidPostProcessingError):
+ boxplot(
+ df=names_df,
+ groupby=["region"],
+ whisker_type=PostProcessingBoxplotWhiskerType.PERCENTILE,
+ metrics=["cars"],
+ percentiles=[10, 90, 10],
+ )
+
+
+def test_boxplot_type_coercion():
+ df = names_df
+ df["cars"] = df["cars"].astype(str)
+ df = boxplot(
+ df=df,
+ groupby=["region"],
+ whisker_type=PostProcessingBoxplotWhiskerType.TUKEY,
+ metrics=["cars"],
+ )
+
+ columns = {column for column in df.columns}
+ assert columns == {
+ "cars__mean",
+ "cars__median",
+ "cars__q1",
+ "cars__q3",
+ "cars__max",
+ "cars__min",
+ "cars__count",
+ "cars__outliers",
+ "region",
+ }
+ assert len(df) == 4
diff --git a/tests/unit_tests/pandas_postprocessing/test_compare.py b/tests/unit_tests/pandas_postprocessing/test_compare.py
new file mode 100644
index 0000000000000..9da8a31535470
--- /dev/null
+++ b/tests/unit_tests/pandas_postprocessing/test_compare.py
@@ -0,0 +1,231 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import pandas as pd
+
+from superset.constants import PandasPostprocessingCompare as PPC
+from superset.utils import pandas_postprocessing as pp
+from superset.utils.pandas_postprocessing.utils import FLAT_COLUMN_SEPARATOR
+from tests.unit_tests.fixtures.dataframes import multiple_metrics_df, timeseries_df2
+
+
+def test_compare_should_not_side_effect():
+ _timeseries_df2 = timeseries_df2.copy()
+ pp.compare(
+ df=_timeseries_df2,
+ source_columns=["y"],
+ compare_columns=["z"],
+ compare_type=PPC.DIFF,
+ )
+ assert _timeseries_df2.equals(timeseries_df2)
+
+
+def test_compare_diff():
+ # `difference` comparison
+ post_df = pp.compare(
+ df=timeseries_df2,
+ source_columns=["y"],
+ compare_columns=["z"],
+ compare_type=PPC.DIFF,
+ )
+ """
+ label y z difference__y__z
+ 2019-01-01 x 2.0 2.0 0.0
+ 2019-01-02 y 2.0 4.0 -2.0
+ 2019-01-05 z 2.0 10.0 -8.0
+ 2019-01-07 q 2.0 8.0 -6.0
+ """
+ assert post_df.equals(
+ pd.DataFrame(
+ index=timeseries_df2.index,
+ data={
+ "label": ["x", "y", "z", "q"],
+ "y": [2.0, 2.0, 2.0, 2.0],
+ "z": [2.0, 4.0, 10.0, 8.0],
+ "difference__y__z": [0.0, -2.0, -8.0, -6.0],
+ },
+ )
+ )
+
+ # drop original columns
+ post_df = pp.compare(
+ df=timeseries_df2,
+ source_columns=["y"],
+ compare_columns=["z"],
+ compare_type=PPC.DIFF,
+ drop_original_columns=True,
+ )
+ assert post_df.equals(
+ pd.DataFrame(
+ index=timeseries_df2.index,
+ data={
+ "label": ["x", "y", "z", "q"],
+ "difference__y__z": [0.0, -2.0, -8.0, -6.0],
+ },
+ )
+ )
+
+
+def test_compare_percentage():
+ # `percentage` comparison
+ post_df = pp.compare(
+ df=timeseries_df2,
+ source_columns=["y"],
+ compare_columns=["z"],
+ compare_type=PPC.PCT,
+ )
+ """
+ label y z percentage__y__z
+ 2019-01-01 x 2.0 2.0 0.0
+ 2019-01-02 y 2.0 4.0 -0.50
+ 2019-01-05 z 2.0 10.0 -0.80
+ 2019-01-07 q 2.0 8.0 -0.75
+ """
+ assert post_df.equals(
+ pd.DataFrame(
+ index=timeseries_df2.index,
+ data={
+ "label": ["x", "y", "z", "q"],
+ "y": [2.0, 2.0, 2.0, 2.0],
+ "z": [2.0, 4.0, 10.0, 8.0],
+ "percentage__y__z": [0.0, -0.50, -0.80, -0.75],
+ },
+ )
+ )
+
+
+def test_compare_ratio():
+ # `ratio` comparison
+ post_df = pp.compare(
+ df=timeseries_df2,
+ source_columns=["y"],
+ compare_columns=["z"],
+ compare_type=PPC.RAT,
+ )
+ """
+ label y z ratio__y__z
+ 2019-01-01 x 2.0 2.0 1.00
+ 2019-01-02 y 2.0 4.0 0.50
+ 2019-01-05 z 2.0 10.0 0.20
+ 2019-01-07 q 2.0 8.0 0.25
+ """
+ assert post_df.equals(
+ pd.DataFrame(
+ index=timeseries_df2.index,
+ data={
+ "label": ["x", "y", "z", "q"],
+ "y": [2.0, 2.0, 2.0, 2.0],
+ "z": [2.0, 4.0, 10.0, 8.0],
+ "ratio__y__z": [1.00, 0.50, 0.20, 0.25],
+ },
+ )
+ )
+
+
+def test_compare_multi_index_column():
+ index = pd.to_datetime(["2021-01-01", "2021-01-02", "2021-01-03"])
+ index.name = "__timestamp"
+ iterables = [["m1", "m2"], ["a", "b"], ["x", "y"]]
+ columns = pd.MultiIndex.from_product(iterables, names=[None, "level1", "level2"])
+ df = pd.DataFrame(index=index, columns=columns, data=1)
+ """
+ m1 m2
+ level1 a b a b
+ level2 x y x y x y x y
+ __timestamp
+ 2021-01-01 1 1 1 1 1 1 1 1
+ 2021-01-02 1 1 1 1 1 1 1 1
+ 2021-01-03 1 1 1 1 1 1 1 1
+ """
+ post_df = pp.compare(
+ df,
+ source_columns=["m1"],
+ compare_columns=["m2"],
+ compare_type=PPC.DIFF,
+ drop_original_columns=True,
+ )
+ flat_df = pp.flatten(post_df)
+ """
+ __timestamp difference__m1__m2, a, x difference__m1__m2, a, y difference__m1__m2, b, x difference__m1__m2, b, y
+ 0 2021-01-01 0 0 0 0
+ 1 2021-01-02 0 0 0 0
+ 2 2021-01-03 0 0 0 0
+ """
+ assert flat_df.equals(
+ pd.DataFrame(
+ data={
+ "__timestamp": pd.to_datetime(
+ ["2021-01-01", "2021-01-02", "2021-01-03"]
+ ),
+ "difference__m1__m2, a, x": [0, 0, 0],
+ "difference__m1__m2, a, y": [0, 0, 0],
+ "difference__m1__m2, b, x": [0, 0, 0],
+ "difference__m1__m2, b, y": [0, 0, 0],
+ }
+ )
+ )
+
+
+def test_compare_after_pivot():
+ pivot_df = pp.pivot(
+ df=multiple_metrics_df,
+ index=["dttm"],
+ columns=["country"],
+ aggregates={
+ "sum_metric": {"operator": "sum"},
+ "count_metric": {"operator": "sum"},
+ },
+ )
+ """
+ count_metric sum_metric
+ country UK US UK US
+ dttm
+ 2019-01-01 1 2 5 6
+ 2019-01-02 3 4 7 8
+ """
+ compared_df = pp.compare(
+ pivot_df,
+ source_columns=["count_metric"],
+ compare_columns=["sum_metric"],
+ compare_type=PPC.DIFF,
+ drop_original_columns=True,
+ )
+ """
+ difference__count_metric__sum_metric
+ country UK US
+ dttm
+ 2019-01-01 -4 -4
+ 2019-01-02 -4 -4
+ """
+ flat_df = pp.flatten(compared_df)
+ """
+ dttm difference__count_metric__sum_metric, UK difference__count_metric__sum_metric, US
+ 0 2019-01-01 -4 -4
+ 1 2019-01-02 -4 -4
+ """
+ assert flat_df.equals(
+ pd.DataFrame(
+ data={
+ "dttm": pd.to_datetime(["2019-01-01", "2019-01-02"]),
+ FLAT_COLUMN_SEPARATOR.join(
+ ["difference__count_metric__sum_metric", "UK"]
+ ): [-4, -4],
+ FLAT_COLUMN_SEPARATOR.join(
+ ["difference__count_metric__sum_metric", "US"]
+ ): [-4, -4],
+ }
+ )
+ )
diff --git a/tests/unit_tests/pandas_postprocessing/test_contribution.py b/tests/unit_tests/pandas_postprocessing/test_contribution.py
new file mode 100644
index 0000000000000..7eb34c4d13f7b
--- /dev/null
+++ b/tests/unit_tests/pandas_postprocessing/test_contribution.py
@@ -0,0 +1,80 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from datetime import datetime
+
+import pytest
+from numpy import nan
+from numpy.testing import assert_array_equal
+from pandas import DataFrame
+
+from superset.exceptions import InvalidPostProcessingError
+from superset.utils.core import DTTM_ALIAS, PostProcessingContributionOrientation
+from superset.utils.pandas_postprocessing import contribution
+
+
+def test_contribution():
+ df = DataFrame(
+ {
+ DTTM_ALIAS: [
+ datetime(2020, 7, 16, 14, 49),
+ datetime(2020, 7, 16, 14, 50),
+ datetime(2020, 7, 16, 14, 51),
+ ],
+ "a": [1, 3, nan],
+ "b": [1, 9, nan],
+ "c": [nan, nan, nan],
+ }
+ )
+ with pytest.raises(InvalidPostProcessingError, match="not numeric"):
+ contribution(df, columns=[DTTM_ALIAS])
+
+ with pytest.raises(InvalidPostProcessingError, match="same length"):
+ contribution(df, columns=["a"], rename_columns=["aa", "bb"])
+
+ # cell contribution across row
+ processed_df = contribution(
+ df,
+ orientation=PostProcessingContributionOrientation.ROW,
+ )
+ assert processed_df.columns.tolist() == [DTTM_ALIAS, "a", "b", "c"]
+ assert_array_equal(processed_df["a"].tolist(), [0.5, 0.25, nan])
+ assert_array_equal(processed_df["b"].tolist(), [0.5, 0.75, nan])
+ assert_array_equal(processed_df["c"].tolist(), [0, 0, nan])
+
+ # cell contribution across column without temporal column
+ df.pop(DTTM_ALIAS)
+ processed_df = contribution(
+ df, orientation=PostProcessingContributionOrientation.COLUMN
+ )
+ assert processed_df.columns.tolist() == ["a", "b", "c"]
+ assert_array_equal(processed_df["a"].tolist(), [0.25, 0.75, 0])
+ assert_array_equal(processed_df["b"].tolist(), [0.1, 0.9, 0])
+ assert_array_equal(processed_df["c"].tolist(), [nan, nan, nan])
+
+ # contribution only on selected columns
+ processed_df = contribution(
+ df,
+ orientation=PostProcessingContributionOrientation.COLUMN,
+ columns=["a"],
+ rename_columns=["pct_a"],
+ )
+ assert processed_df.columns.tolist() == ["a", "b", "c", "pct_a"]
+ assert_array_equal(processed_df["a"].tolist(), [1, 3, nan])
+ assert_array_equal(processed_df["b"].tolist(), [1, 9, nan])
+ assert_array_equal(processed_df["c"].tolist(), [nan, nan, nan])
+ assert processed_df["pct_a"].tolist() == [0.25, 0.75, 0]
diff --git a/tests/unit_tests/pandas_postprocessing/test_cum.py b/tests/unit_tests/pandas_postprocessing/test_cum.py
new file mode 100644
index 0000000000000..130e0602520a1
--- /dev/null
+++ b/tests/unit_tests/pandas_postprocessing/test_cum.py
@@ -0,0 +1,164 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import pandas as pd
+import pytest
+
+from superset.exceptions import InvalidPostProcessingError
+from superset.utils import pandas_postprocessing as pp
+from superset.utils.pandas_postprocessing.utils import FLAT_COLUMN_SEPARATOR
+from tests.unit_tests.fixtures.dataframes import (
+ multiple_metrics_df,
+ single_metric_df,
+ timeseries_df,
+)
+from tests.unit_tests.pandas_postprocessing.utils import series_to_list
+
+
+def test_cum_should_not_side_effect():
+ _timeseries_df = timeseries_df.copy()
+ pp.cum(
+ df=timeseries_df,
+ columns={"y": "y2"},
+ operator="sum",
+ )
+ assert _timeseries_df.equals(timeseries_df)
+
+
+def test_cum():
+ # create new column (cumsum)
+ post_df = pp.cum(
+ df=timeseries_df,
+ columns={"y": "y2"},
+ operator="sum",
+ )
+ assert post_df.columns.tolist() == ["label", "y", "y2"]
+ assert series_to_list(post_df["label"]) == ["x", "y", "z", "q"]
+ assert series_to_list(post_df["y"]) == [1.0, 2.0, 3.0, 4.0]
+ assert series_to_list(post_df["y2"]) == [1.0, 3.0, 6.0, 10.0]
+
+ # overwrite column (cumprod)
+ post_df = pp.cum(
+ df=timeseries_df,
+ columns={"y": "y"},
+ operator="prod",
+ )
+ assert post_df.columns.tolist() == ["label", "y"]
+ assert series_to_list(post_df["y"]) == [1.0, 2.0, 6.0, 24.0]
+
+ # overwrite column (cummin)
+ post_df = pp.cum(
+ df=timeseries_df,
+ columns={"y": "y"},
+ operator="min",
+ )
+ assert post_df.columns.tolist() == ["label", "y"]
+ assert series_to_list(post_df["y"]) == [1.0, 1.0, 1.0, 1.0]
+
+ # invalid operator
+ with pytest.raises(InvalidPostProcessingError):
+ pp.cum(
+ df=timeseries_df,
+ columns={"y": "y"},
+ operator="abc",
+ )
+
+
+def test_cum_after_pivot_with_single_metric():
+ pivot_df = pp.pivot(
+ df=single_metric_df,
+ index=["dttm"],
+ columns=["country"],
+ aggregates={"sum_metric": {"operator": "sum"}},
+ )
+ """
+ sum_metric
+ country UK US
+ dttm
+ 2019-01-01 5 6
+ 2019-01-02 7 8
+ """
+ cum_df = pp.cum(df=pivot_df, operator="sum", columns={"sum_metric": "sum_metric"})
+ """
+ sum_metric
+ country UK US
+ dttm
+ 2019-01-01 5 6
+ 2019-01-02 12 14
+ """
+ cum_and_flat_df = pp.flatten(cum_df)
+ """
+ dttm sum_metric, UK sum_metric, US
+ 0 2019-01-01 5 6
+ 1 2019-01-02 12 14
+ """
+ assert cum_and_flat_df.equals(
+ pd.DataFrame(
+ {
+ "dttm": pd.to_datetime(["2019-01-01", "2019-01-02"]),
+ FLAT_COLUMN_SEPARATOR.join(["sum_metric", "UK"]): [5, 12],
+ FLAT_COLUMN_SEPARATOR.join(["sum_metric", "US"]): [6, 14],
+ }
+ )
+ )
+
+
+def test_cum_after_pivot_with_multiple_metrics():
+ pivot_df = pp.pivot(
+ df=multiple_metrics_df,
+ index=["dttm"],
+ columns=["country"],
+ aggregates={
+ "sum_metric": {"operator": "sum"},
+ "count_metric": {"operator": "sum"},
+ },
+ )
+ """
+ count_metric sum_metric
+ country UK US UK US
+ dttm
+ 2019-01-01 1 2 5 6
+ 2019-01-02 3 4 7 8
+ """
+ cum_df = pp.cum(
+ df=pivot_df,
+ operator="sum",
+ columns={"sum_metric": "sum_metric", "count_metric": "count_metric"},
+ )
+ """
+ count_metric sum_metric
+ country UK US UK US
+ dttm
+ 2019-01-01 1 2 5 6
+ 2019-01-02 4 6 12 14
+ """
+ flat_df = pp.flatten(cum_df)
+ """
+ dttm count_metric, UK count_metric, US sum_metric, UK sum_metric, US
+ 0 2019-01-01 1 2 5 6
+ 1 2019-01-02 4 6 12 14
+ """
+ assert flat_df.equals(
+ pd.DataFrame(
+ {
+ "dttm": pd.to_datetime(["2019-01-01", "2019-01-02"]),
+ FLAT_COLUMN_SEPARATOR.join(["count_metric", "UK"]): [1, 4],
+ FLAT_COLUMN_SEPARATOR.join(["count_metric", "US"]): [2, 6],
+ FLAT_COLUMN_SEPARATOR.join(["sum_metric", "UK"]): [5, 12],
+ FLAT_COLUMN_SEPARATOR.join(["sum_metric", "US"]): [6, 14],
+ }
+ )
+ )
diff --git a/tests/unit_tests/pandas_postprocessing/test_diff.py b/tests/unit_tests/pandas_postprocessing/test_diff.py
new file mode 100644
index 0000000000000..c77195bbf6d71
--- /dev/null
+++ b/tests/unit_tests/pandas_postprocessing/test_diff.py
@@ -0,0 +1,51 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import pytest
+
+from superset.exceptions import InvalidPostProcessingError
+from superset.utils.pandas_postprocessing import diff
+from tests.unit_tests.fixtures.dataframes import timeseries_df, timeseries_df2
+from tests.unit_tests.pandas_postprocessing.utils import series_to_list
+
+
+def test_diff():
+ # overwrite column
+ post_df = diff(df=timeseries_df, columns={"y": "y"})
+ assert post_df.columns.tolist() == ["label", "y"]
+ assert series_to_list(post_df["y"]) == [None, 1.0, 1.0, 1.0]
+
+ # add column
+ post_df = diff(df=timeseries_df, columns={"y": "y1"})
+ assert post_df.columns.tolist() == ["label", "y", "y1"]
+ assert series_to_list(post_df["y"]) == [1.0, 2.0, 3.0, 4.0]
+ assert series_to_list(post_df["y1"]) == [None, 1.0, 1.0, 1.0]
+
+ # look ahead
+ post_df = diff(df=timeseries_df, columns={"y": "y1"}, periods=-1)
+ assert series_to_list(post_df["y1"]) == [-1.0, -1.0, -1.0, None]
+
+ # invalid column reference
+ with pytest.raises(InvalidPostProcessingError):
+ diff(
+ df=timeseries_df,
+ columns={"abc": "abc"},
+ )
+
+ # diff by columns
+ post_df = diff(df=timeseries_df2, columns={"y": "y", "z": "z"}, axis=1)
+ assert post_df.columns.tolist() == ["label", "y", "z"]
+ assert series_to_list(post_df["z"]) == [0.0, 2.0, 8.0, 6.0]
diff --git a/tests/unit_tests/pandas_postprocessing/test_flatten.py b/tests/unit_tests/pandas_postprocessing/test_flatten.py
new file mode 100644
index 0000000000000..fea84f7b9f5b0
--- /dev/null
+++ b/tests/unit_tests/pandas_postprocessing/test_flatten.py
@@ -0,0 +1,177 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import pandas as pd
+
+from superset.utils import pandas_postprocessing as pp
+from superset.utils.pandas_postprocessing.utils import FLAT_COLUMN_SEPARATOR
+from tests.unit_tests.fixtures.dataframes import timeseries_df
+
+
+def test_flat_should_not_change():
+ df = pd.DataFrame(
+ data={
+ "foo": [1, 2, 3],
+ "bar": [4, 5, 6],
+ }
+ )
+
+ assert pp.flatten(df).equals(df)
+
+
+def test_flat_should_not_reset_index():
+ index = pd.to_datetime(["2021-01-01", "2021-01-02", "2021-01-03"])
+ index.name = "__timestamp"
+ df = pd.DataFrame(index=index, data={"foo": [1, 2, 3], "bar": [4, 5, 6]})
+
+ assert pp.flatten(df, reset_index=False).equals(df)
+
+
+def test_flat_should_flat_datetime_index():
+ index = pd.to_datetime(["2021-01-01", "2021-01-02", "2021-01-03"])
+ index.name = "__timestamp"
+ df = pd.DataFrame(index=index, data={"foo": [1, 2, 3], "bar": [4, 5, 6]})
+
+ assert pp.flatten(df).equals(
+ pd.DataFrame(
+ {
+ "__timestamp": index,
+ "foo": [1, 2, 3],
+ "bar": [4, 5, 6],
+ }
+ )
+ )
+
+
+def test_flat_should_flat_multiple_index():
+ index = pd.to_datetime(["2021-01-01", "2021-01-02", "2021-01-03"])
+ index.name = "__timestamp"
+ iterables = [["foo", "bar"], [1, "two"]]
+ columns = pd.MultiIndex.from_product(iterables, names=["level1", "level2"])
+ df = pd.DataFrame(index=index, columns=columns, data=1)
+
+ assert pp.flatten(df).equals(
+ pd.DataFrame(
+ {
+ "__timestamp": index,
+ FLAT_COLUMN_SEPARATOR.join(["foo", "1"]): [1, 1, 1],
+ FLAT_COLUMN_SEPARATOR.join(["foo", "two"]): [1, 1, 1],
+ FLAT_COLUMN_SEPARATOR.join(["bar", "1"]): [1, 1, 1],
+ FLAT_COLUMN_SEPARATOR.join(["bar", "two"]): [1, 1, 1],
+ }
+ )
+ )
+
+
+def test_flat_should_drop_index_level():
+ index = pd.to_datetime(["2021-01-01", "2021-01-02", "2021-01-03"])
+ index.name = "__timestamp"
+ columns = pd.MultiIndex.from_arrays(
+ [["a"] * 3, ["b"] * 3, ["c", "d", "e"], ["ff", "ii", "gg"]],
+ names=["level1", "level2", "level3", "level4"],
+ )
+ df = pd.DataFrame(index=index, columns=columns, data=1)
+
+ # drop level by index
+ assert pp.flatten(df.copy(), drop_levels=(0, 1,)).equals(
+ pd.DataFrame(
+ {
+ "__timestamp": index,
+ FLAT_COLUMN_SEPARATOR.join(["c", "ff"]): [1, 1, 1],
+ FLAT_COLUMN_SEPARATOR.join(["d", "ii"]): [1, 1, 1],
+ FLAT_COLUMN_SEPARATOR.join(["e", "gg"]): [1, 1, 1],
+ }
+ )
+ )
+
+ # drop level by name
+ assert pp.flatten(df.copy(), drop_levels=("level1", "level2")).equals(
+ pd.DataFrame(
+ {
+ "__timestamp": index,
+ FLAT_COLUMN_SEPARATOR.join(["c", "ff"]): [1, 1, 1],
+ FLAT_COLUMN_SEPARATOR.join(["d", "ii"]): [1, 1, 1],
+ FLAT_COLUMN_SEPARATOR.join(["e", "gg"]): [1, 1, 1],
+ }
+ )
+ )
+
+ # only leave 1 level
+ assert pp.flatten(df.copy(), drop_levels=(0, 1, 2)).equals(
+ pd.DataFrame(
+ {
+ "__timestamp": index,
+ FLAT_COLUMN_SEPARATOR.join(["ff"]): [1, 1, 1],
+ FLAT_COLUMN_SEPARATOR.join(["ii"]): [1, 1, 1],
+ FLAT_COLUMN_SEPARATOR.join(["gg"]): [1, 1, 1],
+ }
+ )
+ )
+
+
+def test_flat_should_not_droplevel():
+ assert pp.flatten(timeseries_df, drop_levels=(0,)).equals(
+ pd.DataFrame(
+ {
+ "index": pd.to_datetime(
+ ["2019-01-01", "2019-01-02", "2019-01-05", "2019-01-07"]
+ ),
+ "label": ["x", "y", "z", "q"],
+ "y": [1.0, 2.0, 3.0, 4.0],
+ }
+ )
+ )
+
+
+def test_flat_integer_column_name():
+ index = pd.to_datetime(["2021-01-01", "2021-01-02", "2021-01-03"])
+ index.name = "__timestamp"
+ columns = pd.MultiIndex.from_arrays(
+ [["a"] * 3, [100, 200, 300]],
+ names=["level1", "level2"],
+ )
+ df = pd.DataFrame(index=index, columns=columns, data=1)
+ assert pp.flatten(df, drop_levels=(0,)).equals(
+ pd.DataFrame(
+ {
+ "__timestamp": pd.to_datetime(
+ ["2021-01-01", "2021-01-02", "2021-01-03"]
+ ),
+ "100": [1, 1, 1],
+ "200": [1, 1, 1],
+ "300": [1, 1, 1],
+ }
+ )
+ )
+
+
+def test_escape_column_name():
+ index = pd.to_datetime(["2021-01-01", "2021-01-02", "2021-01-03"])
+ index.name = "__timestamp"
+ columns = pd.MultiIndex.from_arrays(
+ [
+ ["level1,value1", "level1,value2", "level1,value3"],
+ ["level2, value1", "level2, value2", "level2, value3"],
+ ],
+ names=["level1", "level2"],
+ )
+ df = pd.DataFrame(index=index, columns=columns, data=1)
+ assert list(pp.flatten(df).columns.values) == [
+ "__timestamp",
+ "level1\\,value1" + FLAT_COLUMN_SEPARATOR + "level2\\, value1",
+ "level1\\,value2" + FLAT_COLUMN_SEPARATOR + "level2\\, value2",
+ "level1\\,value3" + FLAT_COLUMN_SEPARATOR + "level2\\, value3",
+ ]
diff --git a/tests/unit_tests/pandas_postprocessing/test_geography.py b/tests/unit_tests/pandas_postprocessing/test_geography.py
new file mode 100644
index 0000000000000..6162f3c8a0b94
--- /dev/null
+++ b/tests/unit_tests/pandas_postprocessing/test_geography.py
@@ -0,0 +1,90 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from superset.utils.pandas_postprocessing import (
+ geodetic_parse,
+ geohash_decode,
+ geohash_encode,
+)
+from tests.unit_tests.fixtures.dataframes import lonlat_df
+from tests.unit_tests.pandas_postprocessing.utils import round_floats, series_to_list
+
+
+def test_geohash_decode():
+ # decode lon/lat from geohash
+ post_df = geohash_decode(
+ df=lonlat_df[["city", "geohash"]],
+ geohash="geohash",
+ latitude="latitude",
+ longitude="longitude",
+ )
+ assert sorted(post_df.columns.tolist()) == sorted(
+ ["city", "geohash", "latitude", "longitude"]
+ )
+ assert round_floats(series_to_list(post_df["longitude"]), 6) == round_floats(
+ series_to_list(lonlat_df["longitude"]), 6
+ )
+ assert round_floats(series_to_list(post_df["latitude"]), 6) == round_floats(
+ series_to_list(lonlat_df["latitude"]), 6
+ )
+
+
+def test_geohash_encode():
+ # encode lon/lat into geohash
+ post_df = geohash_encode(
+ df=lonlat_df[["city", "latitude", "longitude"]],
+ latitude="latitude",
+ longitude="longitude",
+ geohash="geohash",
+ )
+ assert sorted(post_df.columns.tolist()) == sorted(
+ ["city", "geohash", "latitude", "longitude"]
+ )
+ assert series_to_list(post_df["geohash"]) == series_to_list(lonlat_df["geohash"])
+
+
+def test_geodetic_parse():
+ # parse geodetic string with altitude into lon/lat/altitude
+ post_df = geodetic_parse(
+ df=lonlat_df[["city", "geodetic"]],
+ geodetic="geodetic",
+ latitude="latitude",
+ longitude="longitude",
+ altitude="altitude",
+ )
+ assert sorted(post_df.columns.tolist()) == sorted(
+ ["city", "geodetic", "latitude", "longitude", "altitude"]
+ )
+ assert series_to_list(post_df["longitude"]) == series_to_list(
+ lonlat_df["longitude"]
+ )
+ assert series_to_list(post_df["latitude"]) == series_to_list(lonlat_df["latitude"])
+ assert series_to_list(post_df["altitude"]) == series_to_list(lonlat_df["altitude"])
+
+ # parse geodetic string into lon/lat
+ post_df = geodetic_parse(
+ df=lonlat_df[["city", "geodetic"]],
+ geodetic="geodetic",
+ latitude="latitude",
+ longitude="longitude",
+ )
+ assert sorted(post_df.columns.tolist()) == sorted(
+ ["city", "geodetic", "latitude", "longitude"]
+ )
+ assert series_to_list(post_df["longitude"]) == series_to_list(
+ lonlat_df["longitude"]
+ )
+ assert series_to_list(post_df["latitude"]), series_to_list(lonlat_df["latitude"])
diff --git a/tests/unit_tests/pandas_postprocessing/test_pivot.py b/tests/unit_tests/pandas_postprocessing/test_pivot.py
new file mode 100644
index 0000000000000..8efd203906077
--- /dev/null
+++ b/tests/unit_tests/pandas_postprocessing/test_pivot.py
@@ -0,0 +1,205 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import numpy as np
+import pytest
+from pandas import DataFrame, to_datetime
+
+from superset.exceptions import InvalidPostProcessingError
+from superset.utils.pandas_postprocessing import flatten, pivot
+from tests.unit_tests.fixtures.dataframes import categories_df
+from tests.unit_tests.pandas_postprocessing.utils import AGGREGATES_SINGLE
+
+
+def test_pivot_without_columns():
+ """
+ Make sure pivot without columns returns correct DataFrame
+ """
+ df = pivot(
+ df=categories_df,
+ index=["name"],
+ aggregates=AGGREGATES_SINGLE,
+ )
+ assert df.columns.tolist() == ["idx_nulls"]
+ assert len(df) == 101
+ assert df["idx_nulls"].sum() == 1050
+
+
+def test_pivot_with_single_column():
+ """
+ Make sure pivot with single column returns correct DataFrame
+ """
+ df = pivot(
+ df=categories_df,
+ index=["name"],
+ columns=["category"],
+ aggregates=AGGREGATES_SINGLE,
+ )
+ assert df.columns.tolist() == [
+ ("idx_nulls", "cat0"),
+ ("idx_nulls", "cat1"),
+ ("idx_nulls", "cat2"),
+ ]
+ assert len(df) == 101
+ assert df["idx_nulls"]["cat0"].sum() == 315
+
+ df = pivot(
+ df=categories_df,
+ index=["dept"],
+ columns=["category"],
+ aggregates=AGGREGATES_SINGLE,
+ )
+ assert df.columns.tolist() == [
+ ("idx_nulls", "cat0"),
+ ("idx_nulls", "cat1"),
+ ("idx_nulls", "cat2"),
+ ]
+ assert len(df) == 5
+
+
+def test_pivot_with_multiple_columns():
+ """
+ Make sure pivot with multiple columns returns correct DataFrame
+ """
+ df = pivot(
+ df=categories_df,
+ index=["name"],
+ columns=["category", "dept"],
+ aggregates=AGGREGATES_SINGLE,
+ )
+ df = flatten(df)
+ assert len(df.columns) == 1 + 3 * 5 # index + possible permutations
+
+
+def test_pivot_fill_values():
+ """
+ Make sure pivot with fill values returns correct DataFrame
+ """
+ df = pivot(
+ df=categories_df,
+ index=["name"],
+ columns=["category"],
+ metric_fill_value=1,
+ aggregates={"idx_nulls": {"operator": "sum"}},
+ )
+ assert df["idx_nulls"]["cat0"].sum() == 382
+
+
+def test_pivot_fill_column_values():
+ """
+ Make sure pivot witn null column names returns correct DataFrame
+ """
+ df_copy = categories_df.copy()
+ df_copy["category"] = None
+ df = pivot(
+ df=df_copy,
+ index=["name"],
+ columns=["category"],
+ aggregates={"idx_nulls": {"operator": "sum"}},
+ )
+ assert len(df) == 101
+ assert df.columns.tolist() == [("idx_nulls", "")]
+
+
+def test_pivot_exceptions():
+ """
+ Make sure pivot raises correct Exceptions
+ """
+ # Missing index
+ with pytest.raises(TypeError):
+ pivot(df=categories_df, columns=["dept"], aggregates=AGGREGATES_SINGLE)
+
+ # invalid index reference
+ with pytest.raises(InvalidPostProcessingError):
+ pivot(
+ df=categories_df,
+ index=["abc"],
+ columns=["dept"],
+ aggregates=AGGREGATES_SINGLE,
+ )
+
+ # invalid column reference
+ with pytest.raises(InvalidPostProcessingError):
+ pivot(
+ df=categories_df,
+ index=["dept"],
+ columns=["abc"],
+ aggregates=AGGREGATES_SINGLE,
+ )
+
+ # invalid aggregate options
+ with pytest.raises(InvalidPostProcessingError):
+ pivot(
+ df=categories_df,
+ index=["name"],
+ columns=["category"],
+ aggregates={"idx_nulls": {}},
+ )
+
+
+def test_pivot_eliminate_cartesian_product_columns():
+ # single metric
+ mock_df = DataFrame(
+ {
+ "dttm": to_datetime(["2019-01-01", "2019-01-01"]),
+ "a": [0, 1],
+ "b": [0, 1],
+ "metric": [9, np.NAN],
+ }
+ )
+
+ df = pivot(
+ df=mock_df,
+ index=["dttm"],
+ columns=["a", "b"],
+ aggregates={"metric": {"operator": "mean"}},
+ drop_missing_columns=False,
+ )
+ df = flatten(df)
+ assert list(df.columns) == ["dttm", "metric, 0, 0", "metric, 1, 1"]
+ assert np.isnan(df["metric, 1, 1"][0])
+
+ # multiple metrics
+ mock_df = DataFrame(
+ {
+ "dttm": to_datetime(["2019-01-01", "2019-01-01"]),
+ "a": [0, 1],
+ "b": [0, 1],
+ "metric": [9, np.NAN],
+ "metric2": [10, 11],
+ }
+ )
+
+ df = pivot(
+ df=mock_df,
+ index=["dttm"],
+ columns=["a", "b"],
+ aggregates={
+ "metric": {"operator": "mean"},
+ "metric2": {"operator": "mean"},
+ },
+ drop_missing_columns=False,
+ )
+ df = flatten(df)
+ assert list(df.columns) == [
+ "dttm",
+ "metric, 0, 0",
+ "metric, 1, 1",
+ "metric2, 0, 0",
+ "metric2, 1, 1",
+ ]
+ assert np.isnan(df["metric, 1, 1"][0])
diff --git a/tests/unit_tests/pandas_postprocessing/test_prophet.py b/tests/unit_tests/pandas_postprocessing/test_prophet.py
new file mode 100644
index 0000000000000..6da3a7a591a3d
--- /dev/null
+++ b/tests/unit_tests/pandas_postprocessing/test_prophet.py
@@ -0,0 +1,190 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from datetime import datetime
+from importlib.util import find_spec
+
+import pandas as pd
+import pytest
+
+from superset.exceptions import InvalidPostProcessingError
+from superset.utils.core import DTTM_ALIAS
+from superset.utils.pandas_postprocessing import prophet
+from tests.unit_tests.fixtures.dataframes import prophet_df
+
+
+def test_prophet_valid():
+ pytest.importorskip("prophet")
+
+ df = prophet(df=prophet_df, time_grain="P1M", periods=3, confidence_interval=0.9)
+ columns = {column for column in df.columns}
+ assert columns == {
+ DTTM_ALIAS,
+ "a__yhat",
+ "a__yhat_upper",
+ "a__yhat_lower",
+ "a",
+ "b__yhat",
+ "b__yhat_upper",
+ "b__yhat_lower",
+ "b",
+ }
+ assert df[DTTM_ALIAS].iloc[0].to_pydatetime() == datetime(2018, 12, 31)
+ assert df[DTTM_ALIAS].iloc[-1].to_pydatetime() == datetime(2022, 3, 31)
+ assert len(df) == 7
+
+ df = prophet(df=prophet_df, time_grain="P1M", periods=5, confidence_interval=0.9)
+ assert df[DTTM_ALIAS].iloc[0].to_pydatetime() == datetime(2018, 12, 31)
+ assert df[DTTM_ALIAS].iloc[-1].to_pydatetime() == datetime(2022, 5, 31)
+ assert len(df) == 9
+
+ df = prophet(
+ df=pd.DataFrame(
+ {
+ "__timestamp": [datetime(2022, 1, 2), datetime(2022, 1, 9)],
+ "x": [1, 1],
+ }
+ ),
+ time_grain="P1W",
+ periods=1,
+ confidence_interval=0.9,
+ )
+
+ assert df[DTTM_ALIAS].iloc[-1].to_pydatetime() == datetime(2022, 1, 16)
+ assert len(df) == 3
+
+ df = prophet(
+ df=pd.DataFrame(
+ {
+ "__timestamp": [datetime(2022, 1, 2), datetime(2022, 1, 9)],
+ "x": [1, 1],
+ }
+ ),
+ time_grain="1969-12-28T00:00:00Z/P1W",
+ periods=1,
+ confidence_interval=0.9,
+ )
+
+ assert df[DTTM_ALIAS].iloc[-1].to_pydatetime() == datetime(2022, 1, 16)
+ assert len(df) == 3
+
+ df = prophet(
+ df=pd.DataFrame(
+ {
+ "__timestamp": [datetime(2022, 1, 3), datetime(2022, 1, 10)],
+ "x": [1, 1],
+ }
+ ),
+ time_grain="1969-12-29T00:00:00Z/P1W",
+ periods=1,
+ confidence_interval=0.9,
+ )
+
+ assert df[DTTM_ALIAS].iloc[-1].to_pydatetime() == datetime(2022, 1, 17)
+ assert len(df) == 3
+
+ df = prophet(
+ df=pd.DataFrame(
+ {
+ "__timestamp": [datetime(2022, 1, 8), datetime(2022, 1, 15)],
+ "x": [1, 1],
+ }
+ ),
+ time_grain="P1W/1970-01-03T00:00:00Z",
+ periods=1,
+ confidence_interval=0.9,
+ )
+
+ assert df[DTTM_ALIAS].iloc[-1].to_pydatetime() == datetime(2022, 1, 22)
+ assert len(df) == 3
+
+
+def test_prophet_valid_zero_periods():
+ pytest.importorskip("prophet")
+
+ df = prophet(df=prophet_df, time_grain="P1M", periods=0, confidence_interval=0.9)
+ columns = {column for column in df.columns}
+ assert columns == {
+ DTTM_ALIAS,
+ "a__yhat",
+ "a__yhat_upper",
+ "a__yhat_lower",
+ "a",
+ "b__yhat",
+ "b__yhat_upper",
+ "b__yhat_lower",
+ "b",
+ }
+ assert df[DTTM_ALIAS].iloc[0].to_pydatetime() == datetime(2018, 12, 31)
+ assert df[DTTM_ALIAS].iloc[-1].to_pydatetime() == datetime(2021, 12, 31)
+ assert len(df) == 4
+
+
+def test_prophet_import():
+ dynamic_module = find_spec("prophet")
+ if dynamic_module is None:
+ with pytest.raises(InvalidPostProcessingError):
+ prophet(df=prophet_df, time_grain="P1M", periods=3, confidence_interval=0.9)
+
+
+def test_prophet_missing_temporal_column():
+ df = prophet_df.drop(DTTM_ALIAS, axis=1)
+
+ with pytest.raises(InvalidPostProcessingError):
+ prophet(
+ df=df,
+ time_grain="P1M",
+ periods=3,
+ confidence_interval=0.9,
+ )
+
+
+def test_prophet_incorrect_confidence_interval():
+ with pytest.raises(InvalidPostProcessingError):
+ prophet(
+ df=prophet_df,
+ time_grain="P1M",
+ periods=3,
+ confidence_interval=0.0,
+ )
+
+ with pytest.raises(InvalidPostProcessingError):
+ prophet(
+ df=prophet_df,
+ time_grain="P1M",
+ periods=3,
+ confidence_interval=1.0,
+ )
+
+
+def test_prophet_incorrect_periods():
+ with pytest.raises(InvalidPostProcessingError):
+ prophet(
+ df=prophet_df,
+ time_grain="P1M",
+ periods=-1,
+ confidence_interval=0.8,
+ )
+
+
+def test_prophet_incorrect_time_grain():
+ with pytest.raises(InvalidPostProcessingError):
+ prophet(
+ df=prophet_df,
+ time_grain="yearly",
+ periods=10,
+ confidence_interval=0.8,
+ )
diff --git a/tests/unit_tests/pandas_postprocessing/test_rename.py b/tests/unit_tests/pandas_postprocessing/test_rename.py
new file mode 100644
index 0000000000000..f49680a352618
--- /dev/null
+++ b/tests/unit_tests/pandas_postprocessing/test_rename.py
@@ -0,0 +1,175 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import pandas as pd
+import pytest
+
+from superset.exceptions import InvalidPostProcessingError
+from superset.utils import pandas_postprocessing as pp
+from tests.unit_tests.fixtures.dataframes import categories_df
+
+
+def test_rename_should_not_side_effect():
+ _categories_df = categories_df.copy()
+ pp.rename(
+ df=_categories_df,
+ columns={
+ "constant": "constant_newname",
+ "category": "category_namename",
+ },
+ )
+ assert _categories_df.equals(categories_df)
+
+
+def test_rename():
+ new_categories_df = pp.rename(
+ df=categories_df,
+ columns={
+ "constant": "constant_newname",
+ "category": "category_newname",
+ },
+ )
+ assert list(new_categories_df.columns.values) == [
+ "constant_newname",
+ "category_newname",
+ "dept",
+ "name",
+ "asc_idx",
+ "desc_idx",
+ "idx_nulls",
+ ]
+ assert not new_categories_df.equals(categories_df)
+
+
+def test_should_inplace_rename():
+ _categories_df = categories_df.copy()
+ _categories_df_inplaced = pp.rename(
+ df=_categories_df,
+ columns={
+ "constant": "constant_newname",
+ "category": "category_namename",
+ },
+ inplace=True,
+ )
+ assert _categories_df_inplaced.equals(_categories_df)
+
+
+def test_should_rename_on_level():
+ iterables = [["m1", "m2"], ["a", "b"], ["x", "y"]]
+ columns = pd.MultiIndex.from_product(iterables, names=[None, "level1", "level2"])
+ df = pd.DataFrame(index=[0, 1, 2], columns=columns, data=1)
+ """
+ m1 m2
+ level1 a b a b
+ level2 x y x y x y x y
+ 0 1 1 1 1 1 1 1 1
+ 1 1 1 1 1 1 1 1 1
+ 2 1 1 1 1 1 1 1 1
+ """
+ post_df = pp.rename(
+ df=df,
+ columns={"m1": "new_m1"},
+ level=0,
+ )
+ assert post_df.columns.get_level_values(level=0).equals(
+ pd.Index(
+ [
+ "new_m1",
+ "new_m1",
+ "new_m1",
+ "new_m1",
+ "m2",
+ "m2",
+ "m2",
+ "m2",
+ ]
+ )
+ )
+
+
+def test_should_raise_exception_no_column():
+ with pytest.raises(InvalidPostProcessingError):
+ pp.rename(
+ df=categories_df,
+ columns={
+ "foobar": "foobar2",
+ },
+ )
+
+
+def test_should_raise_exception_duplication():
+ with pytest.raises(InvalidPostProcessingError):
+ pp.rename(
+ df=categories_df,
+ columns={
+ "constant": "category",
+ },
+ )
+
+
+def test_should_raise_exception_duplication_on_multiindx():
+ iterables = [["m1", "m2"], ["a", "b"], ["x", "y"]]
+ columns = pd.MultiIndex.from_product(iterables, names=[None, "level1", "level2"])
+ df = pd.DataFrame(index=[0, 1, 2], columns=columns, data=1)
+ """
+ m1 m2
+ level1 a b a b
+ level2 x y x y x y x y
+ 0 1 1 1 1 1 1 1 1
+ 1 1 1 1 1 1 1 1 1
+ 2 1 1 1 1 1 1 1 1
+ """
+
+ with pytest.raises(InvalidPostProcessingError):
+ pp.rename(
+ df=df,
+ columns={
+ "m1": "m2",
+ },
+ level=0,
+ )
+ pp.rename(
+ df=df,
+ columns={
+ "a": "b",
+ },
+ level=1,
+ )
+
+
+def test_should_raise_exception_invalid_level():
+ with pytest.raises(InvalidPostProcessingError):
+ pp.rename(
+ df=categories_df,
+ columns={
+ "constant": "new_constant",
+ },
+ level=100,
+ )
+ pp.rename(
+ df=categories_df,
+ columns={
+ "constant": "new_constant",
+ },
+ level="xxxxx",
+ )
+
+
+def test_should_return_df_empty_columns():
+ assert pp.rename(
+ df=categories_df,
+ columns={},
+ ).equals(categories_df)
diff --git a/tests/unit_tests/pandas_postprocessing/test_resample.py b/tests/unit_tests/pandas_postprocessing/test_resample.py
new file mode 100644
index 0000000000000..b1414c5fe8fdc
--- /dev/null
+++ b/tests/unit_tests/pandas_postprocessing/test_resample.py
@@ -0,0 +1,208 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import numpy as np
+import pandas as pd
+import pytest
+from pandas import to_datetime
+
+from superset.exceptions import InvalidPostProcessingError
+from superset.utils import pandas_postprocessing as pp
+from tests.unit_tests.fixtures.dataframes import categories_df, timeseries_df
+
+
+def test_resample_should_not_side_effect():
+ _timeseries_df = timeseries_df.copy()
+ pp.resample(df=_timeseries_df, rule="1D", method="ffill")
+ assert _timeseries_df.equals(timeseries_df)
+
+
+def test_resample():
+ post_df = pp.resample(df=timeseries_df, rule="1D", method="ffill")
+ """
+ label y
+ 2019-01-01 x 1.0
+ 2019-01-02 y 2.0
+ 2019-01-03 y 2.0
+ 2019-01-04 y 2.0
+ 2019-01-05 z 3.0
+ 2019-01-06 z 3.0
+ 2019-01-07 q 4.0
+ """
+ assert post_df.equals(
+ pd.DataFrame(
+ index=pd.to_datetime(
+ [
+ "2019-01-01",
+ "2019-01-02",
+ "2019-01-03",
+ "2019-01-04",
+ "2019-01-05",
+ "2019-01-06",
+ "2019-01-07",
+ ]
+ ),
+ data={
+ "label": ["x", "y", "y", "y", "z", "z", "q"],
+ "y": [1.0, 2.0, 2.0, 2.0, 3.0, 3.0, 4.0],
+ },
+ )
+ )
+
+
+def test_resample_zero_fill():
+ post_df = pp.resample(df=timeseries_df, rule="1D", method="asfreq", fill_value=0)
+ assert post_df.equals(
+ pd.DataFrame(
+ index=pd.to_datetime(
+ [
+ "2019-01-01",
+ "2019-01-02",
+ "2019-01-03",
+ "2019-01-04",
+ "2019-01-05",
+ "2019-01-06",
+ "2019-01-07",
+ ]
+ ),
+ data={
+ "label": ["x", "y", 0, 0, "z", 0, "q"],
+ "y": [1.0, 2.0, 0, 0, 3.0, 0, 4.0],
+ },
+ )
+ )
+
+
+def test_resample_after_pivot():
+ df = pd.DataFrame(
+ data={
+ "__timestamp": pd.to_datetime(
+ [
+ "2022-01-13",
+ "2022-01-13",
+ "2022-01-13",
+ "2022-01-11",
+ "2022-01-11",
+ "2022-01-11",
+ ]
+ ),
+ "city": ["Chicago", "LA", "NY", "Chicago", "LA", "NY"],
+ "val": [6.0, 5.0, 4.0, 3.0, 2.0, 1.0],
+ }
+ )
+ pivot_df = pp.pivot(
+ df=df,
+ index=["__timestamp"],
+ columns=["city"],
+ aggregates={
+ "val": {"operator": "sum"},
+ },
+ )
+ """
+ val
+ city Chicago LA NY
+ __timestamp
+ 2022-01-11 3.0 2.0 1.0
+ 2022-01-13 6.0 5.0 4.0
+ """
+ resample_df = pp.resample(
+ df=pivot_df,
+ rule="1D",
+ method="asfreq",
+ fill_value=0,
+ )
+ """
+ val
+ city Chicago LA NY
+ __timestamp
+ 2022-01-11 3.0 2.0 1.0
+ 2022-01-12 0.0 0.0 0.0
+ 2022-01-13 6.0 5.0 4.0
+ """
+ flat_df = pp.flatten(resample_df)
+ """
+ __timestamp val, Chicago val, LA val, NY
+ 0 2022-01-11 3.0 2.0 1.0
+ 1 2022-01-12 0.0 0.0 0.0
+ 2 2022-01-13 6.0 5.0 4.0
+ """
+ assert flat_df.equals(
+ pd.DataFrame(
+ data={
+ "__timestamp": pd.to_datetime(
+ ["2022-01-11", "2022-01-12", "2022-01-13"]
+ ),
+ "val, Chicago": [3.0, 0, 6.0],
+ "val, LA": [2.0, 0, 5.0],
+ "val, NY": [1.0, 0, 4.0],
+ }
+ )
+ )
+
+
+def test_resample_should_raise_ex():
+ with pytest.raises(InvalidPostProcessingError):
+ pp.resample(
+ df=categories_df,
+ rule="1D",
+ method="asfreq",
+ )
+
+ with pytest.raises(InvalidPostProcessingError):
+ pp.resample(
+ df=timeseries_df,
+ rule="1D",
+ method="foobar",
+ )
+
+
+def test_resample_linear():
+ df = pd.DataFrame(
+ index=to_datetime(["2019-01-01", "2019-01-05", "2019-01-08"]),
+ data={"label": ["a", "e", "j"], "y": [1.0, 5.0, 8.0]},
+ )
+ post_df = pp.resample(df=df, rule="1D", method="linear")
+ """
+ label y
+ 2019-01-01 a 1.0
+ 2019-01-02 NaN 2.0
+ 2019-01-03 NaN 3.0
+ 2019-01-04 NaN 4.0
+ 2019-01-05 e 5.0
+ 2019-01-06 NaN 6.0
+ 2019-01-07 NaN 7.0
+ 2019-01-08 j 8.0
+ """
+ assert post_df.equals(
+ pd.DataFrame(
+ index=pd.to_datetime(
+ [
+ "2019-01-01",
+ "2019-01-02",
+ "2019-01-03",
+ "2019-01-04",
+ "2019-01-05",
+ "2019-01-06",
+ "2019-01-07",
+ "2019-01-08",
+ ]
+ ),
+ data={
+ "label": ["a", np.NaN, np.NaN, np.NaN, "e", np.NaN, np.NaN, "j"],
+ "y": [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0],
+ },
+ )
+ )
diff --git a/tests/unit_tests/pandas_postprocessing/test_rolling.py b/tests/unit_tests/pandas_postprocessing/test_rolling.py
new file mode 100644
index 0000000000000..b72a8bee44827
--- /dev/null
+++ b/tests/unit_tests/pandas_postprocessing/test_rolling.py
@@ -0,0 +1,222 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import pandas as pd
+import pytest
+
+from superset.exceptions import InvalidPostProcessingError
+from superset.utils import pandas_postprocessing as pp
+from superset.utils.pandas_postprocessing.utils import FLAT_COLUMN_SEPARATOR
+from tests.unit_tests.fixtures.dataframes import (
+ multiple_metrics_df,
+ single_metric_df,
+ timeseries_df,
+)
+from tests.unit_tests.pandas_postprocessing.utils import series_to_list
+
+
+def test_rolling_should_not_side_effect():
+ _timeseries_df = timeseries_df.copy()
+ pp.rolling(
+ df=timeseries_df,
+ columns={"y": "y"},
+ rolling_type="sum",
+ window=2,
+ min_periods=0,
+ )
+ assert _timeseries_df.equals(timeseries_df)
+
+
+def test_rolling():
+ # sum rolling type
+ post_df = pp.rolling(
+ df=timeseries_df,
+ columns={"y": "y"},
+ rolling_type="sum",
+ window=2,
+ min_periods=0,
+ )
+
+ assert post_df.columns.tolist() == ["label", "y"]
+ assert series_to_list(post_df["y"]) == [1.0, 3.0, 5.0, 7.0]
+
+ # mean rolling type with alias
+ post_df = pp.rolling(
+ df=timeseries_df,
+ rolling_type="mean",
+ columns={"y": "y_mean"},
+ window=10,
+ min_periods=0,
+ )
+ assert post_df.columns.tolist() == ["label", "y", "y_mean"]
+ assert series_to_list(post_df["y_mean"]) == [1.0, 1.5, 2.0, 2.5]
+
+ # count rolling type
+ post_df = pp.rolling(
+ df=timeseries_df,
+ rolling_type="count",
+ columns={"y": "y"},
+ window=10,
+ min_periods=0,
+ )
+ assert post_df.columns.tolist() == ["label", "y"]
+ assert series_to_list(post_df["y"]) == [1.0, 2.0, 3.0, 4.0]
+
+ # quantile rolling type
+ post_df = pp.rolling(
+ df=timeseries_df,
+ columns={"y": "q1"},
+ rolling_type="quantile",
+ rolling_type_options={"quantile": 0.25},
+ window=10,
+ min_periods=0,
+ )
+ assert post_df.columns.tolist() == ["label", "y", "q1"]
+ assert series_to_list(post_df["q1"]) == [1.0, 1.25, 1.5, 1.75]
+
+ # incorrect rolling type
+ with pytest.raises(InvalidPostProcessingError):
+ pp.rolling(
+ df=timeseries_df,
+ columns={"y": "y"},
+ rolling_type="abc",
+ window=2,
+ )
+
+ # incorrect rolling type options
+ with pytest.raises(InvalidPostProcessingError):
+ pp.rolling(
+ df=timeseries_df,
+ columns={"y": "y"},
+ rolling_type="quantile",
+ rolling_type_options={"abc": 123},
+ window=2,
+ )
+
+
+def test_rolling_should_empty_df():
+ pivot_df = pp.pivot(
+ df=single_metric_df,
+ index=["dttm"],
+ columns=["country"],
+ aggregates={"sum_metric": {"operator": "sum"}},
+ )
+ rolling_df = pp.rolling(
+ df=pivot_df,
+ rolling_type="sum",
+ window=2,
+ min_periods=2,
+ columns={"sum_metric": "sum_metric"},
+ )
+ assert rolling_df.empty is True
+
+
+def test_rolling_after_pivot_with_single_metric():
+ pivot_df = pp.pivot(
+ df=single_metric_df,
+ index=["dttm"],
+ columns=["country"],
+ aggregates={"sum_metric": {"operator": "sum"}},
+ )
+ """
+ sum_metric
+ country UK US
+ dttm
+ 2019-01-01 5 6
+ 2019-01-02 7 8
+ """
+ rolling_df = pp.rolling(
+ df=pivot_df,
+ columns={"sum_metric": "sum_metric"},
+ rolling_type="sum",
+ window=2,
+ min_periods=0,
+ )
+ """
+ sum_metric
+ country UK US
+ dttm
+ 2019-01-01 5.0 6.0
+ 2019-01-02 12.0 14.0
+ """
+ flat_df = pp.flatten(rolling_df)
+ """
+ dttm sum_metric, UK sum_metric, US
+ 0 2019-01-01 5.0 6.0
+ 1 2019-01-02 12.0 14.0
+ """
+ assert flat_df.equals(
+ pd.DataFrame(
+ data={
+ "dttm": pd.to_datetime(["2019-01-01", "2019-01-02"]),
+ FLAT_COLUMN_SEPARATOR.join(["sum_metric", "UK"]): [5.0, 12.0],
+ FLAT_COLUMN_SEPARATOR.join(["sum_metric", "US"]): [6.0, 14.0],
+ }
+ )
+ )
+
+
+def test_rolling_after_pivot_with_multiple_metrics():
+ pivot_df = pp.pivot(
+ df=multiple_metrics_df,
+ index=["dttm"],
+ columns=["country"],
+ aggregates={
+ "sum_metric": {"operator": "sum"},
+ "count_metric": {"operator": "sum"},
+ },
+ )
+ """
+ count_metric sum_metric
+ country UK US UK US
+ dttm
+ 2019-01-01 1 2 5 6
+ 2019-01-02 3 4 7 8
+ """
+ rolling_df = pp.rolling(
+ df=pivot_df,
+ columns={
+ "count_metric": "count_metric",
+ "sum_metric": "sum_metric",
+ },
+ rolling_type="sum",
+ window=2,
+ min_periods=0,
+ )
+ """
+ count_metric sum_metric
+ country UK US UK US
+ dttm
+ 2019-01-01 1.0 2.0 5.0 6.0
+ 2019-01-02 4.0 6.0 12.0 14.0
+ """
+ flat_df = pp.flatten(rolling_df)
+ """
+ dttm count_metric, UK count_metric, US sum_metric, UK sum_metric, US
+ 0 2019-01-01 1.0 2.0 5.0 6.0
+ 1 2019-01-02 4.0 6.0 12.0 14.0
+ """
+ assert flat_df.equals(
+ pd.DataFrame(
+ data={
+ "dttm": pd.to_datetime(["2019-01-01", "2019-01-02"]),
+ FLAT_COLUMN_SEPARATOR.join(["count_metric", "UK"]): [1.0, 4.0],
+ FLAT_COLUMN_SEPARATOR.join(["count_metric", "US"]): [2.0, 6.0],
+ FLAT_COLUMN_SEPARATOR.join(["sum_metric", "UK"]): [5.0, 12.0],
+ FLAT_COLUMN_SEPARATOR.join(["sum_metric", "US"]): [6.0, 14.0],
+ }
+ )
+ )
diff --git a/tests/unit_tests/pandas_postprocessing/test_select.py b/tests/unit_tests/pandas_postprocessing/test_select.py
new file mode 100644
index 0000000000000..2ba126fc4c739
--- /dev/null
+++ b/tests/unit_tests/pandas_postprocessing/test_select.py
@@ -0,0 +1,55 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import pytest
+
+from superset.exceptions import InvalidPostProcessingError
+from superset.utils.pandas_postprocessing.select import select
+from tests.unit_tests.fixtures.dataframes import timeseries_df
+
+
+def test_select():
+ # reorder columns
+ post_df = select(df=timeseries_df, columns=["y", "label"])
+ assert post_df.columns.tolist() == ["y", "label"]
+
+ # one column
+ post_df = select(df=timeseries_df, columns=["label"])
+ assert post_df.columns.tolist() == ["label"]
+
+ # rename and select one column
+ post_df = select(df=timeseries_df, columns=["y"], rename={"y": "y1"})
+ assert post_df.columns.tolist() == ["y1"]
+
+ # rename one and leave one unchanged
+ post_df = select(df=timeseries_df, rename={"y": "y1"})
+ assert post_df.columns.tolist() == ["label", "y1"]
+
+ # drop one column
+ post_df = select(df=timeseries_df, exclude=["label"])
+ assert post_df.columns.tolist() == ["y"]
+
+ # rename and drop one column
+ post_df = select(df=timeseries_df, rename={"y": "y1"}, exclude=["label"])
+ assert post_df.columns.tolist() == ["y1"]
+
+ # invalid columns
+ with pytest.raises(InvalidPostProcessingError):
+ select(df=timeseries_df, columns=["abc"], rename={"abc": "qwerty"})
+
+ # select renamed column by new name
+ with pytest.raises(InvalidPostProcessingError):
+ select(df=timeseries_df, columns=["label_new"], rename={"label": "label_new"})
diff --git a/tests/unit_tests/pandas_postprocessing/test_sort.py b/tests/unit_tests/pandas_postprocessing/test_sort.py
new file mode 100644
index 0000000000000..e19da38efc1ec
--- /dev/null
+++ b/tests/unit_tests/pandas_postprocessing/test_sort.py
@@ -0,0 +1,53 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import pytest
+from dateutil.parser import parse
+
+from superset.exceptions import InvalidPostProcessingError
+from superset.utils.pandas_postprocessing import sort
+from tests.unit_tests.fixtures.dataframes import categories_df, timeseries_df
+from tests.unit_tests.pandas_postprocessing.utils import series_to_list
+
+
+def test_sort():
+ df = sort(df=categories_df, by=["category", "asc_idx"], ascending=[True, False])
+ assert series_to_list(df["asc_idx"])[1] == 96
+
+ df = sort(df=categories_df.set_index("name"), is_sort_index=True)
+ assert df.index[0] == "person0"
+
+ df = sort(df=categories_df.set_index("name"), is_sort_index=True, ascending=False)
+ assert df.index[0] == "person99"
+
+ df = sort(df=categories_df.set_index("name"), by="asc_idx")
+ assert df["asc_idx"][0] == 0
+
+ df = sort(df=categories_df.set_index("name"), by="asc_idx", ascending=False)
+ assert df["asc_idx"][0] == 100
+
+ df = sort(df=timeseries_df, is_sort_index=True)
+ assert df.index[0] == parse("2019-01-01")
+
+ df = sort(df=timeseries_df, is_sort_index=True, ascending=False)
+ assert df.index[0] == parse("2019-01-07")
+
+ df = sort(df=timeseries_df)
+ assert df.equals(timeseries_df)
+
+ with pytest.raises(InvalidPostProcessingError):
+ sort(df=df, by="abc", ascending=False)
+ sort(df=df, by=["abc", "def"])
diff --git a/tests/unit_tests/pandas_postprocessing/test_utils.py b/tests/unit_tests/pandas_postprocessing/test_utils.py
new file mode 100644
index 0000000000000..058cefcd6c72a
--- /dev/null
+++ b/tests/unit_tests/pandas_postprocessing/test_utils.py
@@ -0,0 +1,30 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from superset.utils.pandas_postprocessing import escape_separator, unescape_separator
+
+
+def test_escape_separator():
+ assert escape_separator(r" hell \world ") == r" hell \world "
+ assert unescape_separator(r" hell \world ") == r" hell \world "
+
+ escape_string = escape_separator("hello, world")
+ assert escape_string == r"hello\, world"
+ assert unescape_separator(escape_string) == "hello, world"
+
+ escape_string = escape_separator("hello,world")
+ assert escape_string == r"hello\,world"
+ assert unescape_separator(escape_string) == "hello,world"
diff --git a/tests/unit_tests/pandas_postprocessing/utils.py b/tests/unit_tests/pandas_postprocessing/utils.py
new file mode 100644
index 0000000000000..07366b15774d1
--- /dev/null
+++ b/tests/unit_tests/pandas_postprocessing/utils.py
@@ -0,0 +1,55 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import math
+from typing import Any, List, Optional
+
+from pandas import Series
+
+AGGREGATES_SINGLE = {"idx_nulls": {"operator": "sum"}}
+AGGREGATES_MULTIPLE = {
+ "idx_nulls": {"operator": "sum"},
+ "asc_idx": {"operator": "mean"},
+}
+
+
+def series_to_list(series: Series) -> List[Any]:
+ """
+ Converts a `Series` to a regular list, and replaces non-numeric values to
+ Nones.
+
+ :param series: Series to convert
+ :return: list without nan or inf
+ """
+ return [
+ None
+ if not isinstance(val, str) and (math.isnan(val) or math.isinf(val))
+ else val
+ for val in series.tolist()
+ ]
+
+
+def round_floats(
+ floats: List[Optional[float]], precision: int
+) -> List[Optional[float]]:
+ """
+ Round list of floats to certain precision
+
+ :param floats: floats to round
+ :param precision: intended decimal precision
+ :return: rounded floats
+ """
+ return [round(val, precision) if val else None for val in floats]
diff --git a/tests/unit_tests/result_set_test.py b/tests/unit_tests/result_set_test.py
new file mode 100644
index 0000000000000..331810bb1ed62
--- /dev/null
+++ b/tests/unit_tests/result_set_test.py
@@ -0,0 +1,142 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+# pylint: disable=import-outside-toplevel, unused-argument
+
+
+import numpy as np
+import pandas as pd
+from numpy.core.multiarray import array
+
+from superset.result_set import stringify_values
+
+
+def test_column_names_as_bytes() -> None:
+ """
+ Test that we can handle column names as bytes.
+ """
+ from superset.db_engine_specs.redshift import RedshiftEngineSpec
+ from superset.result_set import SupersetResultSet
+
+ data = (
+ [
+ "2016-01-26",
+ 392.002014,
+ 397.765991,
+ 390.575012,
+ 392.153015,
+ 392.153015,
+ 58147000,
+ ],
+ [
+ "2016-01-27",
+ 392.444,
+ 396.842987,
+ 391.782013,
+ 394.971985,
+ 394.971985,
+ 47424400,
+ ],
+ )
+ description = [
+ (b"date", 1043, None, None, None, None, None),
+ (b"open", 701, None, None, None, None, None),
+ (b"high", 701, None, None, None, None, None),
+ (b"low", 701, None, None, None, None, None),
+ (b"close", 701, None, None, None, None, None),
+ (b"adj close", 701, None, None, None, None, None),
+ (b"volume", 20, None, None, None, None, None),
+ ]
+ result_set = SupersetResultSet(data, description, RedshiftEngineSpec) # type: ignore
+
+ assert (
+ result_set.to_pandas_df().to_markdown()
+ == """
+| | date | open | high | low | close | adj close | volume |
+|---:|:-----------|--------:|--------:|--------:|--------:|------------:|---------:|
+| 0 | 2016-01-26 | 392.002 | 397.766 | 390.575 | 392.153 | 392.153 | 58147000 |
+| 1 | 2016-01-27 | 392.444 | 396.843 | 391.782 | 394.972 | 394.972 | 47424400 |
+ """.strip()
+ )
+
+
+def test_stringify_with_null_integers():
+ """
+ Test that we can safely handle type errors when an integer column has a null value
+ """
+
+ data = [
+ ("foo", "bar", pd.NA, None),
+ ("foo", "bar", pd.NA, True),
+ ("foo", "bar", pd.NA, None),
+ ]
+ numpy_dtype = [
+ ("id", "object"),
+ ("value", "object"),
+ ("num", "object"),
+ ("bool", "object"),
+ ]
+
+ array2 = np.array(data, dtype=numpy_dtype)
+ column_names = ["id", "value", "num", "bool"]
+
+ result_set = np.array([stringify_values(array2[column]) for column in column_names])
+
+ expected = np.array(
+ [
+ array(["foo", "foo", "foo"], dtype=object),
+ array(["bar", "bar", "bar"], dtype=object),
+ array([None, None, None], dtype=object),
+ array([None, "True", None], dtype=object),
+ ]
+ )
+
+ assert np.array_equal(result_set, expected)
+
+
+def test_stringify_with_null_timestamps():
+ """
+ Test that we can safely handle type errors when a timestamp column has a null value
+ """
+
+ data = [
+ ("foo", "bar", pd.NaT, None),
+ ("foo", "bar", pd.NaT, True),
+ ("foo", "bar", pd.NaT, None),
+ ]
+ numpy_dtype = [
+ ("id", "object"),
+ ("value", "object"),
+ ("num", "object"),
+ ("bool", "object"),
+ ]
+
+ array2 = np.array(data, dtype=numpy_dtype)
+ column_names = ["id", "value", "num", "bool"]
+
+ result_set = np.array([stringify_values(array2[column]) for column in column_names])
+
+ expected = np.array(
+ [
+ array(["foo", "foo", "foo"], dtype=object),
+ array(["bar", "bar", "bar"], dtype=object),
+ array([None, None, None], dtype=object),
+ array([None, "True", None], dtype=object),
+ ]
+ )
+
+ assert np.array_equal(result_set, expected)
diff --git a/tests/unit_tests/sql_lab_test.py b/tests/unit_tests/sql_lab_test.py
new file mode 100644
index 0000000000000..29f45eab682a0
--- /dev/null
+++ b/tests/unit_tests/sql_lab_test.py
@@ -0,0 +1,218 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=import-outside-toplevel, invalid-name, unused-argument, too-many-locals
+
+import sqlparse
+from pytest_mock import MockerFixture
+from sqlalchemy.orm.session import Session
+
+from superset.utils.core import override_user
+
+
+def test_execute_sql_statement(mocker: MockerFixture, app: None) -> None:
+ """
+ Simple test for `execute_sql_statement`.
+ """
+ from superset.sql_lab import execute_sql_statement
+
+ sql_statement = "SELECT 42 AS answer"
+
+ query = mocker.MagicMock()
+ query.limit = 1
+ query.select_as_cta_used = False
+ database = query.database
+ database.allow_dml = False
+ database.apply_limit_to_sql.return_value = "SELECT 42 AS answer LIMIT 2"
+ db_engine_spec = database.db_engine_spec
+ db_engine_spec.is_select_query.return_value = True
+ db_engine_spec.fetch_data.return_value = [(42,)]
+
+ session = mocker.MagicMock()
+ cursor = mocker.MagicMock()
+ SupersetResultSet = mocker.patch("superset.sql_lab.SupersetResultSet")
+
+ execute_sql_statement(
+ sql_statement,
+ query,
+ session=session,
+ cursor=cursor,
+ log_params={},
+ apply_ctas=False,
+ )
+
+ database.apply_limit_to_sql.assert_called_with("SELECT 42 AS answer", 2, force=True)
+ db_engine_spec.execute.assert_called_with(
+ cursor, "SELECT 42 AS answer LIMIT 2", async_=True
+ )
+ SupersetResultSet.assert_called_with([(42,)], cursor.description, db_engine_spec)
+
+
+def test_execute_sql_statement_with_rls(
+ mocker: MockerFixture,
+) -> None:
+ """
+ Test for `execute_sql_statement` when an RLS rule is in place.
+ """
+ from superset.sql_lab import execute_sql_statement
+
+ sql_statement = "SELECT * FROM sales"
+
+ query = mocker.MagicMock()
+ query.limit = 100
+ query.select_as_cta_used = False
+ database = query.database
+ database.allow_dml = False
+ database.apply_limit_to_sql.return_value = (
+ "SELECT * FROM sales WHERE organization_id=42 LIMIT 101"
+ )
+ db_engine_spec = database.db_engine_spec
+ db_engine_spec.is_select_query.return_value = True
+ db_engine_spec.fetch_data.return_value = [(42,)]
+
+ session = mocker.MagicMock()
+ cursor = mocker.MagicMock()
+ SupersetResultSet = mocker.patch("superset.sql_lab.SupersetResultSet")
+ mocker.patch(
+ "superset.sql_lab.insert_rls",
+ return_value=sqlparse.parse("SELECT * FROM sales WHERE organization_id=42")[0],
+ )
+ mocker.patch("superset.sql_lab.is_feature_enabled", return_value=True)
+
+ execute_sql_statement(
+ sql_statement,
+ query,
+ session=session,
+ cursor=cursor,
+ log_params={},
+ apply_ctas=False,
+ )
+
+ database.apply_limit_to_sql.assert_called_with(
+ "SELECT * FROM sales WHERE organization_id=42",
+ 101,
+ force=True,
+ )
+ db_engine_spec.execute.assert_called_with(
+ cursor,
+ "SELECT * FROM sales WHERE organization_id=42 LIMIT 101",
+ async_=True,
+ )
+ SupersetResultSet.assert_called_with([(42,)], cursor.description, db_engine_spec)
+
+
+def test_sql_lab_insert_rls(
+ mocker: MockerFixture,
+ session: Session,
+) -> None:
+ """
+ Integration test for `insert_rls`.
+ """
+ from flask_appbuilder.security.sqla.models import Role, User
+
+ from superset.connectors.sqla.models import RowLevelSecurityFilter, SqlaTable
+ from superset.models.core import Database
+ from superset.models.sql_lab import Query
+ from superset.security.manager import SupersetSecurityManager
+ from superset.sql_lab import execute_sql_statement
+ from superset.utils.core import RowLevelSecurityFilterType
+
+ engine = session.connection().engine
+ Query.metadata.create_all(engine) # pylint: disable=no-member
+
+ connection = engine.raw_connection()
+ connection.execute("CREATE TABLE t (c INTEGER)")
+ for i in range(10):
+ connection.execute("INSERT INTO t VALUES (?)", (i,))
+
+ cursor = connection.cursor()
+
+ query = Query(
+ sql="SELECT c FROM t",
+ client_id="abcde",
+ database=Database(database_name="test_db", sqlalchemy_uri="sqlite://"),
+ schema=None,
+ limit=5,
+ select_as_cta_used=False,
+ )
+ session.add(query)
+ session.commit()
+
+ admin = User(
+ first_name="Alice",
+ last_name="Doe",
+ email="adoe@example.org",
+ username="admin",
+ roles=[Role(name="Admin")],
+ )
+
+ # first without RLS
+ with override_user(admin):
+ superset_result_set = execute_sql_statement(
+ sql_statement=query.sql,
+ query=query,
+ session=session,
+ cursor=cursor,
+ log_params=None,
+ apply_ctas=False,
+ )
+ assert (
+ superset_result_set.to_pandas_df().to_markdown()
+ == """
+| | c |
+|---:|----:|
+| 0 | 0 |
+| 1 | 1 |
+| 2 | 2 |
+| 3 | 3 |
+| 4 | 4 |""".strip()
+ )
+ assert query.executed_sql == "SELECT c FROM t\nLIMIT 6"
+
+ # now with RLS
+ rls = RowLevelSecurityFilter(
+ name="sqllab_rls1",
+ filter_type=RowLevelSecurityFilterType.REGULAR,
+ tables=[SqlaTable(database_id=1, schema=None, table_name="t")],
+ roles=[admin.roles[0]],
+ group_key=None,
+ clause="c > 5",
+ )
+ session.add(rls)
+ session.flush()
+ mocker.patch.object(SupersetSecurityManager, "find_user", return_value=admin)
+ mocker.patch("superset.sql_lab.is_feature_enabled", return_value=True)
+
+ with override_user(admin):
+ superset_result_set = execute_sql_statement(
+ sql_statement=query.sql,
+ query=query,
+ session=session,
+ cursor=cursor,
+ log_params=None,
+ apply_ctas=False,
+ )
+ assert (
+ superset_result_set.to_pandas_df().to_markdown()
+ == """
+| | c |
+|---:|----:|
+| 0 | 6 |
+| 1 | 7 |
+| 2 | 8 |
+| 3 | 9 |""".strip()
+ )
+ assert query.executed_sql == "SELECT c FROM t WHERE (t.c > 5)\nLIMIT 6"
diff --git a/tests/unit_tests/sql_parse_tests.py b/tests/unit_tests/sql_parse_tests.py
new file mode 100644
index 0000000000000..ba3da69aaefaf
--- /dev/null
+++ b/tests/unit_tests/sql_parse_tests.py
@@ -0,0 +1,1508 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name, redefined-outer-name, unused-argument, protected-access, too-many-lines
+
+import unittest
+from typing import Optional, Set
+
+import pytest
+import sqlparse
+from pytest_mock import MockerFixture
+from sqlalchemy import text
+from sqlparse.sql import Identifier, Token, TokenList
+from sqlparse.tokens import Name
+
+from superset.exceptions import QueryClauseValidationException
+from superset.sql_parse import (
+ add_table_name,
+ extract_table_references,
+ get_rls_for_table,
+ has_table_query,
+ insert_rls,
+ ParsedQuery,
+ sanitize_clause,
+ strip_comments_from_sql,
+ Table,
+)
+
+
+def extract_tables(query: str) -> Set[Table]:
+ """
+ Helper function to extract tables referenced in a query.
+ """
+ return ParsedQuery(query).tables
+
+
+def test_table() -> None:
+ """
+ Test the ``Table`` class and its string conversion.
+
+ Special characters in the table, schema, or catalog name should be escaped correctly.
+ """
+ assert str(Table("tbname")) == "tbname"
+ assert str(Table("tbname", "schemaname")) == "schemaname.tbname"
+ assert (
+ str(Table("tbname", "schemaname", "catalogname"))
+ == "catalogname.schemaname.tbname"
+ )
+ assert (
+ str(Table("table.name", "schema/name", "catalog\nname"))
+ == "catalog%0Aname.schema%2Fname.table%2Ename"
+ )
+
+
+def test_extract_tables() -> None:
+ """
+ Test that referenced tables are parsed correctly from the SQL.
+ """
+ assert extract_tables("SELECT * FROM tbname") == {Table("tbname")}
+ assert extract_tables("SELECT * FROM tbname foo") == {Table("tbname")}
+ assert extract_tables("SELECT * FROM tbname AS foo") == {Table("tbname")}
+
+ # underscore
+ assert extract_tables("SELECT * FROM tb_name") == {Table("tb_name")}
+
+ # quotes
+ assert extract_tables('SELECT * FROM "tbname"') == {Table("tbname")}
+
+ # unicode
+ assert extract_tables('SELECT * FROM "tb_name" WHERE city = "Lübeck"') == {
+ Table("tb_name")
+ }
+
+ # columns
+ assert extract_tables("SELECT field1, field2 FROM tb_name") == {Table("tb_name")}
+ assert extract_tables("SELECT t1.f1, t2.f2 FROM t1, t2") == {
+ Table("t1"),
+ Table("t2"),
+ }
+
+ # named table
+ assert extract_tables("SELECT a.date, a.field FROM left_table a LIMIT 10") == {
+ Table("left_table")
+ }
+
+ # reverse select
+ assert extract_tables("FROM t1 SELECT field") == {Table("t1")}
+
+
+def test_extract_tables_subselect() -> None:
+ """
+ Test that tables inside subselects are parsed correctly.
+ """
+ assert (
+ extract_tables(
+ """
+SELECT sub.*
+FROM (
+ SELECT *
+ FROM s1.t1
+ WHERE day_of_week = 'Friday'
+ ) sub, s2.t2
+WHERE sub.resolution = 'NONE'
+"""
+ )
+ == {Table("t1", "s1"), Table("t2", "s2")}
+ )
+
+ assert (
+ extract_tables(
+ """
+SELECT sub.*
+FROM (
+ SELECT *
+ FROM s1.t1
+ WHERE day_of_week = 'Friday'
+) sub
+WHERE sub.resolution = 'NONE'
+"""
+ )
+ == {Table("t1", "s1")}
+ )
+
+ assert (
+ extract_tables(
+ """
+SELECT * FROM t1
+WHERE s11 > ANY (
+ SELECT COUNT(*) /* no hint */ FROM t2
+ WHERE NOT EXISTS (
+ SELECT * FROM t3
+ WHERE ROW(5*t2.s1,77)=(
+ SELECT 50,11*s1 FROM t4
+ )
+ )
+)
+"""
+ )
+ == {Table("t1"), Table("t2"), Table("t3"), Table("t4")}
+ )
+
+
+def test_extract_tables_select_in_expression() -> None:
+ """
+ Test that parser works with ``SELECT``s used as expressions.
+ """
+ assert extract_tables("SELECT f1, (SELECT count(1) FROM t2) FROM t1") == {
+ Table("t1"),
+ Table("t2"),
+ }
+ assert extract_tables("SELECT f1, (SELECT count(1) FROM t2) as f2 FROM t1") == {
+ Table("t1"),
+ Table("t2"),
+ }
+
+
+def test_extract_tables_parenthesis() -> None:
+ """
+ Test that parenthesis are parsed correctly.
+ """
+ assert extract_tables("SELECT f1, (x + y) AS f2 FROM t1") == {Table("t1")}
+
+
+def test_extract_tables_with_schema() -> None:
+ """
+ Test that schemas are parsed correctly.
+ """
+ assert extract_tables("SELECT * FROM schemaname.tbname") == {
+ Table("tbname", "schemaname")
+ }
+ assert extract_tables('SELECT * FROM "schemaname"."tbname"') == {
+ Table("tbname", "schemaname")
+ }
+ assert extract_tables('SELECT * FROM "schemaname"."tbname" foo') == {
+ Table("tbname", "schemaname")
+ }
+ assert extract_tables('SELECT * FROM "schemaname"."tbname" AS foo') == {
+ Table("tbname", "schemaname")
+ }
+
+
+def test_extract_tables_union() -> None:
+ """
+ Test that ``UNION`` queries work as expected.
+ """
+ assert extract_tables("SELECT * FROM t1 UNION SELECT * FROM t2") == {
+ Table("t1"),
+ Table("t2"),
+ }
+ assert extract_tables("SELECT * FROM t1 UNION ALL SELECT * FROM t2") == {
+ Table("t1"),
+ Table("t2"),
+ }
+ assert extract_tables("SELECT * FROM t1 INTERSECT ALL SELECT * FROM t2") == {
+ Table("t1"),
+ Table("t2"),
+ }
+
+
+def test_extract_tables_select_from_values() -> None:
+ """
+ Test that selecting from values returns no tables.
+ """
+ assert extract_tables("SELECT * FROM VALUES (13, 42)") == set()
+
+
+def test_extract_tables_select_array() -> None:
+ """
+ Test that queries selecting arrays work as expected.
+ """
+ assert (
+ extract_tables(
+ """
+SELECT ARRAY[1, 2, 3] AS my_array
+FROM t1 LIMIT 10
+"""
+ )
+ == {Table("t1")}
+ )
+
+
+def test_extract_tables_select_if() -> None:
+ """
+ Test that queries with an ``IF`` work as expected.
+ """
+ assert (
+ extract_tables(
+ """
+SELECT IF(CARDINALITY(my_array) >= 3, my_array[3], NULL)
+FROM t1 LIMIT 10
+"""
+ )
+ == {Table("t1")}
+ )
+
+
+def test_extract_tables_with_catalog() -> None:
+ """
+ Test that catalogs are parsed correctly.
+ """
+ assert extract_tables("SELECT * FROM catalogname.schemaname.tbname") == {
+ Table("tbname", "schemaname", "catalogname")
+ }
+
+
+def test_extract_tables_illdefined() -> None:
+ """
+ Test that ill-defined tables return an empty set.
+ """
+ assert extract_tables("SELECT * FROM schemaname.") == set()
+ assert extract_tables("SELECT * FROM catalogname.schemaname.") == set()
+ assert extract_tables("SELECT * FROM catalogname..") == set()
+ assert extract_tables("SELECT * FROM catalogname..tbname") == set()
+
+
+def test_extract_tables_show_tables_from() -> None:
+ """
+ Test ``SHOW TABLES FROM``.
+ """
+ assert extract_tables("SHOW TABLES FROM s1 like '%order%'") == set()
+
+
+def test_extract_tables_show_columns_from() -> None:
+ """
+ Test ``SHOW COLUMNS FROM``.
+ """
+ assert extract_tables("SHOW COLUMNS FROM t1") == {Table("t1")}
+
+
+def test_extract_tables_where_subquery() -> None:
+ """
+ Test that tables in a ``WHERE`` subquery are parsed correctly.
+ """
+ assert (
+ extract_tables(
+ """
+SELECT name
+FROM t1
+WHERE regionkey = (SELECT max(regionkey) FROM t2)
+"""
+ )
+ == {Table("t1"), Table("t2")}
+ )
+
+ assert (
+ extract_tables(
+ """
+SELECT name
+FROM t1
+WHERE regionkey IN (SELECT regionkey FROM t2)
+"""
+ )
+ == {Table("t1"), Table("t2")}
+ )
+
+ assert (
+ extract_tables(
+ """
+SELECT name
+FROM t1
+WHERE regionkey EXISTS (SELECT regionkey FROM t2)
+"""
+ )
+ == {Table("t1"), Table("t2")}
+ )
+
+
+def test_extract_tables_describe() -> None:
+ """
+ Test ``DESCRIBE``.
+ """
+ assert extract_tables("DESCRIBE t1") == {Table("t1")}
+
+
+def test_extract_tables_show_partitions() -> None:
+ """
+ Test ``SHOW PARTITIONS``.
+ """
+ assert (
+ extract_tables(
+ """
+SHOW PARTITIONS FROM orders
+WHERE ds >= '2013-01-01' ORDER BY ds DESC
+"""
+ )
+ == {Table("orders")}
+ )
+
+
+def test_extract_tables_join() -> None:
+ """
+ Test joins.
+ """
+ assert extract_tables("SELECT t1.*, t2.* FROM t1 JOIN t2 ON t1.a = t2.a;") == {
+ Table("t1"),
+ Table("t2"),
+ }
+
+ assert (
+ extract_tables(
+ """
+SELECT a.date, b.name
+FROM left_table a
+JOIN (
+ SELECT
+ CAST((b.year) as VARCHAR) date,
+ name
+ FROM right_table
+) b
+ON a.date = b.date
+"""
+ )
+ == {Table("left_table"), Table("right_table")}
+ )
+
+ assert (
+ extract_tables(
+ """
+SELECT a.date, b.name
+FROM left_table a
+LEFT INNER JOIN (
+ SELECT
+ CAST((b.year) as VARCHAR) date,
+ name
+ FROM right_table
+) b
+ON a.date = b.date
+"""
+ )
+ == {Table("left_table"), Table("right_table")}
+ )
+
+ assert (
+ extract_tables(
+ """
+SELECT a.date, b.name
+FROM left_table a
+RIGHT OUTER JOIN (
+ SELECT
+ CAST((b.year) as VARCHAR) date,
+ name
+ FROM right_table
+) b
+ON a.date = b.date
+"""
+ )
+ == {Table("left_table"), Table("right_table")}
+ )
+
+ assert (
+ extract_tables(
+ """
+SELECT a.date, b.name
+FROM left_table a
+FULL OUTER JOIN (
+ SELECT
+ CAST((b.year) as VARCHAR) date,
+ name
+ FROM right_table
+) b
+ON a.date = b.date
+"""
+ )
+ == {Table("left_table"), Table("right_table")}
+ )
+
+
+def test_extract_tables_semi_join() -> None:
+ """
+ Test ``LEFT SEMI JOIN``.
+ """
+ assert (
+ extract_tables(
+ """
+SELECT a.date, b.name
+FROM left_table a
+LEFT SEMI JOIN (
+ SELECT
+ CAST((b.year) as VARCHAR) date,
+ name
+ FROM right_table
+) b
+ON a.data = b.date
+"""
+ )
+ == {Table("left_table"), Table("right_table")}
+ )
+
+
+def test_extract_tables_combinations() -> None:
+ """
+ Test a complex case with nested queries.
+ """
+ assert (
+ extract_tables(
+ """
+SELECT * FROM t1
+WHERE s11 > ANY (
+ SELECT * FROM t1 UNION ALL SELECT * FROM (
+ SELECT t6.*, t3.* FROM t6 JOIN t3 ON t6.a = t3.a
+ ) tmp_join
+ WHERE NOT EXISTS (
+ SELECT * FROM t3
+ WHERE ROW(5*t3.s1,77)=(
+ SELECT 50,11*s1 FROM t4
+ )
+ )
+)
+"""
+ )
+ == {Table("t1"), Table("t3"), Table("t4"), Table("t6")}
+ )
+
+ assert (
+ extract_tables(
+ """
+SELECT * FROM (
+ SELECT * FROM (
+ SELECT * FROM (
+ SELECT * FROM EmployeeS
+ ) AS S1
+ ) AS S2
+) AS S3
+"""
+ )
+ == {Table("EmployeeS")}
+ )
+
+
+def test_extract_tables_with() -> None:
+ """
+ Test ``WITH``.
+ """
+ assert (
+ extract_tables(
+ """
+WITH
+ x AS (SELECT a FROM t1),
+ y AS (SELECT a AS b FROM t2),
+ z AS (SELECT b AS c FROM t3)
+SELECT c FROM z
+"""
+ )
+ == {Table("t1"), Table("t2"), Table("t3")}
+ )
+
+ assert (
+ extract_tables(
+ """
+WITH
+ x AS (SELECT a FROM t1),
+ y AS (SELECT a AS b FROM x),
+ z AS (SELECT b AS c FROM y)
+SELECT c FROM z
+"""
+ )
+ == {Table("t1")}
+ )
+
+
+def test_extract_tables_reusing_aliases() -> None:
+ """
+ Test that the parser follows aliases.
+ """
+ assert (
+ extract_tables(
+ """
+with q1 as ( select key from q2 where key = '5'),
+q2 as ( select key from src where key = '5')
+select * from (select key from q1) a
+"""
+ )
+ == {Table("src")}
+ )
+
+
+def test_extract_tables_multistatement() -> None:
+ """
+ Test that the parser works with multiple statements.
+ """
+ assert extract_tables("SELECT * FROM t1; SELECT * FROM t2") == {
+ Table("t1"),
+ Table("t2"),
+ }
+ assert extract_tables("SELECT * FROM t1; SELECT * FROM t2;") == {
+ Table("t1"),
+ Table("t2"),
+ }
+
+
+def test_extract_tables_complex() -> None:
+ """
+ Test a few complex queries.
+ """
+ assert (
+ extract_tables(
+ """
+SELECT sum(m_examples) AS "sum__m_example"
+FROM (
+ SELECT
+ COUNT(DISTINCT id_userid) AS m_examples,
+ some_more_info
+ FROM my_b_table b
+ JOIN my_t_table t ON b.ds=t.ds
+ JOIN my_l_table l ON b.uid=l.uid
+ WHERE
+ b.rid IN (
+ SELECT other_col
+ FROM inner_table
+ )
+ AND l.bla IN ('x', 'y')
+ GROUP BY 2
+ ORDER BY 2 ASC
+) AS "meh"
+ORDER BY "sum__m_example" DESC
+LIMIT 10;
+"""
+ )
+ == {
+ Table("my_l_table"),
+ Table("my_b_table"),
+ Table("my_t_table"),
+ Table("inner_table"),
+ }
+ )
+
+ assert (
+ extract_tables(
+ """
+SELECT *
+FROM table_a AS a, table_b AS b, table_c as c
+WHERE a.id = b.id and b.id = c.id
+"""
+ )
+ == {Table("table_a"), Table("table_b"), Table("table_c")}
+ )
+
+ assert (
+ extract_tables(
+ """
+SELECT somecol AS somecol
+FROM (
+ WITH bla AS (
+ SELECT col_a
+ FROM a
+ WHERE
+ 1=1
+ AND column_of_choice NOT IN (
+ SELECT interesting_col
+ FROM b
+ )
+ ),
+ rb AS (
+ SELECT yet_another_column
+ FROM (
+ SELECT a
+ FROM c
+ GROUP BY the_other_col
+ ) not_table
+ LEFT JOIN bla foo
+ ON foo.prop = not_table.bad_col0
+ WHERE 1=1
+ GROUP BY
+ not_table.bad_col1 ,
+ not_table.bad_col2 ,
+ ORDER BY not_table.bad_col_3 DESC ,
+ not_table.bad_col4 ,
+ not_table.bad_col5
+ )
+ SELECT random_col
+ FROM d
+ WHERE 1=1
+ UNION ALL SELECT even_more_cols
+ FROM e
+ WHERE 1=1
+ UNION ALL SELECT lets_go_deeper
+ FROM f
+ WHERE 1=1
+ WHERE 2=2
+ GROUP BY last_col
+ LIMIT 50000
+)
+"""
+ )
+ == {Table("a"), Table("b"), Table("c"), Table("d"), Table("e"), Table("f")}
+ )
+
+
+def test_extract_tables_mixed_from_clause() -> None:
+ """
+ Test that the parser handles a ``FROM`` clause with table and subselect.
+ """
+ assert (
+ extract_tables(
+ """
+SELECT *
+FROM table_a AS a, (select * from table_b) AS b, table_c as c
+WHERE a.id = b.id and b.id = c.id
+"""
+ )
+ == {Table("table_a"), Table("table_b"), Table("table_c")}
+ )
+
+
+def test_extract_tables_nested_select() -> None:
+ """
+ Test that the parser handles selects inside functions.
+ """
+ assert (
+ extract_tables(
+ """
+select (extractvalue(1,concat(0x7e,(select GROUP_CONCAT(TABLE_NAME)
+from INFORMATION_SCHEMA.COLUMNS
+WHERE TABLE_SCHEMA like "%bi%"),0x7e)));
+"""
+ )
+ == {Table("COLUMNS", "INFORMATION_SCHEMA")}
+ )
+
+ assert (
+ extract_tables(
+ """
+select (extractvalue(1,concat(0x7e,(select GROUP_CONCAT(COLUMN_NAME)
+from INFORMATION_SCHEMA.COLUMNS
+WHERE TABLE_NAME="bi_achievement_daily"),0x7e)));
+"""
+ )
+ == {Table("COLUMNS", "INFORMATION_SCHEMA")}
+ )
+
+
+def test_extract_tables_complex_cte_with_prefix() -> None:
+ """
+ Test that the parser handles CTEs with prefixes.
+ """
+ assert (
+ extract_tables(
+ """
+WITH CTE__test (SalesPersonID, SalesOrderID, SalesYear)
+AS (
+ SELECT SalesPersonID, SalesOrderID, YEAR(OrderDate) AS SalesYear
+ FROM SalesOrderHeader
+ WHERE SalesPersonID IS NOT NULL
+)
+SELECT SalesPersonID, COUNT(SalesOrderID) AS TotalSales, SalesYear
+FROM CTE__test
+GROUP BY SalesYear, SalesPersonID
+ORDER BY SalesPersonID, SalesYear;
+"""
+ )
+ == {Table("SalesOrderHeader")}
+ )
+
+
+def test_extract_tables_identifier_list_with_keyword_as_alias() -> None:
+ """
+ Test that aliases that are keywords are parsed correctly.
+ """
+ assert (
+ extract_tables(
+ """
+WITH
+ f AS (SELECT * FROM foo),
+ match AS (SELECT * FROM f)
+SELECT * FROM match
+"""
+ )
+ == {Table("foo")}
+ )
+
+
+def test_update() -> None:
+ """
+ Test that ``UPDATE`` is not detected as ``SELECT``.
+ """
+ assert ParsedQuery("UPDATE t1 SET col1 = NULL").is_select() is False
+
+
+def test_set() -> None:
+ """
+ Test that ``SET`` is detected correctly.
+ """
+ query = ParsedQuery(
+ """
+-- comment
+SET hivevar:desc='Legislators';
+"""
+ )
+ assert query.is_set() is True
+ assert query.is_select() is False
+
+ assert ParsedQuery("set hivevar:desc='bla'").is_set() is True
+ assert ParsedQuery("SELECT 1").is_set() is False
+
+
+def test_show() -> None:
+ """
+ Test that ``SHOW`` is detected correctly.
+ """
+ query = ParsedQuery(
+ """
+-- comment
+SHOW LOCKS test EXTENDED;
+-- comment
+"""
+ )
+ assert query.is_show() is True
+ assert query.is_select() is False
+
+ assert ParsedQuery("SHOW TABLES").is_show() is True
+ assert ParsedQuery("shOw TABLES").is_show() is True
+ assert ParsedQuery("show TABLES").is_show() is True
+ assert ParsedQuery("SELECT 1").is_show() is False
+
+
+def test_is_explain() -> None:
+ """
+ Test that ``EXPLAIN`` is detected correctly.
+ """
+ assert ParsedQuery("EXPLAIN SELECT 1").is_explain() is True
+ assert ParsedQuery("EXPLAIN SELECT 1").is_select() is False
+
+ assert (
+ ParsedQuery(
+ """
+-- comment
+EXPLAIN select * from table
+-- comment 2
+"""
+ ).is_explain()
+ is True
+ )
+
+ assert (
+ ParsedQuery(
+ """
+-- comment
+EXPLAIN select * from table
+where col1 = 'something'
+-- comment 2
+
+-- comment 3
+EXPLAIN select * from table
+where col1 = 'something'
+-- comment 4
+"""
+ ).is_explain()
+ is True
+ )
+
+ assert (
+ ParsedQuery(
+ """
+-- This is a comment
+ -- this is another comment but with a space in the front
+EXPLAIN SELECT * FROM TABLE
+"""
+ ).is_explain()
+ is True
+ )
+
+ assert (
+ ParsedQuery(
+ """
+/* This is a comment
+ with stars instead */
+EXPLAIN SELECT * FROM TABLE
+"""
+ ).is_explain()
+ is True
+ )
+
+ assert (
+ ParsedQuery(
+ """
+-- comment
+select * from table
+where col1 = 'something'
+-- comment 2
+"""
+ ).is_explain()
+ is False
+ )
+
+
+def test_is_valid_ctas() -> None:
+ """
+ Test if a query is a valid CTAS.
+
+ A valid CTAS has a ``SELECT`` as its last statement.
+ """
+ assert (
+ ParsedQuery("SELECT * FROM table", strip_comments=True).is_valid_ctas() is True
+ )
+
+ assert (
+ ParsedQuery(
+ """
+-- comment
+SELECT * FROM table
+-- comment 2
+""",
+ strip_comments=True,
+ ).is_valid_ctas()
+ is True
+ )
+
+ assert (
+ ParsedQuery(
+ """
+-- comment
+SET @value = 42;
+SELECT @value as foo;
+-- comment 2
+""",
+ strip_comments=True,
+ ).is_valid_ctas()
+ is True
+ )
+
+ assert (
+ ParsedQuery(
+ """
+-- comment
+EXPLAIN SELECT * FROM table
+-- comment 2
+""",
+ strip_comments=True,
+ ).is_valid_ctas()
+ is False
+ )
+
+ assert (
+ ParsedQuery(
+ """
+SELECT * FROM table;
+INSERT INTO TABLE (foo) VALUES (42);
+""",
+ strip_comments=True,
+ ).is_valid_ctas()
+ is False
+ )
+
+
+def test_is_valid_cvas() -> None:
+ """
+ Test if a query is a valid CVAS.
+
+ A valid CVAS has a single ``SELECT`` statement.
+ """
+ assert (
+ ParsedQuery("SELECT * FROM table", strip_comments=True).is_valid_cvas() is True
+ )
+
+ assert (
+ ParsedQuery(
+ """
+-- comment
+SELECT * FROM table
+-- comment 2
+""",
+ strip_comments=True,
+ ).is_valid_cvas()
+ is True
+ )
+
+ assert (
+ ParsedQuery(
+ """
+-- comment
+SET @value = 42;
+SELECT @value as foo;
+-- comment 2
+""",
+ strip_comments=True,
+ ).is_valid_cvas()
+ is False
+ )
+
+ assert (
+ ParsedQuery(
+ """
+-- comment
+EXPLAIN SELECT * FROM table
+-- comment 2
+""",
+ strip_comments=True,
+ ).is_valid_cvas()
+ is False
+ )
+
+ assert (
+ ParsedQuery(
+ """
+SELECT * FROM table;
+INSERT INTO TABLE (foo) VALUES (42);
+""",
+ strip_comments=True,
+ ).is_valid_cvas()
+ is False
+ )
+
+
+def test_is_select_cte_with_comments() -> None:
+ """
+ Some CTES with comments are not correctly identified as SELECTS.
+ """
+ sql = ParsedQuery(
+ """WITH blah AS
+ (SELECT * FROM core_dev.manager_team),
+
+blah2 AS
+ (SELECT * FROM core_dev.manager_workspace)
+
+SELECT * FROM blah
+INNER JOIN blah2 ON blah2.team_id = blah.team_id"""
+ )
+ assert sql.is_select()
+
+ sql = ParsedQuery(
+ """WITH blah AS
+/*blahblahbalh*/
+ (SELECT * FROM core_dev.manager_team),
+--blahblahbalh
+
+blah2 AS
+ (SELECT * FROM core_dev.manager_workspace)
+
+SELECT * FROM blah
+INNER JOIN blah2 ON blah2.team_id = blah.team_id"""
+ )
+ assert sql.is_select()
+
+
+def test_cte_is_select() -> None:
+ """
+ Some CTEs are not correctly identified as SELECTS.
+ """
+ # `AS(` gets parsed as a function
+ sql = ParsedQuery(
+ """WITH foo AS(
+SELECT
+ FLOOR(__time TO WEEK) AS "week",
+ name,
+ COUNT(DISTINCT user_id) AS "unique_users"
+FROM "druid"."my_table"
+GROUP BY 1,2
+)
+SELECT
+ f.week,
+ f.name,
+ f.unique_users
+FROM foo f"""
+ )
+ assert sql.is_select()
+
+
+def test_unknown_select() -> None:
+ """
+ Test that `is_select` works when sqlparse fails to identify the type.
+ """
+ sql = "WITH foo AS(SELECT 1) SELECT 1"
+ assert sqlparse.parse(sql)[0].get_type() == "SELECT"
+ assert ParsedQuery(sql).is_select()
+
+ sql = "WITH foo AS(SELECT 1) INSERT INTO my_table (a) VALUES (1)"
+ assert sqlparse.parse(sql)[0].get_type() == "INSERT"
+ assert not ParsedQuery(sql).is_select()
+
+ sql = "WITH foo AS(SELECT 1) DELETE FROM my_table"
+ assert sqlparse.parse(sql)[0].get_type() == "DELETE"
+ assert not ParsedQuery(sql).is_select()
+
+
+def test_get_query_with_new_limit_comment() -> None:
+ """
+ Test that limit is applied correctly.
+ """
+ query = ParsedQuery("SELECT * FROM birth_names -- SOME COMMENT")
+ assert query.set_or_update_query_limit(1000) == (
+ "SELECT * FROM birth_names -- SOME COMMENT\nLIMIT 1000"
+ )
+
+
+def test_get_query_with_new_limit_comment_with_limit() -> None:
+ """
+ Test that limits in comments are ignored.
+ """
+ query = ParsedQuery("SELECT * FROM birth_names -- SOME COMMENT WITH LIMIT 555")
+ assert query.set_or_update_query_limit(1000) == (
+ "SELECT * FROM birth_names -- SOME COMMENT WITH LIMIT 555\nLIMIT 1000"
+ )
+
+
+def test_get_query_with_new_limit_lower() -> None:
+ """
+ Test that lower limits are not replaced.
+ """
+ query = ParsedQuery("SELECT * FROM birth_names LIMIT 555")
+ assert query.set_or_update_query_limit(1000) == (
+ "SELECT * FROM birth_names LIMIT 555"
+ )
+
+
+def test_get_query_with_new_limit_upper() -> None:
+ """
+ Test that higher limits are replaced.
+ """
+ query = ParsedQuery("SELECT * FROM birth_names LIMIT 2000")
+ assert query.set_or_update_query_limit(1000) == (
+ "SELECT * FROM birth_names LIMIT 1000"
+ )
+
+
+def test_basic_breakdown_statements() -> None:
+ """
+ Test that multiple statements are parsed correctly.
+ """
+ query = ParsedQuery(
+ """
+SELECT * FROM birth_names;
+SELECT * FROM birth_names LIMIT 1;
+"""
+ )
+ assert query.get_statements() == [
+ "SELECT * FROM birth_names",
+ "SELECT * FROM birth_names LIMIT 1",
+ ]
+
+
+def test_messy_breakdown_statements() -> None:
+ """
+ Test the messy multiple statements are parsed correctly.
+ """
+ query = ParsedQuery(
+ """
+SELECT 1;\t\n\n\n \t
+\t\nSELECT 2;
+SELECT * FROM birth_names;;;
+SELECT * FROM birth_names LIMIT 1
+"""
+ )
+ assert query.get_statements() == [
+ "SELECT 1",
+ "SELECT 2",
+ "SELECT * FROM birth_names",
+ "SELECT * FROM birth_names LIMIT 1",
+ ]
+
+
+def test_sqlparse_formatting():
+ """
+ Test that ``from_unixtime`` is formatted correctly.
+ """
+ assert sqlparse.format(
+ "SELECT extract(HOUR from from_unixtime(hour_ts) "
+ "AT TIME ZONE 'America/Los_Angeles') from table",
+ reindent=True,
+ ) == (
+ "SELECT extract(HOUR\n from from_unixtime(hour_ts) "
+ "AT TIME ZONE 'America/Los_Angeles')\nfrom table"
+ )
+
+
+def test_strip_comments_from_sql() -> None:
+ """
+ Test that comments are stripped out correctly.
+ """
+ assert (
+ strip_comments_from_sql("SELECT col1, col2 FROM table1")
+ == "SELECT col1, col2 FROM table1"
+ )
+ assert (
+ strip_comments_from_sql("SELECT col1, col2 FROM table1\n-- comment")
+ == "SELECT col1, col2 FROM table1\n"
+ )
+ assert (
+ strip_comments_from_sql("SELECT '--abc' as abc, col2 FROM table1\n")
+ == "SELECT '--abc' as abc, col2 FROM table1"
+ )
+
+
+def test_sanitize_clause_valid():
+ # regular clauses
+ assert sanitize_clause("col = 1") == "col = 1"
+ assert sanitize_clause("1=\t\n1") == "1=\t\n1"
+ assert sanitize_clause("(col = 1)") == "(col = 1)"
+ assert sanitize_clause("(col1 = 1) AND (col2 = 2)") == "(col1 = 1) AND (col2 = 2)"
+ assert sanitize_clause("col = 'abc' -- comment") == "col = 'abc' -- comment\n"
+
+ # Valid literal values that at could be flagged as invalid by a naive query parser
+ assert (
+ sanitize_clause("col = 'col1 = 1) AND (col2 = 2'")
+ == "col = 'col1 = 1) AND (col2 = 2'"
+ )
+ assert sanitize_clause("col = 'select 1; select 2'") == "col = 'select 1; select 2'"
+ assert sanitize_clause("col = 'abc -- comment'") == "col = 'abc -- comment'"
+
+
+def test_sanitize_clause_closing_unclosed():
+ with pytest.raises(QueryClauseValidationException):
+ sanitize_clause("col1 = 1) AND (col2 = 2)")
+
+
+def test_sanitize_clause_unclosed():
+ with pytest.raises(QueryClauseValidationException):
+ sanitize_clause("(col1 = 1) AND (col2 = 2")
+
+
+def test_sanitize_clause_closing_and_unclosed():
+ with pytest.raises(QueryClauseValidationException):
+ sanitize_clause("col1 = 1) AND (col2 = 2")
+
+
+def test_sanitize_clause_closing_and_unclosed_nested():
+ with pytest.raises(QueryClauseValidationException):
+ sanitize_clause("(col1 = 1)) AND ((col2 = 2)")
+
+
+def test_sanitize_clause_multiple():
+ with pytest.raises(QueryClauseValidationException):
+ sanitize_clause("TRUE; SELECT 1")
+
+
+def test_sqlparse_issue_652():
+ stmt = sqlparse.parse(r"foo = '\' AND bar = 'baz'")[0]
+ assert len(stmt.tokens) == 5
+ assert str(stmt.tokens[0]) == "foo = '\\'"
+
+
+@pytest.mark.parametrize(
+ "sql,expected",
+ [
+ ("SELECT * FROM table", True),
+ ("SELECT a FROM (SELECT 1 AS a) JOIN (SELECT * FROM table)", True),
+ ("(SELECT COUNT(DISTINCT name) AS foo FROM birth_names)", True),
+ ("COUNT(*)", False),
+ ("SELECT a FROM (SELECT 1 AS a)", False),
+ ("SELECT a FROM (SELECT 1 AS a) JOIN table", True),
+ ("SELECT * FROM (SELECT 1 AS foo, 2 AS bar) ORDER BY foo ASC, bar", False),
+ ("SELECT * FROM other_table", True),
+ ("extract(HOUR from from_unixtime(hour_ts)", False),
+ ("(SELECT * FROM table)", True),
+ ("(SELECT COUNT(DISTINCT name) from birth_names)", True),
+ ],
+)
+def test_has_table_query(sql: str, expected: bool) -> None:
+ """
+ Test if a given statement queries a table.
+
+ This is used to prevent ad-hoc metrics from querying unauthorized tables, bypassing
+ row-level security.
+ """
+ statement = sqlparse.parse(sql)[0]
+ assert has_table_query(statement) == expected
+
+
+@pytest.mark.parametrize(
+ "sql,table,rls,expected",
+ [
+ # Basic test: append RLS (some_table.id=42) to an existing WHERE clause.
+ (
+ "SELECT * FROM some_table WHERE 1=1",
+ "some_table",
+ "id=42",
+ "SELECT * FROM some_table WHERE ( 1=1) AND some_table.id=42",
+ ),
+ # Any existing predicates MUST to be wrapped in parenthesis because AND has higher
+ # precedence than OR. If the RLS it `1=0` and we didn't add parenthesis a user
+ # could bypass it by crafting a query with `WHERE TRUE OR FALSE`, since
+ # `WHERE TRUE OR FALSE AND 1=0` evaluates to `WHERE TRUE OR (FALSE AND 1=0)`.
+ (
+ "SELECT * FROM some_table WHERE TRUE OR FALSE",
+ "some_table",
+ "1=0",
+ "SELECT * FROM some_table WHERE ( TRUE OR FALSE) AND 1=0",
+ ),
+ # Here "table" is a reserved word; since sqlparse is too aggressive when
+ # characterizing reserved words we need to support them even when not quoted.
+ (
+ "SELECT * FROM table WHERE 1=1",
+ "table",
+ "id=42",
+ "SELECT * FROM table WHERE ( 1=1) AND table.id=42",
+ ),
+ # RLS is only applied to queries reading from the associated table.
+ (
+ "SELECT * FROM table WHERE 1=1",
+ "other_table",
+ "id=42",
+ "SELECT * FROM table WHERE 1=1",
+ ),
+ (
+ "SELECT * FROM other_table WHERE 1=1",
+ "table",
+ "id=42",
+ "SELECT * FROM other_table WHERE 1=1",
+ ),
+ # If there's no pre-existing WHERE clause we create one.
+ (
+ "SELECT * FROM table",
+ "table",
+ "id=42",
+ "SELECT * FROM table WHERE table.id=42",
+ ),
+ (
+ "SELECT * FROM some_table",
+ "some_table",
+ "id=42",
+ "SELECT * FROM some_table WHERE some_table.id=42",
+ ),
+ (
+ "SELECT * FROM table ORDER BY id",
+ "table",
+ "id=42",
+ "SELECT * FROM table WHERE table.id=42 ORDER BY id",
+ ),
+ (
+ "SELECT * FROM some_table;",
+ "some_table",
+ "id=42",
+ "SELECT * FROM some_table WHERE some_table.id=42 ;",
+ ),
+ (
+ "SELECT * FROM some_table ;",
+ "some_table",
+ "id=42",
+ "SELECT * FROM some_table WHERE some_table.id=42 ;",
+ ),
+ (
+ "SELECT * FROM some_table ",
+ "some_table",
+ "id=42",
+ "SELECT * FROM some_table WHERE some_table.id=42",
+ ),
+ # We add the RLS even if it's already present, to be conservative. It should have
+ # no impact on the query, and it's easier than testing if the RLS is already
+ # present (it could be present in an OR clause, eg).
+ (
+ "SELECT * FROM table WHERE 1=1 AND table.id=42",
+ "table",
+ "id=42",
+ "SELECT * FROM table WHERE ( 1=1 AND table.id=42) AND table.id=42",
+ ),
+ (
+ (
+ "SELECT * FROM table JOIN other_table ON "
+ "table.id = other_table.id AND other_table.id=42"
+ ),
+ "other_table",
+ "id=42",
+ (
+ "SELECT * FROM table JOIN other_table ON other_table.id=42 "
+ "AND ( table.id = other_table.id AND other_table.id=42 )"
+ ),
+ ),
+ (
+ "SELECT * FROM table WHERE 1=1 AND id=42",
+ "table",
+ "id=42",
+ "SELECT * FROM table WHERE ( 1=1 AND id=42) AND table.id=42",
+ ),
+ # For joins we apply the RLS to the ON clause, since it's easier and prevents
+ # leaking information about number of rows on OUTER JOINs.
+ (
+ "SELECT * FROM table JOIN other_table ON table.id = other_table.id",
+ "other_table",
+ "id=42",
+ (
+ "SELECT * FROM table JOIN other_table ON other_table.id=42 "
+ "AND ( table.id = other_table.id )"
+ ),
+ ),
+ (
+ (
+ "SELECT * FROM table JOIN other_table ON table.id = other_table.id "
+ "WHERE 1=1"
+ ),
+ "other_table",
+ "id=42",
+ (
+ "SELECT * FROM table JOIN other_table ON other_table.id=42 "
+ "AND ( table.id = other_table.id ) WHERE 1=1"
+ ),
+ ),
+ # Subqueries also work, as expected.
+ (
+ "SELECT * FROM (SELECT * FROM other_table)",
+ "other_table",
+ "id=42",
+ "SELECT * FROM (SELECT * FROM other_table WHERE other_table.id=42 )",
+ ),
+ # As well as UNION.
+ (
+ "SELECT * FROM table UNION ALL SELECT * FROM other_table",
+ "table",
+ "id=42",
+ "SELECT * FROM table WHERE table.id=42 UNION ALL SELECT * FROM other_table",
+ ),
+ (
+ "SELECT * FROM table UNION ALL SELECT * FROM other_table",
+ "other_table",
+ "id=42",
+ (
+ "SELECT * FROM table UNION ALL "
+ "SELECT * FROM other_table WHERE other_table.id=42"
+ ),
+ ),
+ # When comparing fully qualified table names (eg, schema.table) to simple names
+ # (eg, table) we are also conservative, assuming the schema is the same, since
+ # we don't have information on the default schema.
+ (
+ "SELECT * FROM schema.table_name",
+ "table_name",
+ "id=42",
+ "SELECT * FROM schema.table_name WHERE table_name.id=42",
+ ),
+ (
+ "SELECT * FROM schema.table_name",
+ "schema.table_name",
+ "id=42",
+ "SELECT * FROM schema.table_name WHERE schema.table_name.id=42",
+ ),
+ (
+ "SELECT * FROM table_name",
+ "schema.table_name",
+ "id=42",
+ "SELECT * FROM table_name WHERE schema.table_name.id=42",
+ ),
+ ],
+)
+def test_insert_rls(
+ mocker: MockerFixture, sql: str, table: str, rls: str, expected: str
+) -> None:
+ """
+ Insert into a statement a given RLS condition associated with a table.
+ """
+ condition = sqlparse.parse(rls)[0]
+ add_table_name(condition, table)
+
+ # pylint: disable=unused-argument
+ def get_rls_for_table(
+ candidate: Token,
+ database_id: int,
+ default_schema: str,
+ ) -> Optional[TokenList]:
+ """
+ Return the RLS ``condition`` if ``candidate`` matches ``table``.
+ """
+ # compare ignoring schema
+ for left, right in zip(str(candidate).split(".")[::-1], table.split(".")[::-1]):
+ if left != right:
+ return None
+ return condition
+
+ mocker.patch("superset.sql_parse.get_rls_for_table", new=get_rls_for_table)
+
+ statement = sqlparse.parse(sql)[0]
+ assert (
+ str(
+ insert_rls(token_list=statement, database_id=1, default_schema="my_schema")
+ ).strip()
+ == expected.strip()
+ )
+
+
+@pytest.mark.parametrize(
+ "rls,table,expected",
+ [
+ ("id=42", "users", "users.id=42"),
+ ("users.id=42", "users", "users.id=42"),
+ ("schema.users.id=42", "users", "schema.users.id=42"),
+ ("false", "users", "false"),
+ ],
+)
+def test_add_table_name(rls: str, table: str, expected: str) -> None:
+ condition = sqlparse.parse(rls)[0]
+ add_table_name(condition, table)
+ assert str(condition) == expected
+
+
+def test_get_rls_for_table(mocker: MockerFixture) -> None:
+ """
+ Tests for ``get_rls_for_table``.
+ """
+ candidate = Identifier([Token(Name, "some_table")])
+ db = mocker.patch("superset.db")
+ dataset = db.session.query().filter().one_or_none()
+ dataset.__str__.return_value = "some_table"
+
+ dataset.get_sqla_row_level_filters.return_value = [text("organization_id = 1")]
+ assert (
+ str(get_rls_for_table(candidate, 1, "public"))
+ == "some_table.organization_id = 1"
+ )
+
+ dataset.get_sqla_row_level_filters.return_value = [
+ text("organization_id = 1"),
+ text("foo = 'bar'"),
+ ]
+ assert (
+ str(get_rls_for_table(candidate, 1, "public"))
+ == "some_table.organization_id = 1 AND some_table.foo = 'bar'"
+ )
+
+ dataset.get_sqla_row_level_filters.return_value = []
+ assert get_rls_for_table(candidate, 1, "public") is None
+
+
+def test_extract_table_references(mocker: MockerFixture) -> None:
+ """
+ Test the ``extract_table_references`` helper function.
+ """
+ assert extract_table_references("SELECT 1", "trino") == set()
+ assert extract_table_references("SELECT 1 FROM some_table", "trino") == {
+ Table(table="some_table", schema=None, catalog=None)
+ }
+ assert extract_table_references("SELECT {{ jinja }} FROM some_table", "trino") == {
+ Table(table="some_table", schema=None, catalog=None)
+ }
+ assert extract_table_references(
+ "SELECT 1 FROM some_catalog.some_schema.some_table", "trino"
+ ) == {Table(table="some_table", schema="some_schema", catalog="some_catalog")}
+
+ # with identifier quotes
+ assert extract_table_references(
+ "SELECT 1 FROM `some_catalog`.`some_schema`.`some_table`", "mysql"
+ ) == {Table(table="some_table", schema="some_schema", catalog="some_catalog")}
+ assert extract_table_references(
+ 'SELECT 1 FROM "some_catalog".some_schema."some_table"', "trino"
+ ) == {Table(table="some_table", schema="some_schema", catalog="some_catalog")}
+
+ assert extract_table_references(
+ "SELECT * FROM some_table JOIN other_table ON some_table.id = other_table.id",
+ "trino",
+ ) == {
+ Table(table="some_table", schema=None, catalog=None),
+ Table(table="other_table", schema=None, catalog=None),
+ }
+
+ # test falling back to sqlparse
+ logger = mocker.patch("superset.sql_parse.logger")
+ sql = "SELECT * FROM table UNION ALL SELECT * FROM other_table"
+ assert extract_table_references(
+ sql,
+ "trino",
+ ) == {Table(table="other_table", schema=None, catalog=None)}
+ logger.warning.assert_called_once()
+
+ logger = mocker.patch("superset.migrations.shared.utils.logger")
+ sql = "SELECT * FROM table UNION ALL SELECT * FROM other_table"
+ assert extract_table_references(sql, "trino", show_warning=False) == {
+ Table(table="other_table", schema=None, catalog=None)
+ }
+ logger.warning.assert_not_called()
diff --git a/tests/unit_tests/tables/__init__.py b/tests/unit_tests/tables/__init__.py
new file mode 100644
index 0000000000000..13a83393a9124
--- /dev/null
+++ b/tests/unit_tests/tables/__init__.py
@@ -0,0 +1,16 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
diff --git a/tests/unit_tests/tables/test_models.py b/tests/unit_tests/tables/test_models.py
new file mode 100644
index 0000000000000..7705dba6aa09d
--- /dev/null
+++ b/tests/unit_tests/tables/test_models.py
@@ -0,0 +1,56 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+# pylint: disable=import-outside-toplevel, unused-argument
+
+from sqlalchemy.orm.session import Session
+
+
+def test_table_model(session: Session) -> None:
+ """
+ Test basic attributes of a ``Table``.
+ """
+ from superset.columns.models import Column
+ from superset.models.core import Database
+ from superset.tables.models import Table
+
+ engine = session.get_bind()
+ Table.metadata.create_all(engine) # pylint: disable=no-member
+
+ table = Table(
+ name="my_table",
+ schema="my_schema",
+ catalog="my_catalog",
+ database=Database(database_name="my_database", sqlalchemy_uri="test://"),
+ columns=[
+ Column(
+ name="ds",
+ type="TIMESTAMP",
+ expression="ds",
+ )
+ ],
+ )
+ session.add(table)
+ session.flush()
+
+ assert table.id == 1
+ assert table.uuid is not None
+ assert table.database_id == 1
+ assert table.catalog == "my_catalog"
+ assert table.schema == "my_schema"
+ assert table.name == "my_table"
+ assert [column.name for column in table.columns] == ["ds"]
diff --git a/tests/unit_tests/tasks/__init__.py b/tests/unit_tests/tasks/__init__.py
new file mode 100644
index 0000000000000..13a83393a9124
--- /dev/null
+++ b/tests/unit_tests/tasks/__init__.py
@@ -0,0 +1,16 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
diff --git a/tests/unit_tests/tasks/test_cron_util.py b/tests/unit_tests/tasks/test_cron_util.py
new file mode 100644
index 0000000000000..d0f9ae21705e2
--- /dev/null
+++ b/tests/unit_tests/tasks/test_cron_util.py
@@ -0,0 +1,212 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from datetime import datetime
+from typing import List
+
+import pytest
+import pytz
+from dateutil import parser
+from freezegun import freeze_time
+from freezegun.api import FakeDatetime # type: ignore
+
+from superset.tasks.cron_util import cron_schedule_window
+
+
+@pytest.mark.parametrize(
+ "current_dttm, cron, expected",
+ [
+ ("2020-01-01T08:59:01Z", "0 1 * * *", []),
+ (
+ "2020-01-01T08:59:02Z",
+ "0 1 * * *",
+ [FakeDatetime(2020, 1, 1, 9, 0).strftime("%A, %d %B %Y, %H:%M:%S")],
+ ),
+ (
+ "2020-01-01T08:59:59Z",
+ "0 1 * * *",
+ [FakeDatetime(2020, 1, 1, 9, 0).strftime("%A, %d %B %Y, %H:%M:%S")],
+ ),
+ (
+ "2020-01-01T09:00:00Z",
+ "0 1 * * *",
+ [FakeDatetime(2020, 1, 1, 9, 0).strftime("%A, %d %B %Y, %H:%M:%S")],
+ ),
+ ("2020-01-01T09:00:01Z", "0 1 * * *", []),
+ ],
+)
+def test_cron_schedule_window_los_angeles(
+ current_dttm: str, cron: str, expected: List[FakeDatetime]
+) -> None:
+ """
+ Reports scheduler: Test cron schedule window for "America/Los_Angeles"
+ """
+
+ with freeze_time(current_dttm):
+ datetimes = cron_schedule_window(cron, "America/Los_Angeles")
+ assert (
+ list(cron.strftime("%A, %d %B %Y, %H:%M:%S") for cron in datetimes)
+ == expected
+ )
+
+
+@pytest.mark.parametrize(
+ "current_dttm, cron, expected",
+ [
+ ("2020-01-01T00:59:01Z", "0 1 * * *", []),
+ (
+ "2020-01-01T00:59:02Z",
+ "0 1 * * *",
+ [FakeDatetime(2020, 1, 1, 1, 0).strftime("%A, %d %B %Y, %H:%M:%S")],
+ ),
+ (
+ "2020-01-01T00:59:59Z",
+ "0 1 * * *",
+ [FakeDatetime(2020, 1, 1, 1, 0).strftime("%A, %d %B %Y, %H:%M:%S")],
+ ),
+ (
+ "2020-01-01T01:00:00Z",
+ "0 1 * * *",
+ [FakeDatetime(2020, 1, 1, 1, 0).strftime("%A, %d %B %Y, %H:%M:%S")],
+ ),
+ ("2020-01-01T01:00:01Z", "0 1 * * *", []),
+ ],
+)
+def test_cron_schedule_window_invalid_timezone(
+ current_dttm: str, cron: str, expected: List[FakeDatetime]
+) -> None:
+ """
+ Reports scheduler: Test cron schedule window for "invalid timezone"
+ """
+
+ with freeze_time(current_dttm):
+ datetimes = cron_schedule_window(cron, "invalid timezone")
+ # it should default to UTC
+ assert (
+ list(cron.strftime("%A, %d %B %Y, %H:%M:%S") for cron in datetimes)
+ == expected
+ )
+
+
+@pytest.mark.parametrize(
+ "current_dttm, cron, expected",
+ [
+ ("2020-01-01T05:59:01Z", "0 1 * * *", []),
+ (
+ "2020-01-01T05:59:02Z",
+ "0 1 * * *",
+ [FakeDatetime(2020, 1, 1, 6, 0).strftime("%A, %d %B %Y, %H:%M:%S")],
+ ),
+ (
+ "2020-01-01T5:59:59Z",
+ "0 1 * * *",
+ [FakeDatetime(2020, 1, 1, 6, 0).strftime("%A, %d %B %Y, %H:%M:%S")],
+ ),
+ (
+ "2020-01-01T6:00:00",
+ "0 1 * * *",
+ [FakeDatetime(2020, 1, 1, 6, 0).strftime("%A, %d %B %Y, %H:%M:%S")],
+ ),
+ ("2020-01-01T6:00:01Z", "0 1 * * *", []),
+ ],
+)
+def test_cron_schedule_window_new_york(
+ current_dttm: str, cron: str, expected: List[FakeDatetime]
+) -> None:
+ """
+ Reports scheduler: Test cron schedule window for "America/New_York"
+ """
+
+ with freeze_time(current_dttm, tz_offset=0):
+ datetimes = cron_schedule_window(cron, "America/New_York")
+ assert (
+ list(cron.strftime("%A, %d %B %Y, %H:%M:%S") for cron in datetimes)
+ == expected
+ )
+
+
+@pytest.mark.parametrize(
+ "current_dttm, cron, expected",
+ [
+ ("2020-01-01T06:59:01Z", "0 1 * * *", []),
+ (
+ "2020-01-01T06:59:02Z",
+ "0 1 * * *",
+ [FakeDatetime(2020, 1, 1, 7, 0).strftime("%A, %d %B %Y, %H:%M:%S")],
+ ),
+ (
+ "2020-01-01T06:59:59Z",
+ "0 1 * * *",
+ [FakeDatetime(2020, 1, 1, 7, 0).strftime("%A, %d %B %Y, %H:%M:%S")],
+ ),
+ (
+ "2020-01-01T07:00:00",
+ "0 1 * * *",
+ [FakeDatetime(2020, 1, 1, 7, 0).strftime("%A, %d %B %Y, %H:%M:%S")],
+ ),
+ ("2020-01-01T07:00:01Z", "0 1 * * *", []),
+ ],
+)
+def test_cron_schedule_window_chicago(
+ current_dttm: str, cron: str, expected: List[FakeDatetime]
+) -> None:
+ """
+ Reports scheduler: Test cron schedule window for "America/Chicago"
+ """
+
+ with freeze_time(current_dttm, tz_offset=0):
+ datetimes = cron_schedule_window(cron, "America/Chicago")
+ assert (
+ list(cron.strftime("%A, %d %B %Y, %H:%M:%S") for cron in datetimes)
+ == expected
+ )
+
+
+@pytest.mark.parametrize(
+ "current_dttm, cron, expected",
+ [
+ ("2020-07-01T05:59:01Z", "0 1 * * *", []),
+ (
+ "2020-07-01T05:59:02Z",
+ "0 1 * * *",
+ [FakeDatetime(2020, 7, 1, 6, 0).strftime("%A, %d %B %Y, %H:%M:%S")],
+ ),
+ (
+ "2020-07-01T05:59:59Z",
+ "0 1 * * *",
+ [FakeDatetime(2020, 7, 1, 6, 0).strftime("%A, %d %B %Y, %H:%M:%S")],
+ ),
+ (
+ "2020-07-01T06:00:00",
+ "0 1 * * *",
+ [FakeDatetime(2020, 7, 1, 6, 0).strftime("%A, %d %B %Y, %H:%M:%S")],
+ ),
+ ("2020-07-01T06:00:01Z", "0 1 * * *", []),
+ ],
+)
+def test_cron_schedule_window_chicago_daylight(
+ current_dttm: str, cron: str, expected: List[FakeDatetime]
+) -> None:
+ """
+ Reports scheduler: Test cron schedule window for "America/Chicago"
+ """
+
+ with freeze_time(current_dttm, tz_offset=0):
+ datetimes = cron_schedule_window(cron, "America/Chicago")
+ assert (
+ list(cron.strftime("%A, %d %B %Y, %H:%M:%S") for cron in datetimes)
+ == expected
+ )
diff --git a/tests/unit_tests/tasks/test_utils.py b/tests/unit_tests/tasks/test_utils.py
new file mode 100644
index 0000000000000..7854717201229
--- /dev/null
+++ b/tests/unit_tests/tasks/test_utils.py
@@ -0,0 +1,323 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from contextlib import nullcontext
+from dataclasses import dataclass
+from enum import Enum
+from typing import Any, Dict, List, Optional, Tuple, Type, Union
+
+import pytest
+from flask_appbuilder.security.sqla.models import User
+
+from superset.tasks.exceptions import ExecutorNotFoundError
+from superset.tasks.types import ExecutorType
+
+SELENIUM_USER_ID = 1234
+SELENIUM_USERNAME = "admin"
+
+
+def _get_users(
+ params: Optional[Union[int, List[int]]]
+) -> Optional[Union[User, List[User]]]:
+ if params is None:
+ return None
+ if isinstance(params, int):
+ return User(id=params, username=str(params))
+ return [User(id=user, username=str(user)) for user in params]
+
+
+@dataclass
+class ModelConfig:
+ owners: List[int]
+ creator: Optional[int] = None
+ modifier: Optional[int] = None
+
+
+class ModelType(int, Enum):
+ DASHBOARD = 1
+ CHART = 2
+ REPORT_SCHEDULE = 3
+
+
+@pytest.mark.parametrize(
+ "model_type,executor_types,model_config,current_user,expected_result",
+ [
+ (
+ ModelType.REPORT_SCHEDULE,
+ [ExecutorType.SELENIUM],
+ ModelConfig(
+ owners=[1, 2],
+ creator=3,
+ modifier=4,
+ ),
+ None,
+ (ExecutorType.SELENIUM, SELENIUM_USER_ID),
+ ),
+ (
+ ModelType.REPORT_SCHEDULE,
+ [
+ ExecutorType.CREATOR,
+ ExecutorType.CREATOR_OWNER,
+ ExecutorType.OWNER,
+ ExecutorType.MODIFIER,
+ ExecutorType.MODIFIER_OWNER,
+ ExecutorType.SELENIUM,
+ ],
+ ModelConfig(owners=[]),
+ None,
+ (ExecutorType.SELENIUM, SELENIUM_USER_ID),
+ ),
+ (
+ ModelType.REPORT_SCHEDULE,
+ [
+ ExecutorType.CREATOR,
+ ExecutorType.CREATOR_OWNER,
+ ExecutorType.OWNER,
+ ExecutorType.MODIFIER,
+ ExecutorType.MODIFIER_OWNER,
+ ExecutorType.SELENIUM,
+ ],
+ ModelConfig(owners=[], modifier=1),
+ None,
+ (ExecutorType.MODIFIER, 1),
+ ),
+ (
+ ModelType.REPORT_SCHEDULE,
+ [
+ ExecutorType.CREATOR,
+ ExecutorType.CREATOR_OWNER,
+ ExecutorType.OWNER,
+ ExecutorType.MODIFIER,
+ ExecutorType.MODIFIER_OWNER,
+ ExecutorType.SELENIUM,
+ ],
+ ModelConfig(owners=[2], modifier=1),
+ None,
+ (ExecutorType.OWNER, 2),
+ ),
+ (
+ ModelType.REPORT_SCHEDULE,
+ [
+ ExecutorType.CREATOR,
+ ExecutorType.CREATOR_OWNER,
+ ExecutorType.OWNER,
+ ExecutorType.MODIFIER,
+ ExecutorType.MODIFIER_OWNER,
+ ExecutorType.SELENIUM,
+ ],
+ ModelConfig(owners=[2], creator=3, modifier=1),
+ None,
+ (ExecutorType.CREATOR, 3),
+ ),
+ (
+ ModelType.REPORT_SCHEDULE,
+ [
+ ExecutorType.OWNER,
+ ],
+ ModelConfig(owners=[1, 2, 3, 4, 5, 6, 7], creator=3, modifier=4),
+ None,
+ (ExecutorType.OWNER, 4),
+ ),
+ (
+ ModelType.REPORT_SCHEDULE,
+ [
+ ExecutorType.OWNER,
+ ],
+ ModelConfig(owners=[1, 2, 3, 4, 5, 6, 7], creator=3, modifier=8),
+ None,
+ (ExecutorType.OWNER, 3),
+ ),
+ (
+ ModelType.REPORT_SCHEDULE,
+ [
+ ExecutorType.MODIFIER_OWNER,
+ ],
+ ModelConfig(owners=[1, 2, 3, 4, 5, 6, 7], creator=8, modifier=9),
+ None,
+ ExecutorNotFoundError(),
+ ),
+ (
+ ModelType.REPORT_SCHEDULE,
+ [
+ ExecutorType.MODIFIER_OWNER,
+ ],
+ ModelConfig(owners=[1, 2, 3, 4, 5, 6, 7], creator=8, modifier=4),
+ None,
+ (ExecutorType.MODIFIER_OWNER, 4),
+ ),
+ (
+ ModelType.REPORT_SCHEDULE,
+ [
+ ExecutorType.CREATOR_OWNER,
+ ],
+ ModelConfig(owners=[1, 2, 3, 4, 5, 6, 7], creator=8, modifier=9),
+ None,
+ ExecutorNotFoundError(),
+ ),
+ (
+ ModelType.REPORT_SCHEDULE,
+ [
+ ExecutorType.CREATOR_OWNER,
+ ],
+ ModelConfig(owners=[1, 2, 3, 4, 5, 6, 7], creator=4, modifier=8),
+ None,
+ (ExecutorType.CREATOR_OWNER, 4),
+ ),
+ (
+ ModelType.REPORT_SCHEDULE,
+ [
+ ExecutorType.CURRENT_USER,
+ ],
+ ModelConfig(owners=[1, 2, 3, 4, 5, 6, 7], creator=4, modifier=8),
+ None,
+ ExecutorNotFoundError(),
+ ),
+ (
+ ModelType.DASHBOARD,
+ [
+ ExecutorType.CURRENT_USER,
+ ],
+ ModelConfig(owners=[1], creator=2, modifier=3),
+ 4,
+ (ExecutorType.CURRENT_USER, 4),
+ ),
+ (
+ ModelType.DASHBOARD,
+ [
+ ExecutorType.SELENIUM,
+ ],
+ ModelConfig(owners=[1], creator=2, modifier=3),
+ 4,
+ (ExecutorType.SELENIUM, SELENIUM_USER_ID),
+ ),
+ (
+ ModelType.DASHBOARD,
+ [
+ ExecutorType.CURRENT_USER,
+ ],
+ ModelConfig(owners=[1], creator=2, modifier=3),
+ None,
+ ExecutorNotFoundError(),
+ ),
+ (
+ ModelType.DASHBOARD,
+ [
+ ExecutorType.CREATOR_OWNER,
+ ExecutorType.MODIFIER_OWNER,
+ ExecutorType.CURRENT_USER,
+ ExecutorType.SELENIUM,
+ ],
+ ModelConfig(owners=[1], creator=2, modifier=3),
+ None,
+ (ExecutorType.SELENIUM, SELENIUM_USER_ID),
+ ),
+ (
+ ModelType.CHART,
+ [
+ ExecutorType.CURRENT_USER,
+ ],
+ ModelConfig(owners=[1], creator=2, modifier=3),
+ 4,
+ (ExecutorType.CURRENT_USER, 4),
+ ),
+ (
+ ModelType.CHART,
+ [
+ ExecutorType.SELENIUM,
+ ],
+ ModelConfig(owners=[1], creator=2, modifier=3),
+ 4,
+ (ExecutorType.SELENIUM, SELENIUM_USER_ID),
+ ),
+ (
+ ModelType.CHART,
+ [
+ ExecutorType.CURRENT_USER,
+ ],
+ ModelConfig(owners=[1], creator=2, modifier=3),
+ None,
+ ExecutorNotFoundError(),
+ ),
+ (
+ ModelType.CHART,
+ [
+ ExecutorType.CREATOR_OWNER,
+ ExecutorType.MODIFIER_OWNER,
+ ExecutorType.CURRENT_USER,
+ ExecutorType.SELENIUM,
+ ],
+ ModelConfig(owners=[1], creator=2, modifier=3),
+ None,
+ (ExecutorType.SELENIUM, SELENIUM_USER_ID),
+ ),
+ ],
+)
+def test_get_executor(
+ model_type: ModelType,
+ executor_types: List[ExecutorType],
+ model_config: ModelConfig,
+ current_user: Optional[int],
+ expected_result: Tuple[int, ExecutorNotFoundError],
+) -> None:
+ from superset.models.dashboard import Dashboard
+ from superset.models.slice import Slice
+ from superset.reports.models import ReportSchedule
+ from superset.tasks.utils import get_executor
+
+ model: Type[Union[Dashboard, ReportSchedule, Slice]]
+ model_kwargs: Dict[str, Any] = {}
+ if model_type == ModelType.REPORT_SCHEDULE:
+ model = ReportSchedule
+ model_kwargs = {
+ "type": "report",
+ "name": "test_report",
+ }
+ elif model_type == ModelType.DASHBOARD:
+ model = Dashboard
+ elif model_type == ModelType.CHART:
+ model = Slice
+ else:
+ raise Exception(f"Unsupported model type: {model_type}")
+
+ obj = model(
+ id=1,
+ owners=_get_users(model_config.owners),
+ created_by=_get_users(model_config.creator),
+ changed_by=_get_users(model_config.modifier),
+ **model_kwargs,
+ )
+ if isinstance(expected_result, Exception):
+ cm = pytest.raises(type(expected_result))
+ expected_executor_type = None
+ expected_executor = None
+ else:
+ cm = nullcontext()
+ expected_executor_type = expected_result[0]
+ expected_executor = (
+ SELENIUM_USERNAME
+ if expected_executor_type == ExecutorType.SELENIUM
+ else str(expected_result[1])
+ )
+
+ with cm:
+ executor_type, executor = get_executor(
+ executor_types=executor_types,
+ model=obj,
+ current_user=str(current_user) if current_user else None,
+ )
+ assert executor_type == expected_executor_type
+ assert executor == expected_executor
diff --git a/tests/unit_tests/test_jinja_context.py b/tests/unit_tests/test_jinja_context.py
new file mode 100644
index 0000000000000..8704b1d65c211
--- /dev/null
+++ b/tests/unit_tests/test_jinja_context.py
@@ -0,0 +1,267 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import json
+from typing import Any
+
+import pytest
+from sqlalchemy.dialects.postgresql import dialect
+
+from superset import app
+from superset.exceptions import SupersetTemplateException
+from superset.jinja_context import ExtraCache, safe_proxy
+
+
+def test_filter_values_default() -> None:
+ cache = ExtraCache()
+ assert cache.filter_values("name", "foo") == ["foo"]
+ assert cache.removed_filters == []
+
+
+def test_filter_values_remove_not_present() -> None:
+ cache = ExtraCache()
+ assert cache.filter_values("name", remove_filter=True) == []
+ assert cache.removed_filters == []
+
+
+def test_get_filters_remove_not_present() -> None:
+ cache = ExtraCache()
+ assert cache.get_filters("name", remove_filter=True) == []
+ assert cache.removed_filters == []
+
+
+def test_filter_values_no_default() -> None:
+ cache = ExtraCache()
+ assert cache.filter_values("name") == []
+
+
+def test_filter_values_adhoc_filters() -> None:
+ with app.test_request_context(
+ data={
+ "form_data": json.dumps(
+ {
+ "adhoc_filters": [
+ {
+ "clause": "WHERE",
+ "comparator": "foo",
+ "expressionType": "SIMPLE",
+ "operator": "in",
+ "subject": "name",
+ }
+ ],
+ }
+ )
+ }
+ ):
+ cache = ExtraCache()
+ assert cache.filter_values("name") == ["foo"]
+ assert cache.applied_filters == ["name"]
+
+ with app.test_request_context(
+ data={
+ "form_data": json.dumps(
+ {
+ "adhoc_filters": [
+ {
+ "clause": "WHERE",
+ "comparator": ["foo", "bar"],
+ "expressionType": "SIMPLE",
+ "operator": "in",
+ "subject": "name",
+ }
+ ],
+ }
+ )
+ }
+ ):
+ cache = ExtraCache()
+ assert cache.filter_values("name") == ["foo", "bar"]
+ assert cache.applied_filters == ["name"]
+
+
+def test_get_filters_adhoc_filters() -> None:
+ with app.test_request_context(
+ data={
+ "form_data": json.dumps(
+ {
+ "adhoc_filters": [
+ {
+ "clause": "WHERE",
+ "comparator": "foo",
+ "expressionType": "SIMPLE",
+ "operator": "in",
+ "subject": "name",
+ }
+ ],
+ }
+ )
+ }
+ ):
+ cache = ExtraCache()
+ assert cache.get_filters("name") == [
+ {"op": "IN", "col": "name", "val": ["foo"]}
+ ]
+
+ assert cache.removed_filters == []
+ assert cache.applied_filters == ["name"]
+
+ with app.test_request_context(
+ data={
+ "form_data": json.dumps(
+ {
+ "adhoc_filters": [
+ {
+ "clause": "WHERE",
+ "comparator": ["foo", "bar"],
+ "expressionType": "SIMPLE",
+ "operator": "in",
+ "subject": "name",
+ }
+ ],
+ }
+ )
+ }
+ ):
+ cache = ExtraCache()
+ assert cache.get_filters("name") == [
+ {"op": "IN", "col": "name", "val": ["foo", "bar"]}
+ ]
+ assert cache.removed_filters == []
+
+ with app.test_request_context(
+ data={
+ "form_data": json.dumps(
+ {
+ "adhoc_filters": [
+ {
+ "clause": "WHERE",
+ "comparator": ["foo", "bar"],
+ "expressionType": "SIMPLE",
+ "operator": "in",
+ "subject": "name",
+ }
+ ],
+ }
+ )
+ }
+ ):
+ cache = ExtraCache()
+ assert cache.get_filters("name", remove_filter=True) == [
+ {"op": "IN", "col": "name", "val": ["foo", "bar"]}
+ ]
+ assert cache.removed_filters == ["name"]
+ assert cache.applied_filters == ["name"]
+
+
+def test_filter_values_extra_filters() -> None:
+ with app.test_request_context(
+ data={
+ "form_data": json.dumps(
+ {"extra_filters": [{"col": "name", "op": "in", "val": "foo"}]}
+ )
+ }
+ ):
+ cache = ExtraCache()
+ assert cache.filter_values("name") == ["foo"]
+ assert cache.applied_filters == ["name"]
+
+
+def test_url_param_default() -> None:
+ with app.test_request_context():
+ cache = ExtraCache()
+ assert cache.url_param("foo", "bar") == "bar"
+
+
+def test_url_param_no_default() -> None:
+ with app.test_request_context():
+ cache = ExtraCache()
+ assert cache.url_param("foo") is None
+
+
+def test_url_param_query() -> None:
+ with app.test_request_context(query_string={"foo": "bar"}):
+ cache = ExtraCache()
+ assert cache.url_param("foo") == "bar"
+
+
+def test_url_param_form_data() -> None:
+ with app.test_request_context(
+ query_string={"form_data": json.dumps({"url_params": {"foo": "bar"}})}
+ ):
+ cache = ExtraCache()
+ assert cache.url_param("foo") == "bar"
+
+
+def test_url_param_escaped_form_data() -> None:
+ with app.test_request_context(
+ query_string={"form_data": json.dumps({"url_params": {"foo": "O'Brien"}})}
+ ):
+ cache = ExtraCache(dialect=dialect())
+ assert cache.url_param("foo") == "O''Brien"
+
+
+def test_url_param_escaped_default_form_data() -> None:
+ with app.test_request_context(
+ query_string={"form_data": json.dumps({"url_params": {"foo": "O'Brien"}})}
+ ):
+ cache = ExtraCache(dialect=dialect())
+ assert cache.url_param("bar", "O'Malley") == "O''Malley"
+
+
+def test_url_param_unescaped_form_data() -> None:
+ with app.test_request_context(
+ query_string={"form_data": json.dumps({"url_params": {"foo": "O'Brien"}})}
+ ):
+ cache = ExtraCache(dialect=dialect())
+ assert cache.url_param("foo", escape_result=False) == "O'Brien"
+
+
+def test_url_param_unescaped_default_form_data() -> None:
+ with app.test_request_context(
+ query_string={"form_data": json.dumps({"url_params": {"foo": "O'Brien"}})}
+ ):
+ cache = ExtraCache(dialect=dialect())
+ assert cache.url_param("bar", "O'Malley", escape_result=False) == "O'Malley"
+
+
+def test_safe_proxy_primitive() -> None:
+ def func(input_: Any) -> Any:
+ return input_
+
+ assert safe_proxy(func, "foo") == "foo"
+
+
+def test_safe_proxy_dict() -> None:
+ def func(input_: Any) -> Any:
+ return input_
+
+ assert safe_proxy(func, {"foo": "bar"}) == {"foo": "bar"}
+
+
+def test_safe_proxy_lambda() -> None:
+ def func(input_: Any) -> Any:
+ return input_
+
+ with pytest.raises(SupersetTemplateException):
+ safe_proxy(func, lambda: "bar")
+
+
+def test_safe_proxy_nested_lambda() -> None:
+ def func(input_: Any) -> Any:
+ return input_
+
+ with pytest.raises(SupersetTemplateException):
+ safe_proxy(func, {"foo": lambda: "bar"})
diff --git a/tests/unit_tests/thumbnails/__init__.py b/tests/unit_tests/thumbnails/__init__.py
new file mode 100644
index 0000000000000..13a83393a9124
--- /dev/null
+++ b/tests/unit_tests/thumbnails/__init__.py
@@ -0,0 +1,16 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
diff --git a/tests/unit_tests/thumbnails/test_digest.py b/tests/unit_tests/thumbnails/test_digest.py
new file mode 100644
index 0000000000000..04f244e629b59
--- /dev/null
+++ b/tests/unit_tests/thumbnails/test_digest.py
@@ -0,0 +1,258 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from contextlib import nullcontext
+from typing import Any, Dict, List, Optional, TYPE_CHECKING, Union
+from unittest.mock import patch
+
+import pytest
+from flask_appbuilder.security.sqla.models import User
+
+from superset.tasks.exceptions import ExecutorNotFoundError
+from superset.tasks.types import ExecutorType
+from superset.utils.core import override_user
+
+if TYPE_CHECKING:
+ from superset.models.dashboard import Dashboard
+ from superset.models.slice import Slice
+
+_DEFAULT_DASHBOARD_KWARGS: Dict[str, Any] = {
+ "id": 1,
+ "dashboard_title": "My Title",
+ "slices": [{"id": 1, "slice_name": "My Chart"}],
+ "position_json": '{"a": "b"}',
+ "css": "background-color: lightblue;",
+ "json_metadata": '{"c": "d"}',
+}
+
+_DEFAULT_CHART_KWARGS = {
+ "id": 2,
+ "params": {"a": "b"},
+}
+
+
+def CUSTOM_DASHBOARD_FUNC(
+ dashboard: Dashboard,
+ executor_type: ExecutorType,
+ executor: str,
+) -> str:
+ return f"{dashboard.id}.{executor_type.value}.{executor}"
+
+
+def CUSTOM_CHART_FUNC(
+ chart: Slice,
+ executor_type: ExecutorType,
+ executor: str,
+) -> str:
+ return f"{chart.id}.{executor_type.value}.{executor}"
+
+
+@pytest.mark.parametrize(
+ "dashboard_overrides,execute_as,has_current_user,use_custom_digest,expected_result",
+ [
+ (
+ None,
+ [ExecutorType.SELENIUM],
+ False,
+ False,
+ "71452fee8ffbd8d340193d611bcd4559",
+ ),
+ (
+ None,
+ [ExecutorType.CURRENT_USER],
+ True,
+ False,
+ "209dc060ac19271b8708731e3b8280f5",
+ ),
+ (
+ {
+ "dashboard_title": "My Other Title",
+ },
+ [ExecutorType.CURRENT_USER],
+ True,
+ False,
+ "209dc060ac19271b8708731e3b8280f5",
+ ),
+ (
+ {
+ "id": 2,
+ },
+ [ExecutorType.CURRENT_USER],
+ True,
+ False,
+ "06a4144466dbd5ffad0c3c2225e96296",
+ ),
+ (
+ {
+ "slices": [{"id": 2, "slice_name": "My Other Chart"}],
+ },
+ [ExecutorType.CURRENT_USER],
+ True,
+ False,
+ "a823ece9563895ccb14f3d9095e84f7a",
+ ),
+ (
+ {
+ "position_json": {"b": "c"},
+ },
+ [ExecutorType.CURRENT_USER],
+ True,
+ False,
+ "33c5475f92a904925ab3ef493526e5b5",
+ ),
+ (
+ {
+ "css": "background-color: darkblue;",
+ },
+ [ExecutorType.CURRENT_USER],
+ True,
+ False,
+ "cec57345e6402c0d4b3caee5cfaa0a03",
+ ),
+ (
+ {
+ "json_metadata": {"d": "e"},
+ },
+ [ExecutorType.CURRENT_USER],
+ True,
+ False,
+ "5380dcbe94621a0759b09554404f3d02",
+ ),
+ (
+ None,
+ [ExecutorType.CURRENT_USER],
+ True,
+ True,
+ "1.current_user.1",
+ ),
+ (
+ None,
+ [ExecutorType.CURRENT_USER],
+ False,
+ False,
+ ExecutorNotFoundError(),
+ ),
+ ],
+)
+def test_dashboard_digest(
+ dashboard_overrides: Optional[Dict[str, Any]],
+ execute_as: List[ExecutorType],
+ has_current_user: bool,
+ use_custom_digest: bool,
+ expected_result: Union[str, Exception],
+) -> None:
+ from superset import app
+ from superset.models.dashboard import Dashboard
+ from superset.models.slice import Slice
+ from superset.thumbnails.digest import get_dashboard_digest
+
+ kwargs = {
+ **_DEFAULT_DASHBOARD_KWARGS,
+ **(dashboard_overrides or {}),
+ }
+ slices = [Slice(**slice_kwargs) for slice_kwargs in kwargs.pop("slices")]
+ dashboard = Dashboard(**kwargs, slices=slices)
+ user: Optional[User] = None
+ if has_current_user:
+ user = User(id=1, username="1")
+ func = CUSTOM_DASHBOARD_FUNC if use_custom_digest else None
+
+ with patch.dict(
+ app.config,
+ {
+ "THUMBNAIL_EXECUTE_AS": execute_as,
+ "THUMBNAIL_DASHBOARD_DIGEST_FUNC": func,
+ },
+ ), override_user(user):
+ cm = (
+ pytest.raises(type(expected_result))
+ if isinstance(expected_result, Exception)
+ else nullcontext()
+ )
+ with cm:
+ assert get_dashboard_digest(dashboard=dashboard) == expected_result
+
+
+@pytest.mark.parametrize(
+ "chart_overrides,execute_as,has_current_user,use_custom_digest,expected_result",
+ [
+ (
+ None,
+ [ExecutorType.SELENIUM],
+ False,
+ False,
+ "47d852b5c4df211c115905617bb722c1",
+ ),
+ (
+ None,
+ [ExecutorType.CURRENT_USER],
+ True,
+ False,
+ "4f8109d3761e766e650af514bb358f10",
+ ),
+ (
+ None,
+ [ExecutorType.CURRENT_USER],
+ True,
+ True,
+ "2.current_user.1",
+ ),
+ (
+ None,
+ [ExecutorType.CURRENT_USER],
+ False,
+ False,
+ ExecutorNotFoundError(),
+ ),
+ ],
+)
+def test_chart_digest(
+ chart_overrides: Optional[Dict[str, Any]],
+ execute_as: List[ExecutorType],
+ has_current_user: bool,
+ use_custom_digest: bool,
+ expected_result: Union[str, Exception],
+) -> None:
+ from superset import app
+ from superset.models.slice import Slice
+ from superset.thumbnails.digest import get_chart_digest
+
+ kwargs = {
+ **_DEFAULT_CHART_KWARGS,
+ **(chart_overrides or {}),
+ }
+ chart = Slice(**kwargs)
+ user: Optional[User] = None
+ if has_current_user:
+ user = User(id=1, username="1")
+ func = CUSTOM_CHART_FUNC if use_custom_digest else None
+
+ with patch.dict(
+ app.config,
+ {
+ "THUMBNAIL_EXECUTE_AS": execute_as,
+ "THUMBNAIL_CHART_DIGEST_FUNC": func,
+ },
+ ), override_user(user):
+ cm = (
+ pytest.raises(type(expected_result))
+ if isinstance(expected_result, Exception)
+ else nullcontext()
+ )
+ with cm:
+ assert get_chart_digest(chart=chart) == expected_result
diff --git a/tests/unit_tests/utils/cache_test.py b/tests/unit_tests/utils/cache_test.py
new file mode 100644
index 0000000000000..53650e1d20324
--- /dev/null
+++ b/tests/unit_tests/utils/cache_test.py
@@ -0,0 +1,52 @@
+# -*- coding: utf-8 -*-
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+# pylint: disable=import-outside-toplevel, unused-argument
+
+from pytest_mock import MockerFixture
+
+
+def test_memoized_func(mocker: MockerFixture) -> None:
+ """
+ Test the ``memoized_func`` decorator.
+ """
+ from superset.utils.cache import memoized_func
+
+ cache = mocker.MagicMock()
+
+ decorator = memoized_func("db:{self.id}:schema:{schema}:view_list", cache)
+ decorated = decorator(lambda self, schema, cache=False: 42)
+
+ self = mocker.MagicMock()
+ self.id = 1
+
+ # skip cache
+ result = decorated(self, "public", cache=False)
+ assert result == 42
+ cache.get.assert_not_called()
+
+ # check cache, no cached value
+ cache.get.return_value = None
+ result = decorated(self, "public", cache=True)
+ assert result == 42
+ cache.get.assert_called_with("db:1:schema:public:view_list")
+
+ # check cache, cached value
+ cache.get.return_value = 43
+ result = decorated(self, "public", cache=True)
+ assert result == 43
diff --git a/tests/unit_tests/utils/date_parser_tests.py b/tests/unit_tests/utils/date_parser_tests.py
new file mode 100644
index 0000000000000..f3c8b6968077b
--- /dev/null
+++ b/tests/unit_tests/utils/date_parser_tests.py
@@ -0,0 +1,358 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import re
+from datetime import date, datetime, timedelta
+from typing import Optional, Tuple
+from unittest.mock import Mock, patch
+
+import pytest
+from dateutil.relativedelta import relativedelta
+
+from superset.charts.commands.exceptions import (
+ TimeRangeAmbiguousError,
+ TimeRangeParseFailError,
+)
+from superset.utils.date_parser import (
+ DateRangeMigration,
+ datetime_eval,
+ get_past_or_future,
+ get_since_until,
+ parse_human_datetime,
+ parse_human_timedelta,
+ parse_past_timedelta,
+)
+
+
+def mock_parse_human_datetime(s: str) -> Optional[datetime]:
+ if s == "now":
+ return datetime(2016, 11, 7, 9, 30, 10)
+ elif s == "2018":
+ return datetime(2018, 1, 1)
+ elif s == "2018-9":
+ return datetime(2018, 9, 1)
+ elif s == "today":
+ return datetime(2016, 11, 7)
+ elif s == "yesterday":
+ return datetime(2016, 11, 6)
+ elif s == "tomorrow":
+ return datetime(2016, 11, 8)
+ elif s == "Last year":
+ return datetime(2015, 11, 7)
+ elif s == "Last week":
+ return datetime(2015, 10, 31)
+ elif s == "Last 5 months":
+ return datetime(2016, 6, 7)
+ elif s == "Next 5 months":
+ return datetime(2017, 4, 7)
+ elif s in ["5 days", "5 days ago"]:
+ return datetime(2016, 11, 2)
+ elif s == "2018-01-01T00:00:00":
+ return datetime(2018, 1, 1)
+ elif s == "2018-12-31T23:59:59":
+ return datetime(2018, 12, 31, 23, 59, 59)
+ else:
+ return None
+
+
+@patch("superset.utils.date_parser.parse_human_datetime", mock_parse_human_datetime)
+def test_get_since_until() -> None:
+ result: Tuple[Optional[datetime], Optional[datetime]]
+ expected: Tuple[Optional[datetime], Optional[datetime]]
+
+ result = get_since_until()
+ expected = None, datetime(2016, 11, 7)
+ assert result == expected
+
+ result = get_since_until(" : now")
+ expected = None, datetime(2016, 11, 7, 9, 30, 10)
+ assert result == expected
+
+ result = get_since_until("yesterday : tomorrow")
+ expected = datetime(2016, 11, 6), datetime(2016, 11, 8)
+ assert result == expected
+
+ result = get_since_until("2018-01-01T00:00:00 : 2018-12-31T23:59:59")
+ expected = datetime(2018, 1, 1), datetime(2018, 12, 31, 23, 59, 59)
+ assert result == expected
+
+ result = get_since_until("Last year")
+ expected = datetime(2015, 11, 7), datetime(2016, 11, 7)
+ assert result == expected
+
+ result = get_since_until("Last quarter")
+ expected = datetime(2016, 8, 7), datetime(2016, 11, 7)
+ assert result == expected
+
+ result = get_since_until("Last 5 months")
+ expected = datetime(2016, 6, 7), datetime(2016, 11, 7)
+ assert result == expected
+
+ result = get_since_until("Last 1 month")
+ expected = datetime(2016, 10, 7), datetime(2016, 11, 7)
+ assert result == expected
+
+ result = get_since_until("Next 5 months")
+ expected = datetime(2016, 11, 7), datetime(2017, 4, 7)
+ assert result == expected
+
+ result = get_since_until("Next 1 month")
+ expected = datetime(2016, 11, 7), datetime(2016, 12, 7)
+ assert result == expected
+
+ result = get_since_until(since="5 days")
+ expected = datetime(2016, 11, 2), datetime(2016, 11, 7)
+ assert result == expected
+
+ result = get_since_until(since="5 days ago", until="tomorrow")
+ expected = datetime(2016, 11, 2), datetime(2016, 11, 8)
+ assert result == expected
+
+ result = get_since_until(time_range="yesterday : tomorrow", time_shift="1 day")
+ expected = datetime(2016, 11, 5), datetime(2016, 11, 7)
+ assert result == expected
+
+ result = get_since_until(time_range="5 days : now")
+ expected = datetime(2016, 11, 2), datetime(2016, 11, 7, 9, 30, 10)
+ assert result == expected
+
+ result = get_since_until("Last week", relative_end="now")
+ expected = datetime(2016, 10, 31), datetime(2016, 11, 7, 9, 30, 10)
+ assert result == expected
+
+ result = get_since_until("Last week", relative_start="now")
+ expected = datetime(2016, 10, 31, 9, 30, 10), datetime(2016, 11, 7)
+ assert result == expected
+
+ result = get_since_until("Last week", relative_start="now", relative_end="now")
+ expected = datetime(2016, 10, 31, 9, 30, 10), datetime(2016, 11, 7, 9, 30, 10)
+ assert result == expected
+
+ result = get_since_until("previous calendar week")
+ expected = datetime(2016, 10, 31, 0, 0, 0), datetime(2016, 11, 7, 0, 0, 0)
+ assert result == expected
+
+ result = get_since_until("previous calendar month")
+ expected = datetime(2016, 10, 1, 0, 0, 0), datetime(2016, 11, 1, 0, 0, 0)
+ assert result == expected
+
+ result = get_since_until("previous calendar year")
+ expected = datetime(2015, 1, 1, 0, 0, 0), datetime(2016, 1, 1, 0, 0, 0)
+ assert result == expected
+
+ with pytest.raises(ValueError):
+ get_since_until(time_range="tomorrow : yesterday")
+
+
+@patch("superset.utils.date_parser.parse_human_datetime", mock_parse_human_datetime)
+def test_datetime_eval() -> None:
+ result = datetime_eval("datetime('now')")
+ expected = datetime(2016, 11, 7, 9, 30, 10)
+ assert result == expected
+
+ result = datetime_eval("datetime('today')")
+ expected = datetime(2016, 11, 7)
+ assert result == expected
+
+ result = datetime_eval("datetime('2018')")
+ expected = datetime(2018, 1, 1)
+ assert result == expected
+
+ result = datetime_eval("datetime('2018-9')")
+ expected = datetime(2018, 9, 1)
+ assert result == expected
+
+ # Parse compact arguments spelling
+ result = datetime_eval("dateadd(datetime('today'),1,year,)")
+ expected = datetime(2017, 11, 7)
+ assert result == expected
+
+ result = datetime_eval("dateadd(datetime('today'), -2, year)")
+ expected = datetime(2014, 11, 7)
+ assert result == expected
+
+ result = datetime_eval("dateadd(datetime('today'), 2, quarter)")
+ expected = datetime(2017, 5, 7)
+ assert result == expected
+
+ result = datetime_eval("dateadd(datetime('today'), 3, month)")
+ expected = datetime(2017, 2, 7)
+ assert result == expected
+
+ result = datetime_eval("dateadd(datetime('today'), -3, week)")
+ expected = datetime(2016, 10, 17)
+ assert result == expected
+
+ result = datetime_eval("dateadd(datetime('today'), 3, day)")
+ expected = datetime(2016, 11, 10)
+ assert result == expected
+
+ result = datetime_eval("dateadd(datetime('now'), 3, hour)")
+ expected = datetime(2016, 11, 7, 12, 30, 10)
+ assert result == expected
+
+ result = datetime_eval("dateadd(datetime('now'), 40, minute)")
+ expected = datetime(2016, 11, 7, 10, 10, 10)
+ assert result == expected
+
+ result = datetime_eval("dateadd(datetime('now'), -11, second)")
+ expected = datetime(2016, 11, 7, 9, 29, 59)
+ assert result == expected
+
+ result = datetime_eval("datetrunc(datetime('now'), year)")
+ expected = datetime(2016, 1, 1, 0, 0, 0)
+ assert result == expected
+
+ result = datetime_eval("datetrunc(datetime('now'), quarter)")
+ expected = datetime(2016, 10, 1, 0, 0, 0)
+ assert result == expected
+
+ result = datetime_eval("datetrunc(datetime('now'), month)")
+ expected = datetime(2016, 11, 1, 0, 0, 0)
+ assert result == expected
+
+ result = datetime_eval("datetrunc(datetime('now'), day)")
+ expected = datetime(2016, 11, 7, 0, 0, 0)
+ assert result == expected
+
+ result = datetime_eval("datetrunc(datetime('now'), week)")
+ expected = datetime(2016, 11, 7, 0, 0, 0)
+ assert result == expected
+
+ result = datetime_eval("datetrunc(datetime('now'), hour)")
+ expected = datetime(2016, 11, 7, 9, 0, 0)
+ assert result == expected
+
+ result = datetime_eval("datetrunc(datetime('now'), minute)")
+ expected = datetime(2016, 11, 7, 9, 30, 0)
+ assert result == expected
+
+ result = datetime_eval("datetrunc(datetime('now'), second)")
+ expected = datetime(2016, 11, 7, 9, 30, 10)
+ assert result == expected
+
+ result = datetime_eval("lastday(datetime('now'), year)")
+ expected = datetime(2016, 12, 31, 0, 0, 0)
+ assert result == expected
+
+ result = datetime_eval("lastday(datetime('today'), month)")
+ expected = datetime(2016, 11, 30, 0, 0, 0)
+ assert result == expected
+
+ result = datetime_eval("holiday('Christmas')")
+ expected = datetime(2016, 12, 25, 0, 0, 0)
+ assert result == expected
+
+ result = datetime_eval("holiday('Labor day', datetime('2018-01-01T00:00:00'))")
+ expected = datetime(2018, 9, 3, 0, 0, 0)
+ assert result == expected
+
+ result = datetime_eval(
+ "holiday('Boxing day', datetime('2018-01-01T00:00:00'), 'UK')"
+ )
+ expected = datetime(2018, 12, 26, 0, 0, 0)
+ assert result == expected
+
+ result = datetime_eval(
+ "lastday(dateadd(datetime('2018-01-01T00:00:00'), 1, month), month)"
+ )
+ expected = datetime(2018, 2, 28, 0, 0, 0)
+ assert result == expected
+
+
+@patch("superset.utils.date_parser.datetime")
+def test_parse_human_timedelta(mock_datetime: Mock) -> None:
+ mock_datetime.now.return_value = datetime(2019, 4, 1)
+ mock_datetime.side_effect = lambda *args, **kw: datetime(*args, **kw)
+ assert parse_human_timedelta("now") == timedelta(0)
+ assert parse_human_timedelta("1 year") == timedelta(366)
+ assert parse_human_timedelta("-1 year") == timedelta(-365)
+ assert parse_human_timedelta(None) == timedelta(0)
+ assert parse_human_timedelta("1 month", datetime(2019, 4, 1)) == timedelta(30)
+ assert parse_human_timedelta("1 month", datetime(2019, 5, 1)) == timedelta(31)
+ assert parse_human_timedelta("1 month", datetime(2019, 2, 1)) == timedelta(28)
+ assert parse_human_timedelta("-1 month", datetime(2019, 2, 1)) == timedelta(-31)
+
+
+@patch("superset.utils.date_parser.datetime")
+def test_parse_past_timedelta(mock_datetime: Mock) -> None:
+ mock_datetime.now.return_value = datetime(2019, 4, 1)
+ mock_datetime.side_effect = lambda *args, **kw: datetime(*args, **kw)
+ assert parse_past_timedelta("1 year") == timedelta(365)
+ assert parse_past_timedelta("-1 year") == timedelta(365)
+ assert parse_past_timedelta("52 weeks") == timedelta(364)
+ assert parse_past_timedelta("1 month") == timedelta(31)
+
+
+def test_get_past_or_future() -> None:
+ # 2020 is a leap year
+ dttm = datetime(2020, 2, 29)
+ assert get_past_or_future("1 year", dttm) == datetime(2021, 2, 28)
+ assert get_past_or_future("-1 year", dttm) == datetime(2019, 2, 28)
+ assert get_past_or_future("1 month", dttm) == datetime(2020, 3, 29)
+ assert get_past_or_future("3 month", dttm) == datetime(2020, 5, 29)
+
+
+def test_parse_human_datetime() -> None:
+ with pytest.raises(TimeRangeAmbiguousError):
+ parse_human_datetime("2 days")
+
+ with pytest.raises(TimeRangeAmbiguousError):
+ parse_human_datetime("2 day")
+
+ with pytest.raises(TimeRangeParseFailError):
+ parse_human_datetime("xxxxxxx")
+
+ assert parse_human_datetime("2015-04-03") == datetime(2015, 4, 3, 0, 0)
+ assert parse_human_datetime("2/3/1969") == datetime(1969, 2, 3, 0, 0)
+
+ assert parse_human_datetime("now") <= datetime.now()
+ assert parse_human_datetime("yesterday") < datetime.now()
+ assert date.today() - timedelta(1) == parse_human_datetime("yesterday").date()
+
+ assert (
+ parse_human_datetime("one year ago").date()
+ == (datetime.now() - relativedelta(years=1)).date()
+ )
+ assert (
+ parse_human_datetime("2 years after").date()
+ == (datetime.now() + relativedelta(years=2)).date()
+ )
+
+
+def test_date_range_migration() -> None:
+ params = '{"time_range": " 8 days : 2020-03-10T00:00:00"}'
+ assert re.search(DateRangeMigration.x_dateunit_in_since, params)
+
+ params = '{"time_range": "2020-03-10T00:00:00 : 8 days "}'
+ assert re.search(DateRangeMigration.x_dateunit_in_until, params)
+
+ params = '{"time_range": " 2 weeks : 8 days "}'
+ assert re.search(DateRangeMigration.x_dateunit_in_since, params)
+ assert re.search(DateRangeMigration.x_dateunit_in_until, params)
+
+ params = '{"time_range": "2 weeks ago : 8 days later"}'
+ assert not re.search(DateRangeMigration.x_dateunit_in_since, params)
+ assert not re.search(DateRangeMigration.x_dateunit_in_until, params)
+
+ field = " 8 days "
+ assert re.search(DateRangeMigration.x_dateunit, field)
+
+ field = "last week"
+ assert not re.search(DateRangeMigration.x_dateunit, field)
+
+ field = "10 years ago"
+ assert not re.search(DateRangeMigration.x_dateunit, field)
diff --git a/tests/unit_tests/utils/db.py b/tests/unit_tests/utils/db.py
new file mode 100644
index 0000000000000..554c95bd43187
--- /dev/null
+++ b/tests/unit_tests/utils/db.py
@@ -0,0 +1,30 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from typing import Any
+
+from superset import security_manager
+
+
+def get_test_user(id_: int, username: str) -> Any:
+ """Create a sample test user"""
+ return security_manager.user_model(
+ id=id_,
+ username=username,
+ first_name=username,
+ last_name=username,
+ email=f"{username}@example.com",
+ )
diff --git a/tests/unit_tests/utils/log_tests.py b/tests/unit_tests/utils/log_tests.py
new file mode 100644
index 0000000000000..5b031b5778875
--- /dev/null
+++ b/tests/unit_tests/utils/log_tests.py
@@ -0,0 +1,37 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+
+from superset.utils.log import get_logger_from_status
+
+
+def test_log_from_status_exception() -> None:
+ (func, log_level) = get_logger_from_status(500)
+ assert func.__name__ == "exception"
+ assert log_level == "exception"
+
+
+def test_log_from_status_warning() -> None:
+ (func, log_level) = get_logger_from_status(422)
+ assert func.__name__ == "warning"
+ assert log_level == "warning"
+
+
+def test_log_from_status_info() -> None:
+ (func, log_level) = get_logger_from_status(300)
+ assert func.__name__ == "info"
+ assert log_level == "info"
diff --git a/tests/unit_tests/utils/test_core.py b/tests/unit_tests/utils/test_core.py
new file mode 100644
index 0000000000000..6845bb2fc1545
--- /dev/null
+++ b/tests/unit_tests/utils/test_core.py
@@ -0,0 +1,86 @@
+# -*- coding: utf-8 -*-
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from typing import Any, Dict
+
+import pytest
+
+from superset.utils.core import QueryObjectFilterClause, remove_extra_adhoc_filters
+
+ADHOC_FILTER: QueryObjectFilterClause = {
+ "col": "foo",
+ "op": "==",
+ "val": "bar",
+}
+
+EXTRA_FILTER: QueryObjectFilterClause = {
+ "col": "foo",
+ "op": "==",
+ "val": "bar",
+ "isExtra": True,
+}
+
+
+@pytest.mark.parametrize(
+ "original,expected",
+ [
+ ({"foo": "bar"}, {"foo": "bar"}),
+ (
+ {"foo": "bar", "adhoc_filters": [ADHOC_FILTER]},
+ {"foo": "bar", "adhoc_filters": [ADHOC_FILTER]},
+ ),
+ (
+ {"foo": "bar", "adhoc_filters": [EXTRA_FILTER]},
+ {"foo": "bar", "adhoc_filters": []},
+ ),
+ (
+ {
+ "foo": "bar",
+ "adhoc_filters": [ADHOC_FILTER, EXTRA_FILTER],
+ },
+ {"foo": "bar", "adhoc_filters": [ADHOC_FILTER]},
+ ),
+ (
+ {
+ "foo": "bar",
+ "adhoc_filters_b": [ADHOC_FILTER, EXTRA_FILTER],
+ },
+ {"foo": "bar", "adhoc_filters_b": [ADHOC_FILTER]},
+ ),
+ (
+ {
+ "foo": "bar",
+ "custom_adhoc_filters": [
+ ADHOC_FILTER,
+ EXTRA_FILTER,
+ ],
+ },
+ {
+ "foo": "bar",
+ "custom_adhoc_filters": [
+ ADHOC_FILTER,
+ EXTRA_FILTER,
+ ],
+ },
+ ),
+ ],
+)
+def test_remove_extra_adhoc_filters(
+ original: Dict[str, Any], expected: Dict[str, Any]
+) -> None:
+ remove_extra_adhoc_filters(original)
+ assert expected == original
diff --git a/tests/unit_tests/utils/test_decorators.py b/tests/unit_tests/utils/test_decorators.py
new file mode 100644
index 0000000000000..3aafc7a91db2b
--- /dev/null
+++ b/tests/unit_tests/utils/test_decorators.py
@@ -0,0 +1,87 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+
+from contextlib import nullcontext
+from enum import Enum
+from inspect import isclass
+from typing import Any, Optional
+from unittest.mock import call, Mock, patch
+
+import pytest
+
+from superset import app
+from superset.utils import decorators
+
+
+class ResponseValues(str, Enum):
+ FAIL = "fail"
+ WARN = "warn"
+ OK = "ok"
+
+
+def test_debounce() -> None:
+ mock = Mock()
+
+ @decorators.debounce()
+ def myfunc(arg1: int, arg2: int, kwarg1: str = "abc", kwarg2: int = 2) -> int:
+ mock(arg1, kwarg1)
+ return arg1 + arg2 + kwarg2
+
+ # should be called only once when arguments don't change
+ myfunc(1, 1)
+ myfunc(1, 1)
+ result = myfunc(1, 1)
+ mock.assert_called_once_with(1, "abc")
+ assert result == 4
+
+ # kwarg order shouldn't matter
+ myfunc(1, 0, kwarg2=2, kwarg1="haha")
+ result = myfunc(1, 0, kwarg1="haha", kwarg2=2)
+ mock.assert_has_calls([call(1, "abc"), call(1, "haha")])
+ assert result == 3
+
+
+@pytest.mark.parametrize(
+ "response_value, expected_exception, expected_result",
+ [
+ (ResponseValues.OK, None, "custom.prefix.ok"),
+ (ResponseValues.FAIL, ValueError, "custom.prefix.error"),
+ (ResponseValues.WARN, FileNotFoundError, "custom.prefix.warn"),
+ ],
+)
+def test_statsd_gauge(
+ response_value: str, expected_exception: Optional[Exception], expected_result: str
+) -> None:
+ @decorators.statsd_gauge("custom.prefix")
+ def my_func(response: ResponseValues, *args: Any, **kwargs: Any) -> str:
+ if response == ResponseValues.FAIL:
+ raise ValueError("Error")
+ if response == ResponseValues.WARN:
+ raise FileNotFoundError("Not found")
+ return "OK"
+
+ with patch.object(app.config["STATS_LOGGER"], "gauge") as mock:
+ cm = (
+ pytest.raises(expected_exception)
+ if isclass(expected_exception) and issubclass(expected_exception, Exception)
+ else nullcontext()
+ )
+
+ with cm:
+ my_func(response_value, 1, 2)
+ mock.assert_called_once_with(expected_result, 1)
diff --git a/tests/unit_tests/utils/test_file.py b/tests/unit_tests/utils/test_file.py
new file mode 100644
index 0000000000000..de20402e5c21c
--- /dev/null
+++ b/tests/unit_tests/utils/test_file.py
@@ -0,0 +1,44 @@
+# -*- coding: utf-8 -*-
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import pytest
+
+from superset.utils.file import get_filename
+
+
+@pytest.mark.parametrize(
+ "model_name,model_id,skip_id,expected_filename",
+ [
+ ("Energy Sankey", 132, False, "Energy_Sankey_132"),
+ ("Energy Sankey", 132, True, "Energy_Sankey"),
+ ("folder1/Energy Sankey", 132, True, "folder1_Energy_Sankey"),
+ ("D:\\Charts\\Energy Sankey", 132, True, "DChartsEnergy_Sankey"),
+ ("🥴🥴🥴", 4751, False, "4751"),
+ ("🥴🥴🥴", 4751, True, "4751"),
+ ("Energy Sankey 🥴🥴🥴", 4751, False, "Energy_Sankey_4751"),
+ ("Energy Sankey 🥴🥴🥴", 4751, True, "Energy_Sankey"),
+ ("你好", 475, False, "475"),
+ ("你好", 475, True, "475"),
+ ("Energy Sankey 你好", 475, False, "Energy_Sankey_475"),
+ ("Energy Sankey 你好", 475, True, "Energy_Sankey"),
+ ],
+)
+def test_get_filename(
+ model_name: str, model_id: int, skip_id: bool, expected_filename: str
+) -> None:
+ original_filename = get_filename(model_name, model_id, skip_id)
+ assert expected_filename == original_filename
diff --git a/tests/unit_tests/utils/urls_tests.py b/tests/unit_tests/utils/urls_tests.py
new file mode 100644
index 0000000000000..208d6caea4375
--- /dev/null
+++ b/tests/unit_tests/utils/urls_tests.py
@@ -0,0 +1,66 @@
+# -*- coding: utf-8 -*-
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import pytest
+
+from superset.utils.urls import modify_url_query
+
+EXPLORE_CHART_LINK = "http://localhost:9000/explore/?form_data=%7B%22slice_id%22%3A+76%7D&standalone=true&force=false"
+
+EXPLORE_DASHBOARD_LINK = "http://localhost:9000/superset/dashboard/3/?standalone=3"
+
+
+def test_convert_chart_link() -> None:
+ test_url = modify_url_query(EXPLORE_CHART_LINK, standalone="0")
+ assert (
+ test_url
+ == "http://localhost:9000/explore/?form_data=%7B%22slice_id%22%3A%2076%7D&standalone=0&force=false"
+ )
+
+
+def test_convert_dashboard_link() -> None:
+ test_url = modify_url_query(EXPLORE_DASHBOARD_LINK, standalone="0")
+ assert test_url == "http://localhost:9000/superset/dashboard/3/?standalone=0"
+
+
+def test_convert_dashboard_link_with_integer() -> None:
+ test_url = modify_url_query(EXPLORE_DASHBOARD_LINK, standalone=0)
+ assert test_url == "http://localhost:9000/superset/dashboard/3/?standalone=0"
+
+
+@pytest.mark.parametrize(
+ "url,is_safe",
+ [
+ ("http://localhost/", True),
+ ("http://localhost/superset/1", True),
+ ("https://localhost/", False),
+ ("https://localhost/superset/1", False),
+ ("localhost/superset/1", False),
+ ("ftp://localhost/superset/1", False),
+ ("http://external.com", False),
+ ("https://external.com", False),
+ ("external.com", False),
+ ("///localhost", False),
+ ("xpto://localhost:[3/1/", False),
+ ],
+)
+def test_is_safe_url(url: str, is_safe: bool) -> None:
+ from superset import app
+ from superset.utils.urls import is_safe_url
+
+ with app.test_request_context("/"):
+ assert is_safe_url(url) == is_safe
diff --git a/tests/unit_tests/views/__init__.py b/tests/unit_tests/views/__init__.py
new file mode 100644
index 0000000000000..13a83393a9124
--- /dev/null
+++ b/tests/unit_tests/views/__init__.py
@@ -0,0 +1,16 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.