Skip to content

Commit

Permalink
feat: helper functions for RLS (#19055)
Browse files Browse the repository at this point in the history
* feat: helper functions for RLS

* Add function to inject RLS

* Add UNION tests

* Add tests for schema

* Add more tests; cleanup

* has_table_query via tree traversal

* Wrap existing predicate in parenthesis

* Clean up logic

* Improve table matching
  • Loading branch information
betodealmeida authored Mar 11, 2022
1 parent c337491 commit 8234395
Show file tree
Hide file tree
Showing 2 changed files with 447 additions and 2 deletions.
202 changes: 202 additions & 0 deletions superset/sql_parse.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
remove_quotes,
Token,
TokenList,
Where,
)
from sqlparse.tokens import (
CTE,
Expand Down Expand Up @@ -458,3 +459,204 @@ def validate_filter_clause(clause: str) -> None:
)
if open_parens > 0:
raise QueryClauseValidationException("Unclosed parenthesis in filter clause")


class InsertRLSState(str, Enum):
"""
State machine that scans for WHERE and ON clauses referencing tables.
"""

SCANNING = "SCANNING"
SEEN_SOURCE = "SEEN_SOURCE"
FOUND_TABLE = "FOUND_TABLE"


def has_table_query(token_list: TokenList) -> bool:
"""
Return if a stament has a query reading from a table.
>>> has_table_query(sqlparse.parse("COUNT(*)")[0])
False
>>> has_table_query(sqlparse.parse("SELECT * FROM table")[0])
True
Note that queries reading from constant values return false:
>>> has_table_query(sqlparse.parse("SELECT * FROM (SELECT 1)")[0])
False
"""
state = InsertRLSState.SCANNING
for token in token_list.tokens:

# # Recurse into child token list
if isinstance(token, TokenList) and has_table_query(token):
return True

# Found a source keyword (FROM/JOIN)
if imt(token, m=[(Keyword, "FROM"), (Keyword, "JOIN")]):
state = InsertRLSState.SEEN_SOURCE

# Found identifier/keyword after FROM/JOIN
elif state == InsertRLSState.SEEN_SOURCE and (
isinstance(token, sqlparse.sql.Identifier) or token.ttype == Keyword
):
return True

# Found nothing, leaving source
elif state == InsertRLSState.SEEN_SOURCE and token.ttype != Whitespace:
state = InsertRLSState.SCANNING

return False


def add_table_name(rls: TokenList, table: str) -> None:
"""
Modify a RLS expression ensuring columns are fully qualified.
"""
tokens = rls.tokens[:]
while tokens:
token = tokens.pop(0)

if isinstance(token, Identifier) and token.get_parent_name() is None:
token.tokens = [
Token(Name, table),
Token(Punctuation, "."),
Token(Name, token.get_name()),
]
elif isinstance(token, TokenList):
tokens.extend(token.tokens)


def matches_table_name(candidate: Token, table: str) -> bool:
"""
Returns if the token represents a reference to the table.
Tables can be fully qualified with periods.
Note that in theory a table should be represented as an identifier, but due to
sqlparse's aggressive list of keywords (spanning multiple dialects) often it gets
classified as a keyword.
"""
if not isinstance(candidate, Identifier):
candidate = Identifier([Token(Name, candidate.value)])

target = sqlparse.parse(table)[0].tokens[0]
if not isinstance(target, Identifier):
target = Identifier([Token(Name, target.value)])

# match from right to left, splitting on the period, eg, schema.table == table
for left, right in zip(candidate.tokens[::-1], target.tokens[::-1]):
if left.value != right.value:
return False

return True


def insert_rls(token_list: TokenList, table: str, rls: TokenList) -> TokenList:
"""
Update a statement inplace applying an RLS associated with a given table.
"""
# make sure the identifier has the table name
add_table_name(rls, table)

state = InsertRLSState.SCANNING
for token in token_list.tokens:

# Recurse into child token list
if isinstance(token, TokenList):
i = token_list.tokens.index(token)
token_list.tokens[i] = insert_rls(token, table, rls)

# Found a source keyword (FROM/JOIN)
if imt(token, m=[(Keyword, "FROM"), (Keyword, "JOIN")]):
state = InsertRLSState.SEEN_SOURCE

# Found identifier/keyword after FROM/JOIN, test for table
elif state == InsertRLSState.SEEN_SOURCE and (
isinstance(token, Identifier) or token.ttype == Keyword
):
if matches_table_name(token, table):
state = InsertRLSState.FOUND_TABLE

# Found WHERE clause, insert RLS. Note that we insert it even it already exists,
# to be on the safe side: it could be present in a clause like `1=1 OR RLS`.
elif state == InsertRLSState.FOUND_TABLE and isinstance(token, Where):
token.tokens[1:1] = [Token(Whitespace, " "), Token(Punctuation, "(")]
token.tokens.extend(
[
Token(Punctuation, ")"),
Token(Whitespace, " "),
Token(Keyword, "AND"),
Token(Whitespace, " "),
]
+ rls.tokens
)
state = InsertRLSState.SCANNING

# Found ON clause, insert RLS. The logic for ON is more complicated than the logic
# for WHERE because in the former the comparisons are siblings, while on the
# latter they are children.
elif (
state == InsertRLSState.FOUND_TABLE
and token.ttype == Keyword
and token.value.upper() == "ON"
):
tokens = [
Token(Whitespace, " "),
rls,
Token(Whitespace, " "),
Token(Keyword, "AND"),
Token(Whitespace, " "),
Token(Punctuation, "("),
]
i = token_list.tokens.index(token)
token.parent.tokens[i + 1 : i + 1] = tokens
i += len(tokens) + 2

# close parenthesis after last existing comparison
j = 0
for j, sibling in enumerate(token_list.tokens[i:]):
# scan until we hit a non-comparison keyword (like ORDER BY) or a WHERE
if (
sibling.ttype == Keyword
and not imt(
sibling, m=[(Keyword, "AND"), (Keyword, "OR"), (Keyword, "NOT")]
)
or isinstance(sibling, Where)
):
j -= 1
break
token.parent.tokens[i + j + 1 : i + j + 1] = [
Token(Whitespace, " "),
Token(Punctuation, ")"),
Token(Whitespace, " "),
]

state = InsertRLSState.SCANNING

# Found table but no WHERE clause found, insert one
elif state == InsertRLSState.FOUND_TABLE and token.ttype != Whitespace:
i = token_list.tokens.index(token)
token_list.tokens[i:i] = [
Token(Whitespace, " "),
Where([Token(Keyword, "WHERE"), Token(Whitespace, " "), rls]),
Token(Whitespace, " "),
]

state = InsertRLSState.SCANNING

# Found nothing, leaving source
elif state == InsertRLSState.SEEN_SOURCE and token.ttype != Whitespace:
state = InsertRLSState.SCANNING

# found table at the end of the statement; append a WHERE clause
if state == InsertRLSState.FOUND_TABLE:
token_list.tokens.extend(
[
Token(Whitespace, " "),
Where([Token(Keyword, "WHERE"), Token(Whitespace, " "), rls]),
]
)

return token_list
Loading

0 comments on commit 8234395

Please sign in to comment.