Skip to content

Commit

Permalink
feat: add custom expressions
Browse files Browse the repository at this point in the history
Co-authored-by: Ghaith Kdimati <[email protected]>
  • Loading branch information
2 people authored and baraka95 committed Sep 27, 2024
1 parent 5bf4ed3 commit 94b076e
Show file tree
Hide file tree
Showing 5 changed files with 178 additions and 74 deletions.
45 changes: 43 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -52,13 +52,13 @@ class Item(Base):

__rls_policies__ = [
Permissive(
condition_args={
condition_args=[{
"comparator_name": "user_id",
"comparator_source": ComparatorSource.bearerTokenPayload,
"operation": Operation.equality,
"type": ExpressionTypes.integer,
"column_name": "owner_id",
},
}],
cmd=[Command.all],
)
]
Expand Down Expand Up @@ -92,6 +92,47 @@ async def get_users(db: AsyncSession = Session):

```

#### Custom expression

The user gives us a parametrized expression and array of conidition_args

```python
__rls_policies__ = [
Permissive(
condition_args=[
{
"comparator_name": "sub",
"comparator_source": ComparatorSource.bearerTokenPayload,
"operation": Operation.equality,
"type": ExpressionTypes.integer,
"column_name": "owner_id",
},
{
"comparator_name": "title",
"comparator_source": ComparatorSource.bearerTokenPayload,
"operation": Operation.equality,
"type": ExpressionTypes.text,
"column_name": "title",
},
{
"comparator_name": "description",
"comparator_source": ComparatorSource.bearerTokenPayload,
"operation": Operation.equality,
"type": ExpressionTypes.text,
"column_name": "description",
},
],
cmd=[Command.all],
expr= "{0} AND ({1} OR {2})",
)
]
```

you can pass multiple expressions and in the `expr` field specify their joining conditions.
#### `Note`:
- the valid logical joining operators as of now are `AND` and `OR`
- if no `expr` is given only the first policy in the array of `condition_args` is taken

#### Run Tests

No e2e tests done yet
Expand Down
106 changes: 60 additions & 46 deletions rls/database.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,19 @@
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import sessionmaker, Session
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, AsyncEngine
from sqlalchemy import text,Result
from sqlalchemy import text, Result
from sqlalchemy.sql import Executable
from sqlalchemy.sql.elements import TextClause
from sqlalchemy.engine import Engine

import once
import logging
import base64
import json
from typing import Optional, List, Union
from typing import Optional, List

from fastapi import Request

from rls.schemas import Policy, ComparatorSource
from rls.schemas import Policy, ComparatorSource, ExpressionTypes

# --------------------------------------Set RLS variables-----------------------------#

Expand Down Expand Up @@ -78,7 +77,7 @@ def _get_nested_value(dictionary: dict, keys: list[str]):
return value


def _get_set_statements(req: Request) -> Optional[TextClause]:
def _get_set_statements(req: Request) -> Optional[List[Executable]]:
# must be setup for the rls_policies key to be set
queries = []

Expand All @@ -91,42 +90,53 @@ def _get_set_statements(req: Request) -> Optional[TextClause]:
policies: list[Policy] = mapper.class_.__rls_policies__

for policy in policies:
db_var_name = policy.get_db_var_name(table_name)

comparator_name = policy.condition_args["comparator_name"]
comparator_name_split = comparator_name.split(".")

comparator_value = None
if (
policy.condition_args["comparator_source"]
== ComparatorSource.bearerTokenPayload
):
comparator_value = _get_nested_value(
_parse_bearer_token(req.headers.get("Authorization")),
comparator_name_split,
)
elif (
policy.condition_args["comparator_source"]
== ComparatorSource.requestUser
):
comparator_value = _get_nested_value(req.user, comparator_name_split)
elif policy.condition_args["comparator_source"] == ComparatorSource.header:
comparator_value = _get_nested_value(req.headers, comparator_name_split)
else:
raise ValueError("Invalid Comparator Source")

if comparator_value is None:
continue

temp_query = f"SET LOCAL {db_var_name} = {comparator_value};"

queries.append(temp_query)
for idx in range(len(policy.condition_args)):
db_var_name = policy.get_db_var_name(table_name, idx)

comparator_name = policy.condition_args[idx]["comparator_name"]
comparator_name_split = comparator_name.split(".")

comparator_value = None
if (
policy.condition_args[idx]["comparator_source"]
== ComparatorSource.bearerTokenPayload
):
comparator_value = _get_nested_value(
_parse_bearer_token(req.headers.get("Authorization")),
comparator_name_split,
)
elif (
policy.condition_args[idx]["comparator_source"]
== ComparatorSource.requestUser
):
comparator_value = _get_nested_value(
req.user, comparator_name_split
)
elif (
policy.condition_args["comparator_source"]
== ComparatorSource.header
):
comparator_value = _get_nested_value(
req.headers, comparator_name_split
)
else:
raise ValueError("Invalid Comparator Source")

if comparator_value is None:
continue

type = policy.condition_args[idx]["type"]
if type != ExpressionTypes.integer:
comparator_value = f"'{comparator_value}'"

temp_query = text(f"SET LOCAL {db_var_name} = {comparator_value};")

queries.append(temp_query)

if len(queries) == 0:
return None
combined_query = text("\n".join(queries)) # Combine queries into a single string

return combined_query
return queries


# --------------------------------------RlsBase Initialization-----------------------------#
Expand Down Expand Up @@ -182,7 +192,8 @@ def get_sync_session(request: Request):
with SyncSessionLocal() as session:
stmts = _get_set_statements(request)
if stmts is not None:
session.execute(stmts)
for stmt in stmts:
session.execute(stmt)
yield session


Expand All @@ -191,23 +202,26 @@ async def get_async_session(request: Request):
async with AsyncSessionLocal() as session:
stmts = _get_set_statements(request)
if stmts is not None:
await session.execute(stmts)
for stmt in stmts:
await session.execute(stmt)
yield session


async def bypass_rls_async(session:AsyncSession,stmts:List[Executable])->List[Result]:
results=[]
async def bypass_rls_async(
session: AsyncSession, stmts: List[Executable]
) -> List[Result]:
results = []
await session.execute(text("SET row_security=off;"))
for stmt in stmts:
results.append(await session.execute(stmt))
await session.execute(text("SET row_security=on;"))
return results
def bypass_rls_sync(session:Session,stmts:List[Executable])->List[Result]:
results=[]


def bypass_rls_sync(session: Session, stmts: List[Executable]) -> List[Result]:
results = []
session.execute(text("SET row_security=off;"))
for stmt in stmts:
results.append(session.execute(stmt))
session.execute(text("SET row_security=on;"))
return results
return results
61 changes: 48 additions & 13 deletions rls/schemas.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
from enum import Enum
from typing import List, Literal, Union, TypedDict
from typing import List, Literal, Union, TypedDict, Optional

from pydantic import BaseModel
from sqlalchemy import text

import re


class Command(str, Enum):
# policies: https://www.postgresql.org/docs/current/sql-createpolicy.html
Expand All @@ -14,6 +16,11 @@ class Command(str, Enum):
delete = "DELETE"


class LogicalOperator(str, Enum):
AND = "AND"
OR = "OR"


class Operation(str, Enum):
equality = "EQUALITY"

Expand Down Expand Up @@ -41,30 +48,58 @@ class ConditionArgs(TypedDict):

class Policy(BaseModel):
definition: str
condition_args: ConditionArgs
condition_args: List[ConditionArgs]
cmd: Union[Command, List[Command]]
expr: Optional[str]

def get_db_var_name(self, table_name):
return f"rls.{table_name}_{self.condition_args['column_name']}"
def get_db_var_name(self, table_name: str, idx: int = 0):
return f"rls.{table_name}_{self.condition_args[idx]['column_name']}"

def _get_expr_from_params(self, table_name: str):
variable_name = f"NULLIF(current_setting('{self.get_db_var_name(table_name)}', true),'')::{self.condition_args['type'].value}"
def _get_expr_from_params(self, table_name: str, idx: int = 0):
variable_name = f"NULLIF(current_setting('{self.get_db_var_name(table_name=table_name, idx=idx)}', true),'')::{self.condition_args[idx]['type'].value}"

expr = None
if self.condition_args["operation"] == "EQUALITY":
expr = f"{self.condition_args['column_name']} = {variable_name}"
if self.condition_args[idx]["operation"] == "EQUALITY":
expr = f"{self.condition_args[idx]['column_name']} = {variable_name}"

if expr is None:
raise ValueError(f"Unknown operation: {self.condition_args['operation']}")
raise ValueError(
f"Unknown operation: {self.condition_args[idx]['operation']}"
)

return expr

def _get_expr_from_custom_expr(self, table_name: str):
for idx in range(len(self.condition_args)):
pattern = rf"\{{{idx}\}}" # Escaped curly braces
parsed_expr = self._get_expr_from_params(table_name, idx)
self.expr = re.sub(pattern, parsed_expr, self.expr)
return self.expr

def _validate_joining_operations_in_expr(self):
# Pattern to match a number in curly braces followed by "AND" or "OR"
whole_pattern = r"\{(\d+)\}\s*(AND|OR)"

# Find all matches of the pattern in the expression
matches = re.findall(whole_pattern, self.expr)

# Extract the second group (AND/OR) from each match
operators = [match[1] for match in matches]

for operator in operators:
if operator not in LogicalOperator.__members__.values():
raise ValueError(f"Invalid logical operator: {operator}")

def get_sql_policies(self, table_name: str, name_suffix: str = "0"):
commands = [self.cmd] if isinstance(self.cmd, str) else self.cmd
expr = self._get_expr_from_params(table_name)

if self.expr is not None:
self.expr = self._get_expr_from_custom_expr(table_name)
else:
self.expr = self._get_expr_from_params(table_name)
policy_lists = []
print("&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&")
print(expr)
print(self.expr)
print("&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&")

for cmd in commands:
Expand All @@ -81,7 +116,7 @@ def get_sql_policies(self, table_name: str, name_suffix: str = "0"):
CREATE POLICY {policy_name} ON {table_name}
AS {self.definition}
FOR {cmd_value}
USING ({expr})
USING ({self.expr})
"""
)
)
Expand All @@ -92,7 +127,7 @@ def get_sql_policies(self, table_name: str, name_suffix: str = "0"):
CREATE POLICY {policy_name} ON {table_name}
AS {self.definition}
FOR {cmd_value}
WITH CHECK ({expr})
WITH CHECK ({self.expr})
"""
)
)
Expand Down
6 changes: 3 additions & 3 deletions test/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@


from rls.database import get_session, bypass_rls_async
from models import Item
from .models import Item

from engines import async_engine as db_engine
from .engines import async_engine as db_engine


app = FastAPI()
Expand All @@ -26,6 +26,6 @@ async def get_users(db: AsyncSession = Session):
@app.get("/admin/items")
async def get_items(db: AsyncSession = Session):
stmt = select(Item)
results=await bypass_rls_async(db,[stmt])
results = await bypass_rls_async(db, [stmt])
items = results[0].scalars().all()
return items
Loading

0 comments on commit 94b076e

Please sign in to comment.