diff --git a/superset/connectors/sqla/models.py b/superset/connectors/sqla/models.py index b88acadb72f47..5c646d57d9c8b 100644 --- a/superset/connectors/sqla/models.py +++ b/superset/connectors/sqla/models.py @@ -900,7 +900,11 @@ def adhoc_metric_to_sqla( elif expression_type == utils.AdhocMetricExpressionType.SQL: tp = self.get_template_processor() expression = tp.process_template(cast(str, metric["sqlExpression"])) - validate_adhoc_subquery(expression) + expression = validate_adhoc_subquery( + expression, + self.database_id, + self.schema, + ) try: expression = sanitize_clause(expression) except QueryClauseValidationException as ex: @@ -929,7 +933,11 @@ def adhoc_column_to_sqla( if template_processor and expression: expression = template_processor.process_template(expression) if expression: - validate_adhoc_subquery(expression) + expression = validate_adhoc_subquery( + expression, + self.database_id, + self.schema, + ) try: expression = sanitize_clause(expression) except QueryClauseValidationException as ex: @@ -983,9 +991,9 @@ def is_alias_used_in_orderby(col: ColumnElement) -> bool: if is_alias_used_in_orderby(col): col.name = f"{col.name}__" - def _get_sqla_row_level_filters( + def get_sqla_row_level_filters( self, template_processor: BaseTemplateProcessor - ) -> List[str]: + ) -> List[TextClause]: """ Return the appropriate row level security filters for this table and the current user. @@ -993,7 +1001,6 @@ def _get_sqla_row_level_filters( :param BaseTemplateProcessor template_processor: The template processor to apply to the filters. :returns: A list of SQL clauses to be ANDed together. - :rtype: List[str] """ all_filters: List[TextClause] = [] filter_groups: Dict[Union[int, str], List[TextClause]] = defaultdict(list) @@ -1146,6 +1153,12 @@ def get_sqla_query( # pylint: disable=too-many-arguments,too-many-locals,too-ma col: Union[AdhocMetric, ColumnElement] = orig_col if isinstance(col, dict): col = cast(AdhocMetric, col) + if col.get("sqlExpression"): + col["sqlExpression"] = validate_adhoc_subquery( + cast(str, col["sqlExpression"]), + self.database_id, + self.schema, + ) if utils.is_adhoc_metric(col): # add adhoc sort by column to columns_by_name if not exists col = self.adhoc_metric_to_sqla(col, columns_by_name) @@ -1195,7 +1208,11 @@ def get_sqla_query( # pylint: disable=too-many-arguments,too-many-locals,too-ma elif selected in columns_by_name: outer = columns_by_name[selected].get_sqla_col() else: - validate_adhoc_subquery(selected) + selected = validate_adhoc_subquery( + selected, + self.database_id, + self.schema, + ) outer = literal_column(f"({selected})") outer = self.make_sqla_column_compatible(outer, selected) else: @@ -1208,7 +1225,11 @@ def get_sqla_query( # pylint: disable=too-many-arguments,too-many-locals,too-ma select_exprs.append(outer) elif columns: for selected in columns: - validate_adhoc_subquery(selected) + selected = validate_adhoc_subquery( + selected, + self.database_id, + self.schema, + ) select_exprs.append( columns_by_name[selected].get_sqla_col() if selected in columns_by_name @@ -1374,7 +1395,7 @@ def get_sqla_query( # pylint: disable=too-many-arguments,too-many-locals,too-ma _("Invalid filter operation type: %(op)s", op=op) ) if is_feature_enabled("ROW_LEVEL_SECURITY"): - where_clause_and += self._get_sqla_row_level_filters(template_processor) + where_clause_and += self.get_sqla_row_level_filters(template_processor) if extras: where = extras.get("where") if where: @@ -1421,7 +1442,6 @@ def get_sqla_query( # pylint: disable=too-many-arguments,too-many-locals,too-ma and db_engine_spec.allows_hidden_cc_in_orderby and col.name in [select_col.name for select_col in select_exprs] ): - validate_adhoc_subquery(str(col.expression)) col = literal_column(col.name) direction = asc if ascending else desc qry = qry.order_by(direction(col)) diff --git a/superset/connectors/sqla/utils.py b/superset/connectors/sqla/utils.py index 4fc11a4d1d16b..766b74e57c004 100644 --- a/superset/connectors/sqla/utils.py +++ b/superset/connectors/sqla/utils.py @@ -33,7 +33,7 @@ ) from superset.models.core import Database from superset.result_set import SupersetResultSet -from superset.sql_parse import has_table_query, ParsedQuery, Table +from superset.sql_parse import has_table_query, insert_rls, ParsedQuery, Table from superset.tables.models import Table as NewTable if TYPE_CHECKING: @@ -136,29 +136,39 @@ def get_virtual_table_metadata(dataset: "SqlaTable") -> List[Dict[str, str]]: return cols -def validate_adhoc_subquery(raw_sql: str) -> None: +def validate_adhoc_subquery( + sql: str, + database_id: int, + default_schema: str, +) -> str: """ - Check if adhoc SQL contains sub-queries or nested sub-queries with table - :param raw_sql: adhoc sql expression + Check if adhoc SQL contains sub-queries or nested sub-queries with table. + + If sub-queries are allowed, the adhoc SQL is modified to insert any applicable RLS + predicates to it. + + :param sql: adhoc sql expression :raise SupersetSecurityException if sql contains sub-queries or nested sub-queries with table """ # pylint: disable=import-outside-toplevel from superset import is_feature_enabled - if is_feature_enabled("ALLOW_ADHOC_SUBQUERY"): - return - - for statement in sqlparse.parse(raw_sql): + statements = [] + for statement in sqlparse.parse(sql): if has_table_query(statement): - raise SupersetSecurityException( - SupersetError( - error_type=SupersetErrorType.ADHOC_SUBQUERY_NOT_ALLOWED_ERROR, - message=_("Custom SQL fields cannot contain sub-queries."), - level=ErrorLevel.ERROR, + if not is_feature_enabled("ALLOW_ADHOC_SUBQUERY"): + raise SupersetSecurityException( + SupersetError( + error_type=SupersetErrorType.ADHOC_SUBQUERY_NOT_ALLOWED_ERROR, + message=_("Custom SQL fields cannot contain sub-queries."), + level=ErrorLevel.ERROR, + ) ) - ) - return + statement = insert_rls(statement, database_id, default_schema) + statements.append(statement) + + return ";\n".join(str(statement) for statement in statements) def load_or_create_tables( # pylint: disable=too-many-arguments diff --git a/superset/sql_parse.py b/superset/sql_parse.py index 95361b39a6a27..6bfb63c425c48 100644 --- a/superset/sql_parse.py +++ b/superset/sql_parse.py @@ -18,10 +18,11 @@ import re from dataclasses import dataclass from enum import Enum -from typing import List, Optional, Set, Tuple +from typing import cast, List, Optional, Set, Tuple from urllib import parse import sqlparse +from sqlalchemy import and_ from sqlparse.sql import ( Identifier, IdentifierList, @@ -283,7 +284,7 @@ def get_statements(self) -> List[str]: return statements @staticmethod - def _get_table(tlist: TokenList) -> Optional[Table]: + def get_table(tlist: TokenList) -> Optional[Table]: """ Return the table if valid, i.e., conforms to the [[catalog.]schema.]table construct. @@ -324,7 +325,7 @@ def _process_tokenlist(self, token_list: TokenList) -> None: """ # exclude subselects if "(" not in str(token_list): - table = self._get_table(token_list) + table = self.get_table(token_list) if table and not table.table.startswith(CTE_PREFIX): self._tables.add(table) return @@ -500,7 +501,7 @@ def has_table_query(token_list: TokenList) -> bool: state = InsertRLSState.SCANNING for token in token_list.tokens: - # # Recurse into child token list + # Recurse into child token list if isinstance(token, TokenList) and has_table_query(token): return True @@ -523,7 +524,7 @@ def has_table_query(token_list: TokenList) -> bool: def add_table_name(rls: TokenList, table: str) -> None: """ - Modify a RLS expression ensuring columns are fully qualified. + Modify a RLS expression inplace ensuring columns are fully qualified. """ tokens = rls.tokens[:] while tokens: @@ -539,45 +540,70 @@ def add_table_name(rls: TokenList, table: str) -> None: tokens.extend(token.tokens) -def matches_table_name(candidate: Token, table: str) -> bool: +def get_rls_for_table( + candidate: Token, + database_id: int, + default_schema: Optional[str], +) -> Optional[TokenList]: """ - Returns if the token represents a reference to the table. - - Tables can be fully qualified with periods. - - Note that in theory a table should be represented as an identifier, but due to - sqlparse's aggressive list of keywords (spanning multiple dialects) often it gets - classified as a keyword. + Given a table name, return any associated RLS predicates. """ + # pylint: disable=import-outside-toplevel + from superset import db + from superset.connectors.sqla.models import SqlaTable + if not isinstance(candidate, Identifier): candidate = Identifier([Token(Name, candidate.value)]) - target = sqlparse.parse(table)[0].tokens[0] - if not isinstance(target, Identifier): - target = Identifier([Token(Name, target.value)]) + table = ParsedQuery.get_table(candidate) + if not table: + return None - # match from right to left, splitting on the period, eg, schema.table == table - for left, right in zip(candidate.tokens[::-1], target.tokens[::-1]): - if left.value != right.value: - return False + dataset = ( + db.session.query(SqlaTable) + .filter( + and_( + SqlaTable.database_id == database_id, + SqlaTable.schema == (table.schema or default_schema), + SqlaTable.table_name == table.table, + ) + ) + .one_or_none() + ) + if not dataset: + return None + + template_processor = dataset.get_template_processor() + # pylint: disable=protected-access + predicate = " AND ".join( + str(filter_) + for filter_ in dataset.get_sqla_row_level_filters(template_processor) + ) + if not predicate: + return None + + rls = sqlparse.parse(predicate)[0] + add_table_name(rls, str(dataset)) - return True + return rls -def insert_rls(token_list: TokenList, table: str, rls: TokenList) -> TokenList: +def insert_rls( + token_list: TokenList, + database_id: int, + default_schema: Optional[str], +) -> TokenList: """ - Update a statement inplace applying an RLS associated with a given table. + Update a statement inplace applying any associated RLS predicates. """ - # make sure the identifier has the table name - add_table_name(rls, table) - + rls: Optional[TokenList] = None state = InsertRLSState.SCANNING for token in token_list.tokens: # Recurse into child token list if isinstance(token, TokenList): i = token_list.tokens.index(token) - token_list.tokens[i] = insert_rls(token, table, rls) + token_list.tokens[i] = insert_rls(token, database_id, default_schema) # Found a source keyword (FROM/JOIN) if imt(token, m=[(Keyword, "FROM"), (Keyword, "JOIN")]): @@ -587,12 +613,14 @@ def insert_rls(token_list: TokenList, table: str, rls: TokenList) -> TokenList: elif state == InsertRLSState.SEEN_SOURCE and ( isinstance(token, Identifier) or token.ttype == Keyword ): - if matches_table_name(token, table): + rls = get_rls_for_table(token, database_id, default_schema) + if rls: state = InsertRLSState.FOUND_TABLE # Found WHERE clause, insert RLS. Note that we insert it even it already exists, # to be on the safe side: it could be present in a clause like `1=1 OR RLS`. elif state == InsertRLSState.FOUND_TABLE and isinstance(token, Where): + rls = cast(TokenList, rls) token.tokens[1:1] = [Token(Whitespace, " "), Token(Punctuation, "(")] token.tokens.extend( [ diff --git a/tests/unit_tests/sql_parse_tests.py b/tests/unit_tests/sql_parse_tests.py index 75f099e52b6e1..4a1ff89d74cc6 100644 --- a/tests/unit_tests/sql_parse_tests.py +++ b/tests/unit_tests/sql_parse_tests.py @@ -14,21 +14,24 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - -# pylint: disable=invalid-name, too-many-lines +# pylint: disable=invalid-name, redefined-outer-name, unused-argument, protected-access, too-many-lines import unittest -from typing import Set +from typing import Optional, Set import pytest import sqlparse +from pytest_mock import MockerFixture +from sqlalchemy import text +from sqlparse.sql import Identifier, Token, TokenList +from sqlparse.tokens import Name from superset.exceptions import QueryClauseValidationException from superset.sql_parse import ( add_table_name, + get_rls_for_table, has_table_query, insert_rls, - matches_table_name, ParsedQuery, sanitize_clause, strip_comments_from_sql, @@ -1391,13 +1394,37 @@ def test_has_table_query(sql: str, expected: bool) -> None: ), ], ) -def test_insert_rls(sql: str, table: str, rls: str, expected: str) -> None: +def test_insert_rls( + mocker: MockerFixture, sql: str, table: str, rls: str, expected: str +) -> None: """ Insert into a statement a given RLS condition associated with a table. """ - statement = sqlparse.parse(sql)[0] condition = sqlparse.parse(rls)[0] - assert str(insert_rls(statement, table, condition)).strip() == expected.strip() + add_table_name(condition, table) + + # pylint: disable=unused-argument + def get_rls_for_table( + candidate: Token, database_id: int, default_schema: str + ) -> Optional[TokenList]: + """ + Return the RLS ``condition`` if ``candidate`` matches ``table``. + """ + # compare ignoring schema + for left, right in zip(str(candidate).split(".")[::-1], table.split(".")[::-1]): + if left != right: + return None + return condition + + mocker.patch("superset.sql_parse.get_rls_for_table", new=get_rls_for_table) + + statement = sqlparse.parse(sql)[0] + assert ( + str( + insert_rls(token_list=statement, database_id=1, default_schema="my_schema") + ).strip() + == expected.strip() + ) @pytest.mark.parametrize( @@ -1415,16 +1442,29 @@ def test_add_table_name(rls: str, table: str, expected: str) -> None: assert str(condition) == expected -@pytest.mark.parametrize( - "candidate,table,expected", - [ - ("table", "table", True), - ("schema.table", "table", True), - ("table", "schema.table", True), - ('schema."my table"', '"my table"', True), - ('schema."my.table"', '"my.table"', True), - ], -) -def test_matches_table_name(candidate: str, table: str, expected: bool) -> None: - token = sqlparse.parse(candidate)[0].tokens[0] - assert matches_table_name(token, table) == expected +def test_get_rls_for_table(mocker: MockerFixture, app_context: None) -> None: + """ + Tests for ``get_rls_for_table``. + """ + candidate = Identifier([Token(Name, "some_table")]) + db = mocker.patch("superset.db") + dataset = db.session.query().filter().one_or_none() + dataset.__str__.return_value = "some_table" + + dataset.get_sqla_row_level_filters.return_value = [text("organization_id = 1")] + assert ( + str(get_rls_for_table(candidate, 1, "public")) + == "some_table.organization_id = 1" + ) + + dataset.get_sqla_row_level_filters.return_value = [ + text("organization_id = 1"), + text("foo = 'bar'"), + ] + assert ( + str(get_rls_for_table(candidate, 1, "public")) + == "some_table.organization_id = 1 AND some_table.foo = 'bar'" + ) + + dataset.get_sqla_row_level_filters.return_value = [] + assert get_rls_for_table(candidate, 1, "public") is None