Skip to content

Commit

Permalink
feat: update rls api
Browse files Browse the repository at this point in the history
- change the way expressions are passed in the policy
- set the session variables automatically from header variables

Co-authored-by: Ghaith Kdimati <[email protected]>
  • Loading branch information
Ofahmy143 and Ghaithq committed Sep 18, 2024
1 parent bf0cc9f commit ffd3ef6
Show file tree
Hide file tree
Showing 6 changed files with 156 additions and 61 deletions.
99 changes: 99 additions & 0 deletions rls/database.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import sessionmaker, Session
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
from sqlalchemy.orm import sessionmaker
from rls.schemas import Policy
from fastapi import Request
from fastapi.datastructures import Headers
from sqlalchemy import text, create_engine

#--------------------------------------Set RLS variables-----------------------------#

def _get_set_statements(headers: Headers) -> str:
# must be setup for the rls_policies key to be set
queries = []

Base.metadata.info.setdefault("rls_policies", dict())
for mapper in Base.registry.mappers:
if not hasattr(mapper.class_, "__rls_policies__"):
continue
table_name = mapper.tables[0].fullname
policies: list[Policy] = mapper.class_.__rls_policies__

for policy in policies:
db_var_name = policy.get_db_var_name(table_name)
comparator_name = policy.condition_args["comparator_name"]
comparator_value = headers.get(comparator_name)

if(comparator_value is None):
continue

temp_query = f"SET LOCAL {db_var_name} = {comparator_value};"

queries.append(temp_query)


if(len(queries) == 0):
return ""
combined_query = text('\n'.join(queries)) # Combine queries into a single string

return combined_query



#--------------------------------------Base Initialization-----------------------------#


# Base class for our models
Base = declarative_base()


#--------------------------------------Engines Initialization-----------------------------#


ASYNC_DATABASE_URL = "postgresql+asyncpg://my_user:secure_password@localhost/session"

SYNC_DATABASE_URL = "postgresql://my_user:secure_password@localhost/session"


# TODO: DATABASE_URL should be an environment variable be it SYNC or ASYNC
# SQLAlchemy engine
async_engine = create_async_engine(ASYNC_DATABASE_URL)

# SQLAlchemy session
AsyncSessionLocal = sessionmaker(
bind=async_engine,
class_=AsyncSession, # Specify that the session should be async
expire_on_commit=False,
)


sync_engine = create_engine(SYNC_DATABASE_URL)

SyncSessionLocal = sessionmaker(
bind=sync_engine,
class_= Session,
expire_on_commit=False,
)



#--------------------------------------Deps injection functions-----------------------------#


def get_sync_session(request: Request):
with SyncSessionLocal() as session:
stmts = _get_set_statements(request.headers)
if(stmts != ""):
session.execute(stmts)
yield session



# Dependency to get DB session in routes
async def get_async_session(request: Request):
async with AsyncSessionLocal() as session:
stmts = _get_set_statements(request.headers)
if(stmts != ""):
await session.execute(stmts)
yield session
3 changes: 2 additions & 1 deletion rls/register_rls.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,19 +6,20 @@
from .create_policies import create_policies



def set_metadata_info(Base: Type[DeclarativeMeta]):
"""RLS policies are first added to the Metadata before applied."""
Base.metadata.info.setdefault("rls_policies", dict())
for mapper in Base.registry.mappers:
if not hasattr(mapper.class_, "__rls_policies__"):
continue

Base.metadata.info["rls_policies"][
mapper.tables[0].fullname
] = mapper.class_.__rls_policies__




def register_rls(Base: Type[DeclarativeMeta]):

# required for `alembic revision --autogenerate``
Expand Down
54 changes: 42 additions & 12 deletions rls/schemas.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from enum import Enum
from typing import List, Literal, Union
from typing import List, Literal, Union, TypedDict

from pydantic import BaseModel
from sqlalchemy import text
Expand All @@ -13,24 +13,54 @@ class Command(str, Enum):
update = "UPDATE"
delete = "DELETE"

class Operation(str, Enum):
equality = "EQUALITY"

class ExpressionTypes(str, Enum):
integer = "INTEGER"
uuid = "UUID"
text = "TEXT"
boolean = "BOOLEAN"


class ConditionArgs(TypedDict):
comparator_name: str
operation: Operation
type: ExpressionTypes
column_name: str


class Policy(BaseModel):
definition: str
expr: str
condition_args: ConditionArgs
cmd: Union[Command, List[Command]]

def get_sql_policies(self, table_name: str, name_suffix: str = "0"):
commands = [self.cmd] if isinstance(self.cmd, str) else self.cmd

print("%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%")
print('definition', self.definition)
print('expr', self.expr)
print('commands', (commands[0].value))
print('table_name', table_name)
print("%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%")
def get_db_var_name(self, table_name):
return f"rls.{table_name}_{self.condition_args['column_name']}"

def _get_expr_from_params(self, table_name: str):

variable_name = f"NULLIF(current_setting('{self.get_db_var_name(table_name)}', true),'')::{self.condition_args['type'].value}"

expr = None
if self.condition_args["operation"] == "EQUALITY":
expr = f"{self.condition_args['column_name']} = {variable_name}"

if expr is None:
raise ValueError(f"Unknown operation: {self.condition_args['operation']}")

return expr


def get_sql_policies(self, table_name: str, name_suffix: str = "0"):
commands = [self.cmd] if isinstance(self.cmd, str) else self.cmd
expr = self._get_expr_from_params(table_name)
policy_lists = []
print("&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&")
print(expr)
print("&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&")

for cmd in commands:
cmd_value = cmd.value if isinstance(cmd, Command) else cmd
policy_name = (
Expand All @@ -45,7 +75,7 @@ def get_sql_policies(self, table_name: str, name_suffix: str = "0"):
CREATE POLICY {policy_name} ON {table_name}
AS {self.definition}
FOR {cmd_value}
USING ({self.expr})
USING ({expr})
"""
)
)
Expand All @@ -56,7 +86,7 @@ def get_sql_policies(self, table_name: str, name_suffix: str = "0"):
CREATE POLICY {policy_name} ON {table_name}
AS {self.definition}
FOR {cmd_value}
WITH CHECK ({self.expr})
WITH CHECK ({expr})
"""
)
)
Expand Down
26 changes: 0 additions & 26 deletions test/database.py

This file was deleted.

24 changes: 5 additions & 19 deletions test/main.py
Original file line number Diff line number Diff line change
@@ -1,33 +1,19 @@
from fastapi import FastAPI, Depends
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
from sqlalchemy import text

from database import get_db
from rls.database import get_async_session
from models import Item


app = FastAPI()

@app.get("/users/{user_id}/items")
async def get_users(user_id, db: AsyncSession = Depends(get_db)):
async def get_users(user_id, db: AsyncSession = Depends(get_async_session)):

items = None
async with db as session:
print('session', session)
stmt = text(f"SET LOCAL app.current_user_id = {user_id}")
print('stmt', stmt)
await db.execute(stmt)

# Retrieve the parameter value
get_param_stmt = text("SELECT current_setting('app.current_user_id')")
result = await db.execute(get_param_stmt)
current_user_id = result.scalar() # Get the single value
print('after: current_user_id', current_user_id)

stmt = select(Item)
result = await db.execute(stmt)
items = result.scalars().all()
stmt = select(Item)
result = await db.execute(stmt)
items = result.scalars().all()


return items
11 changes: 8 additions & 3 deletions test/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from sqlalchemy.ext.declarative import declarative_base
from rls.schemas import Permissive
from rls.register_rls import register_rls
Base = declarative_base()
from rls.database import Base

# Register RLS policies
register_rls(Base)
Expand All @@ -26,5 +26,10 @@ class Item(Base):
owner = relationship("User")

__rls_policies__ = [
Permissive(expr=f"owner_id = current_setting('app.current_user_id')::integer", cmd=["SELECT","INSERT", "UPDATE", "DELETE"])
]
Permissive(condition_args={
"comparator_name": "user_id",
"operation": "EQUALITY",
"type": "INTEGER",
"column_name": "owner_id"
}, cmd=["SELECT","INSERT", "UPDATE", "DELETE"])
]

0 comments on commit ffd3ef6

Please sign in to comment.