Skip to content

Commit

Permalink
feat: add bypassing rls
Browse files Browse the repository at this point in the history
Co-authored-by: Omar Fahmy <[email protected]>
  • Loading branch information
2 people authored and baraka95 committed Sep 27, 2024
1 parent e8240b6 commit 5bf4ed3
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 5 deletions.
23 changes: 21 additions & 2 deletions rls/database.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import sessionmaker, Session
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, AsyncEngine
from sqlalchemy import text
from sqlalchemy import text,Result
from sqlalchemy.sql import Executable
from sqlalchemy.sql.elements import TextClause
from sqlalchemy.engine import Engine

import once
import logging
import base64
import json
from typing import Optional
from typing import Optional, List, Union

from fastapi import Request

Expand Down Expand Up @@ -192,3 +193,21 @@ async def get_async_session(request: Request):
if stmts is not None:
await session.execute(stmts)
yield session


async def bypass_rls_async(session:AsyncSession,stmts:List[Executable])->List[Result]:
results=[]
await session.execute(text("SET row_security=off;"))
for stmt in stmts:
results.append(await session.execute(stmt))
await session.execute(text("SET row_security=on;"))
return results


def bypass_rls_sync(session:Session,stmts:List[Executable])->List[Result]:
results=[]
session.execute(text("SET row_security=off;"))
for stmt in stmts:
results.append(session.execute(stmt))
session.execute(text("SET row_security=on;"))
return results
14 changes: 11 additions & 3 deletions test/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@
from sqlalchemy.future import select


from rls.database import get_session
from test.models import Item
from rls.database import get_session, bypass_rls_async
from models import Item

from .engines import async_engine as db_engine
from engines import async_engine as db_engine


app = FastAPI()
Expand All @@ -21,3 +21,11 @@ async def get_users(db: AsyncSession = Session):
items = result.scalars().all()

return items


@app.get("/admin/items")
async def get_items(db: AsyncSession = Session):
stmt = select(Item)
results=await bypass_rls_async(db,[stmt])
items = results[0].scalars().all()
return items

0 comments on commit 5bf4ed3

Please sign in to comment.