Skip to content

Commit

Permalink
feat: improve adhoc SQL validation (apache#19454)
Browse files Browse the repository at this point in the history
* feat: improve adhoc SQL validation

* Small changes

* Add more unit tests
  • Loading branch information
betodealmeida authored and philipher29 committed Jun 9, 2022
1 parent 7857ecf commit 521ee6d
Show file tree
Hide file tree
Showing 4 changed files with 170 additions and 72 deletions.
38 changes: 29 additions & 9 deletions superset/connectors/sqla/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -900,7 +900,11 @@ def adhoc_metric_to_sqla(
elif expression_type == utils.AdhocMetricExpressionType.SQL:
tp = self.get_template_processor()
expression = tp.process_template(cast(str, metric["sqlExpression"]))
validate_adhoc_subquery(expression)
expression = validate_adhoc_subquery(
expression,
self.database_id,
self.schema,
)
try:
expression = sanitize_clause(expression)
except QueryClauseValidationException as ex:
Expand Down Expand Up @@ -929,7 +933,11 @@ def adhoc_column_to_sqla(
if template_processor and expression:
expression = template_processor.process_template(expression)
if expression:
validate_adhoc_subquery(expression)
expression = validate_adhoc_subquery(
expression,
self.database_id,
self.schema,
)
try:
expression = sanitize_clause(expression)
except QueryClauseValidationException as ex:
Expand Down Expand Up @@ -983,17 +991,16 @@ def is_alias_used_in_orderby(col: ColumnElement) -> bool:
if is_alias_used_in_orderby(col):
col.name = f"{col.name}__"

def _get_sqla_row_level_filters(
def get_sqla_row_level_filters(
self, template_processor: BaseTemplateProcessor
) -> List[str]:
) -> List[TextClause]:
"""
Return the appropriate row level security filters for
this table and the current user.
:param BaseTemplateProcessor template_processor: The template
processor to apply to the filters.
:returns: A list of SQL clauses to be ANDed together.
:rtype: List[str]
"""
all_filters: List[TextClause] = []
filter_groups: Dict[Union[int, str], List[TextClause]] = defaultdict(list)
Expand Down Expand Up @@ -1146,6 +1153,12 @@ def get_sqla_query( # pylint: disable=too-many-arguments,too-many-locals,too-ma
col: Union[AdhocMetric, ColumnElement] = orig_col
if isinstance(col, dict):
col = cast(AdhocMetric, col)
if col.get("sqlExpression"):
col["sqlExpression"] = validate_adhoc_subquery(
cast(str, col["sqlExpression"]),
self.database_id,
self.schema,
)
if utils.is_adhoc_metric(col):
# add adhoc sort by column to columns_by_name if not exists
col = self.adhoc_metric_to_sqla(col, columns_by_name)
Expand Down Expand Up @@ -1195,7 +1208,11 @@ def get_sqla_query( # pylint: disable=too-many-arguments,too-many-locals,too-ma
elif selected in columns_by_name:
outer = columns_by_name[selected].get_sqla_col()
else:
validate_adhoc_subquery(selected)
selected = validate_adhoc_subquery(
selected,
self.database_id,
self.schema,
)
outer = literal_column(f"({selected})")
outer = self.make_sqla_column_compatible(outer, selected)
else:
Expand All @@ -1208,7 +1225,11 @@ def get_sqla_query( # pylint: disable=too-many-arguments,too-many-locals,too-ma
select_exprs.append(outer)
elif columns:
for selected in columns:
validate_adhoc_subquery(selected)
selected = validate_adhoc_subquery(
selected,
self.database_id,
self.schema,
)
select_exprs.append(
columns_by_name[selected].get_sqla_col()
if selected in columns_by_name
Expand Down Expand Up @@ -1374,7 +1395,7 @@ def get_sqla_query( # pylint: disable=too-many-arguments,too-many-locals,too-ma
_("Invalid filter operation type: %(op)s", op=op)
)
if is_feature_enabled("ROW_LEVEL_SECURITY"):
where_clause_and += self._get_sqla_row_level_filters(template_processor)
where_clause_and += self.get_sqla_row_level_filters(template_processor)
if extras:
where = extras.get("where")
if where:
Expand Down Expand Up @@ -1421,7 +1442,6 @@ def get_sqla_query( # pylint: disable=too-many-arguments,too-many-locals,too-ma
and db_engine_spec.allows_hidden_cc_in_orderby
and col.name in [select_col.name for select_col in select_exprs]
):
validate_adhoc_subquery(str(col.expression))
col = literal_column(col.name)
direction = asc if ascending else desc
qry = qry.order_by(direction(col))
Expand Down
40 changes: 25 additions & 15 deletions superset/connectors/sqla/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
)
from superset.models.core import Database
from superset.result_set import SupersetResultSet
from superset.sql_parse import has_table_query, ParsedQuery, Table
from superset.sql_parse import has_table_query, insert_rls, ParsedQuery, Table
from superset.tables.models import Table as NewTable

if TYPE_CHECKING:
Expand Down Expand Up @@ -136,29 +136,39 @@ def get_virtual_table_metadata(dataset: "SqlaTable") -> List[Dict[str, str]]:
return cols


def validate_adhoc_subquery(raw_sql: str) -> None:
def validate_adhoc_subquery(
sql: str,
database_id: int,
default_schema: str,
) -> str:
"""
Check if adhoc SQL contains sub-queries or nested sub-queries with table
:param raw_sql: adhoc sql expression
Check if adhoc SQL contains sub-queries or nested sub-queries with table.
If sub-queries are allowed, the adhoc SQL is modified to insert any applicable RLS
predicates to it.
:param sql: adhoc sql expression
:raise SupersetSecurityException if sql contains sub-queries or
nested sub-queries with table
"""
# pylint: disable=import-outside-toplevel
from superset import is_feature_enabled

if is_feature_enabled("ALLOW_ADHOC_SUBQUERY"):
return

for statement in sqlparse.parse(raw_sql):
statements = []
for statement in sqlparse.parse(sql):
if has_table_query(statement):
raise SupersetSecurityException(
SupersetError(
error_type=SupersetErrorType.ADHOC_SUBQUERY_NOT_ALLOWED_ERROR,
message=_("Custom SQL fields cannot contain sub-queries."),
level=ErrorLevel.ERROR,
if not is_feature_enabled("ALLOW_ADHOC_SUBQUERY"):
raise SupersetSecurityException(
SupersetError(
error_type=SupersetErrorType.ADHOC_SUBQUERY_NOT_ALLOWED_ERROR,
message=_("Custom SQL fields cannot contain sub-queries."),
level=ErrorLevel.ERROR,
)
)
)
return
statement = insert_rls(statement, database_id, default_schema)
statements.append(statement)

return ";\n".join(str(statement) for statement in statements)


def load_or_create_tables( # pylint: disable=too-many-arguments
Expand Down
84 changes: 56 additions & 28 deletions superset/sql_parse.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,11 @@
import re
from dataclasses import dataclass
from enum import Enum
from typing import List, Optional, Set, Tuple
from typing import cast, List, Optional, Set, Tuple
from urllib import parse

import sqlparse
from sqlalchemy import and_
from sqlparse.sql import (
Identifier,
IdentifierList,
Expand Down Expand Up @@ -283,7 +284,7 @@ def get_statements(self) -> List[str]:
return statements

@staticmethod
def _get_table(tlist: TokenList) -> Optional[Table]:
def get_table(tlist: TokenList) -> Optional[Table]:
"""
Return the table if valid, i.e., conforms to the [[catalog.]schema.]table
construct.
Expand Down Expand Up @@ -324,7 +325,7 @@ def _process_tokenlist(self, token_list: TokenList) -> None:
"""
# exclude subselects
if "(" not in str(token_list):
table = self._get_table(token_list)
table = self.get_table(token_list)
if table and not table.table.startswith(CTE_PREFIX):
self._tables.add(table)
return
Expand Down Expand Up @@ -500,7 +501,7 @@ def has_table_query(token_list: TokenList) -> bool:
state = InsertRLSState.SCANNING
for token in token_list.tokens:

# # Recurse into child token list
# Recurse into child token list
if isinstance(token, TokenList) and has_table_query(token):
return True

Expand All @@ -523,7 +524,7 @@ def has_table_query(token_list: TokenList) -> bool:

def add_table_name(rls: TokenList, table: str) -> None:
"""
Modify a RLS expression ensuring columns are fully qualified.
Modify a RLS expression inplace ensuring columns are fully qualified.
"""
tokens = rls.tokens[:]
while tokens:
Expand All @@ -539,45 +540,70 @@ def add_table_name(rls: TokenList, table: str) -> None:
tokens.extend(token.tokens)


def matches_table_name(candidate: Token, table: str) -> bool:
def get_rls_for_table(
candidate: Token,
database_id: int,
default_schema: Optional[str],
) -> Optional[TokenList]:
"""
Returns if the token represents a reference to the table.
Tables can be fully qualified with periods.
Note that in theory a table should be represented as an identifier, but due to
sqlparse's aggressive list of keywords (spanning multiple dialects) often it gets
classified as a keyword.
Given a table name, return any associated RLS predicates.
"""
# pylint: disable=import-outside-toplevel
from superset import db
from superset.connectors.sqla.models import SqlaTable

if not isinstance(candidate, Identifier):
candidate = Identifier([Token(Name, candidate.value)])

target = sqlparse.parse(table)[0].tokens[0]
if not isinstance(target, Identifier):
target = Identifier([Token(Name, target.value)])
table = ParsedQuery.get_table(candidate)
if not table:
return None

# match from right to left, splitting on the period, eg, schema.table == table
for left, right in zip(candidate.tokens[::-1], target.tokens[::-1]):
if left.value != right.value:
return False
dataset = (
db.session.query(SqlaTable)
.filter(
and_(
SqlaTable.database_id == database_id,
SqlaTable.schema == (table.schema or default_schema),
SqlaTable.table_name == table.table,
)
)
.one_or_none()
)
if not dataset:
return None

template_processor = dataset.get_template_processor()
# pylint: disable=protected-access
predicate = " AND ".join(
str(filter_)
for filter_ in dataset.get_sqla_row_level_filters(template_processor)
)
if not predicate:
return None

rls = sqlparse.parse(predicate)[0]
add_table_name(rls, str(dataset))

return True
return rls


def insert_rls(token_list: TokenList, table: str, rls: TokenList) -> TokenList:
def insert_rls(
token_list: TokenList,
database_id: int,
default_schema: Optional[str],
) -> TokenList:
"""
Update a statement inplace applying an RLS associated with a given table.
Update a statement inplace applying any associated RLS predicates.
"""
# make sure the identifier has the table name
add_table_name(rls, table)

rls: Optional[TokenList] = None
state = InsertRLSState.SCANNING
for token in token_list.tokens:

# Recurse into child token list
if isinstance(token, TokenList):
i = token_list.tokens.index(token)
token_list.tokens[i] = insert_rls(token, table, rls)
token_list.tokens[i] = insert_rls(token, database_id, default_schema)

# Found a source keyword (FROM/JOIN)
if imt(token, m=[(Keyword, "FROM"), (Keyword, "JOIN")]):
Expand All @@ -587,12 +613,14 @@ def insert_rls(token_list: TokenList, table: str, rls: TokenList) -> TokenList:
elif state == InsertRLSState.SEEN_SOURCE and (
isinstance(token, Identifier) or token.ttype == Keyword
):
if matches_table_name(token, table):
rls = get_rls_for_table(token, database_id, default_schema)
if rls:
state = InsertRLSState.FOUND_TABLE

# Found WHERE clause, insert RLS. Note that we insert it even it already exists,
# to be on the safe side: it could be present in a clause like `1=1 OR RLS`.
elif state == InsertRLSState.FOUND_TABLE and isinstance(token, Where):
rls = cast(TokenList, rls)
token.tokens[1:1] = [Token(Whitespace, " "), Token(Punctuation, "(")]
token.tokens.extend(
[
Expand Down
Loading

0 comments on commit 521ee6d

Please sign in to comment.