From 50902d51f56eac7813572178527ee83b429db284 Mon Sep 17 00:00:00 2001 From: Lily Kuang Date: Fri, 18 Mar 2022 10:04:55 -0700 Subject: [PATCH] fix: allow subquery in ad-hoc SQL (WIP) (#19242) * allow adhoc subquery * add config for allow ad hoc subquery * default to true allow adhoc subquery * fix test * Update superset/errors.py Co-authored-by: Beto Dealmeida * Update superset/connectors/sqla/utils.py Co-authored-by: David Aaron Suddjian <1858430+suddjian@users.noreply.github.com> * rename and add doc string * fix for big query test * Update superset/connectors/sqla/utils.py Co-authored-by: Beto Dealmeida * Apply suggestions from code review Co-authored-by: Beto Dealmeida * add test * update validate adhoc subquery Co-authored-by: Beto Dealmeida Co-authored-by: David Aaron Suddjian <1858430+suddjian@users.noreply.github.com> --- superset/config.py | 1 + superset/connectors/sqla/models.py | 7 +++++ superset/connectors/sqla/utils.py | 28 +++++++++++++++++- superset/errors.py | 3 ++ tests/integration_tests/sqla_models_tests.py | 31 +++++++++++++++++++- tests/unit_tests/sql_parse_tests.py | 2 ++ 6 files changed, 70 insertions(+), 2 deletions(-) diff --git a/superset/config.py b/superset/config.py index 775765d08f0aa..6a84a1cf40976 100644 --- a/superset/config.py +++ b/superset/config.py @@ -443,6 +443,7 @@ def _try_json_readsha(filepath: str, length: int) -> Optional[str]: "ALLOW_FULL_CSV_EXPORT": False, "UX_BETA": False, "GENERIC_CHART_AXES": False, + "ALLOW_ADHOC_SUBQUERY": False, } # Feature flags may also be set via 'SUPERSET_FEATURE_' prefixed environment vars. diff --git a/superset/connectors/sqla/models.py b/superset/connectors/sqla/models.py index 99cbc50997559..23d3d326cd9a5 100644 --- a/superset/connectors/sqla/models.py +++ b/superset/connectors/sqla/models.py @@ -78,6 +78,7 @@ from superset.connectors.sqla.utils import ( get_physical_table_metadata, get_virtual_table_metadata, + validate_adhoc_subquery, ) from superset.datasets.models import Dataset as NewDataset from superset.db_engine_specs.base import BaseEngineSpec, CTE_ALIAS, TimestampExpression @@ -885,6 +886,7 @@ def adhoc_metric_to_sqla( elif expression_type == utils.AdhocMetricExpressionType.SQL: tp = self.get_template_processor() expression = tp.process_template(cast(str, metric["sqlExpression"])) + validate_adhoc_subquery(expression) sqla_metric = literal_column(expression) else: raise QueryObjectValidationError("Adhoc metric expressionType is invalid") @@ -908,6 +910,8 @@ def adhoc_column_to_sqla( expression = col["sqlExpression"] if template_processor and expression: expression = template_processor.process_template(expression) + if expression: + validate_adhoc_subquery(expression) sqla_metric = literal_column(expression) return self.make_sqla_column_compatible(sqla_metric, label) @@ -1166,6 +1170,7 @@ def get_sqla_query( # pylint: disable=too-many-arguments,too-many-locals,too-ma elif selected in columns_by_name: outer = columns_by_name[selected].get_sqla_col() else: + validate_adhoc_subquery(selected) outer = literal_column(f"({selected})") outer = self.make_sqla_column_compatible(outer, selected) else: @@ -1178,6 +1183,7 @@ def get_sqla_query( # pylint: disable=too-many-arguments,too-many-locals,too-ma select_exprs.append(outer) elif columns: for selected in columns: + validate_adhoc_subquery(selected) select_exprs.append( columns_by_name[selected].get_sqla_col() if selected in columns_by_name @@ -1389,6 +1395,7 @@ def get_sqla_query( # pylint: disable=too-many-arguments,too-many-locals,too-ma and db_engine_spec.allows_hidden_cc_in_orderby and col.name in [select_col.name for select_col in select_exprs] ): + validate_adhoc_subquery(str(col.expression)) col = literal_column(col.name) direction = asc if ascending else desc qry = qry.order_by(direction(col)) diff --git a/superset/connectors/sqla/utils.py b/superset/connectors/sqla/utils.py index e5209e08dcf68..984eef78f4b76 100644 --- a/superset/connectors/sqla/utils.py +++ b/superset/connectors/sqla/utils.py @@ -17,6 +17,7 @@ from contextlib import closing from typing import Dict, List, Optional, TYPE_CHECKING +import sqlparse from flask_babel import lazy_gettext as _ from sqlalchemy.exc import NoSuchTableError from sqlalchemy.sql.type_api import TypeEngine @@ -28,7 +29,7 @@ ) from superset.models.core import Database from superset.result_set import SupersetResultSet -from superset.sql_parse import ParsedQuery +from superset.sql_parse import has_table_query, ParsedQuery if TYPE_CHECKING: from superset.connectors.sqla.models import SqlaTable @@ -119,3 +120,28 @@ def get_virtual_table_metadata(dataset: "SqlaTable") -> List[Dict[str, str]]: except Exception as ex: raise SupersetGenericDBErrorException(message=str(ex)) from ex return cols + + +def validate_adhoc_subquery(raw_sql: str) -> None: + """ + Check if adhoc SQL contains sub-queries or nested sub-queries with table + :param raw_sql: adhoc sql expression + :raise SupersetSecurityException if sql contains sub-queries or + nested sub-queries with table + """ + # pylint: disable=import-outside-toplevel + from superset import is_feature_enabled + + if is_feature_enabled("ALLOW_ADHOC_SUBQUERY"): + return + + for statement in sqlparse.parse(raw_sql): + if has_table_query(statement): + raise SupersetSecurityException( + SupersetError( + error_type=SupersetErrorType.ADHOC_SUBQUERY_NOT_ALLOWED_ERROR, + message=_("Custom SQL fields cannot contain sub-queries."), + level=ErrorLevel.ERROR, + ) + ) + return diff --git a/superset/errors.py b/superset/errors.py index 9b3414ecb57e2..9198a82d3fe61 100644 --- a/superset/errors.py +++ b/superset/errors.py @@ -80,6 +80,7 @@ class SupersetErrorType(str, Enum): SQLLAB_TIMEOUT_ERROR = "SQLLAB_TIMEOUT_ERROR" RESULTS_BACKEND_ERROR = "RESULTS_BACKEND_ERROR" ASYNC_WORKERS_ERROR = "ASYNC_WORKERS_ERROR" + ADHOC_SUBQUERY_NOT_ALLOWED_ERROR = "ADHOC_SUBQUERY_NOT_ALLOWED_ERROR" # Generic errors GENERIC_COMMAND_ERROR = "GENERIC_COMMAND_ERROR" @@ -138,10 +139,12 @@ class SupersetErrorType(str, Enum): 1034: _("The port number is invalid."), 1035: _("Failed to start remote query on a worker."), 1036: _("The database was deleted."), + 1037: _("Custom SQL fields cannot contain sub-queries."), } ERROR_TYPES_TO_ISSUE_CODES_MAPPING = { + SupersetErrorType.ADHOC_SUBQUERY_NOT_ALLOWED_ERROR: [1037], SupersetErrorType.BACKEND_TIMEOUT_ERROR: [1000, 1001], SupersetErrorType.GENERIC_DB_ENGINE_ERROR: [1002], SupersetErrorType.COLUMN_DOES_NOT_EXIST_ERROR: [1003, 1004], diff --git a/tests/integration_tests/sqla_models_tests.py b/tests/integration_tests/sqla_models_tests.py index 54779fcbbff9b..223d48a4899a7 100644 --- a/tests/integration_tests/sqla_models_tests.py +++ b/tests/integration_tests/sqla_models_tests.py @@ -34,7 +34,7 @@ from superset.constants import EMPTY_STRING, NULL_STRING from superset.db_engine_specs.bigquery import BigQueryEngineSpec from superset.db_engine_specs.druid import DruidEngineSpec -from superset.exceptions import QueryObjectValidationError +from superset.exceptions import QueryObjectValidationError, SupersetSecurityException from superset.models.core import Database from superset.utils.core import ( AdhocMetricExpressionType, @@ -239,6 +239,35 @@ def test_jinja_metrics_and_calc_columns(self, flask_g): db.session.delete(table) db.session.commit() + def test_adhoc_metrics_and_calc_columns(self): + base_query_obj = { + "granularity": None, + "from_dttm": None, + "to_dttm": None, + "groupby": ["user", "expr"], + "metrics": [ + { + "expressionType": AdhocMetricExpressionType.SQL, + "sqlExpression": "(SELECT (SELECT * from birth_names) " + "from test_validate_adhoc_sql)", + "label": "adhoc_metrics", + } + ], + "is_timeseries": False, + "filter": [], + } + + table = SqlaTable( + table_name="test_validate_adhoc_sql", database=get_example_database() + ) + db.session.commit() + + with pytest.raises(SupersetSecurityException): + table.get_sqla_query(**base_query_obj) + # Cleanup + db.session.delete(table) + db.session.commit() + @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") def test_where_operators(self): filters: Tuple[FilterTestCase, ...] = ( diff --git a/tests/unit_tests/sql_parse_tests.py b/tests/unit_tests/sql_parse_tests.py index aa811bdef757e..886eb368e4aa4 100644 --- a/tests/unit_tests/sql_parse_tests.py +++ b/tests/unit_tests/sql_parse_tests.py @@ -1208,6 +1208,8 @@ def test_sqlparse_issue_652(): ("SELECT * FROM (SELECT 1 AS foo, 2 AS bar) ORDER BY foo ASC, bar", False), ("SELECT * FROM other_table", True), ("extract(HOUR from from_unixtime(hour_ts)", False), + ("(SELECT * FROM table)", True), + ("(SELECT COUNT(DISTINCT name) from birth_names)", True), ], ) def test_has_table_query(sql: str, expected: bool) -> None: