-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- change the way expressions are passed in the policy - set the session variables automatically from header variables Co-authored-by: Ghaith Kdimati <[email protected]>
- Loading branch information
Showing
6 changed files
with
156 additions
and
61 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,99 @@ | ||
from sqlalchemy.ext.declarative import declarative_base | ||
from sqlalchemy.orm import sessionmaker, Session | ||
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine | ||
from sqlalchemy.orm import sessionmaker | ||
from rls.schemas import Policy | ||
from fastapi import Request | ||
from fastapi.datastructures import Headers | ||
from sqlalchemy import text, create_engine | ||
|
||
#--------------------------------------Set RLS variables-----------------------------# | ||
|
||
def _get_set_statements(headers: Headers) -> str: | ||
# must be setup for the rls_policies key to be set | ||
queries = [] | ||
|
||
Base.metadata.info.setdefault("rls_policies", dict()) | ||
for mapper in Base.registry.mappers: | ||
if not hasattr(mapper.class_, "__rls_policies__"): | ||
continue | ||
table_name = mapper.tables[0].fullname | ||
policies: list[Policy] = mapper.class_.__rls_policies__ | ||
|
||
for policy in policies: | ||
db_var_name = policy.get_db_var_name(table_name) | ||
comparator_name = policy.condition_args["comparator_name"] | ||
comparator_value = headers.get(comparator_name) | ||
|
||
if(comparator_value is None): | ||
continue | ||
|
||
temp_query = f"SET LOCAL {db_var_name} = {comparator_value};" | ||
|
||
queries.append(temp_query) | ||
|
||
|
||
if(len(queries) == 0): | ||
return "" | ||
combined_query = text('\n'.join(queries)) # Combine queries into a single string | ||
|
||
return combined_query | ||
|
||
|
||
|
||
#--------------------------------------Base Initialization-----------------------------# | ||
|
||
|
||
# Base class for our models | ||
Base = declarative_base() | ||
|
||
|
||
#--------------------------------------Engines Initialization-----------------------------# | ||
|
||
|
||
ASYNC_DATABASE_URL = "postgresql+asyncpg://my_user:secure_password@localhost/session" | ||
|
||
SYNC_DATABASE_URL = "postgresql://my_user:secure_password@localhost/session" | ||
|
||
|
||
# TODO: DATABASE_URL should be an environment variable be it SYNC or ASYNC | ||
# SQLAlchemy engine | ||
async_engine = create_async_engine(ASYNC_DATABASE_URL) | ||
|
||
# SQLAlchemy session | ||
AsyncSessionLocal = sessionmaker( | ||
bind=async_engine, | ||
class_=AsyncSession, # Specify that the session should be async | ||
expire_on_commit=False, | ||
) | ||
|
||
|
||
sync_engine = create_engine(SYNC_DATABASE_URL) | ||
|
||
SyncSessionLocal = sessionmaker( | ||
bind=sync_engine, | ||
class_= Session, | ||
expire_on_commit=False, | ||
) | ||
|
||
|
||
|
||
#--------------------------------------Deps injection functions-----------------------------# | ||
|
||
|
||
def get_sync_session(request: Request): | ||
with SyncSessionLocal() as session: | ||
stmts = _get_set_statements(request.headers) | ||
if(stmts != ""): | ||
session.execute(stmts) | ||
yield session | ||
|
||
|
||
|
||
# Dependency to get DB session in routes | ||
async def get_async_session(request: Request): | ||
async with AsyncSessionLocal() as session: | ||
stmts = _get_set_statements(request.headers) | ||
if(stmts != ""): | ||
await session.execute(stmts) | ||
yield session |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,33 +1,19 @@ | ||
from fastapi import FastAPI, Depends | ||
from sqlalchemy.ext.asyncio import AsyncSession | ||
from sqlalchemy.future import select | ||
from sqlalchemy import text | ||
|
||
from database import get_db | ||
from rls.database import get_async_session | ||
from models import Item | ||
|
||
|
||
app = FastAPI() | ||
|
||
@app.get("/users/{user_id}/items") | ||
async def get_users(user_id, db: AsyncSession = Depends(get_db)): | ||
async def get_users(user_id, db: AsyncSession = Depends(get_async_session)): | ||
|
||
items = None | ||
async with db as session: | ||
print('session', session) | ||
stmt = text(f"SET LOCAL app.current_user_id = {user_id}") | ||
print('stmt', stmt) | ||
await db.execute(stmt) | ||
|
||
# Retrieve the parameter value | ||
get_param_stmt = text("SELECT current_setting('app.current_user_id')") | ||
result = await db.execute(get_param_stmt) | ||
current_user_id = result.scalar() # Get the single value | ||
print('after: current_user_id', current_user_id) | ||
|
||
stmt = select(Item) | ||
result = await db.execute(stmt) | ||
items = result.scalars().all() | ||
stmt = select(Item) | ||
result = await db.execute(stmt) | ||
items = result.scalars().all() | ||
|
||
|
||
return items |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters