Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
Ghaithq committed Oct 6, 2024
2 parents 54229cb + fbffce6 commit 29a4e1d
Show file tree
Hide file tree
Showing 12 changed files with 39 additions and 147 deletions.
6 changes: 6 additions & 0 deletions .github/dependabot.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
version: 2
updates:
- package-ecosystem: github-actions
directory: "/"
schedule:
interval: monthly
6 changes: 3 additions & 3 deletions .github/workflows/release.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,12 @@ jobs:
pull-requests: write
steps:
- name: Checkout code
uses: actions/checkout@v3
uses: actions/checkout@v4
with:
fetch-depth: 0

- name: Set up Python
uses: actions/setup-python@v4
uses: actions/setup-python@v5
with:
python-version: 3.11

Expand All @@ -32,7 +32,7 @@ jobs:
poetry install
- name: Python Semantic Release
uses: python-semantic-release/python-semantic-release@v9.8.8
uses: python-semantic-release/python-semantic-release@v9.9.0
with:
github_token: ${{ secrets.GITHUB_TOKEN }}
push: true
Expand Down
50 changes: 27 additions & 23 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,35 +7,39 @@ authors = [
]
license = "MIT"
readme = "README.md"
packages = [
{include = "rls"},
{include = "rls/py.typed"},
]

[tool.poetry.dependencies]
python = ">=3.11"
annotated-types = "0.7.0"
anyio = "4.4.0"
greenlet = "3.1.0"
idna = "3.10"
psycopg2 = "2.9.9"
pydantic = "2.9.1"
pydantic-core = "2.23.3"
sniffio = "1.3.1"
sqlalchemy = "2.0.34"
typing-extensions = "4.12.2"
asyncpg = "^0.29.0"
alembic = "^1.13.3"
annotated-types = ">=0.7.0"
anyio = ">=4.4.0"
greenlet = ">=3.1.0"
idna = ">=3.10"
psycopg2 = ">=2.9.9"
pydantic = ">=2.9.1"
pydantic-core = ">=2.23.3"
sniffio = ">=1.3.1"
sqlalchemy = ">=2.0.34"
typing-extensions = ">=4.12.2"
asyncpg = ">=0.29.0"
alembic = ">=1.13.3"


[tool.poetry.group.dev.dependencies]
mypy = "^1.11.2"
pre-commit = "^3.8.0"
uvicorn = "^0.30.6"
httpx = "^0.27.2"
pytest = "^8.3.3"
deptry = "^0.20.0"
pytest-asyncio = "^0.24.0"
pytest-xdist = "^3.6.1"
requests = "^2.32.3"
alembic = "^1.13.3"
fastapi = "^0.115.0"
mypy = ">=1.11.2"
pre-commit = ">=3.8.0"
uvicorn = ">=0.30.6"
httpx = ">=0.27.2"
pytest = ">=8.3.3"
deptry = ">=0.20.0"
pytest-asyncio = ">=0.24.0"
pytest-xdist = ">=3.6.1"
requests = ">=2.32.3"
alembic = ">=1.13.3"
fastapi = ">=0.115.0"



Expand Down
9 changes: 0 additions & 9 deletions rls/principals.py

This file was deleted.

Empty file added rls/py.typed
Empty file.
29 changes: 3 additions & 26 deletions rls/rls_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,6 @@ def bypass_rls(self):
"""
return self.BypassRLSContext(self)

def setContext(self, context: BaseModel):
self.context = context

def _get_set_statements(self):
"""
Generates SQL SET statements based on the context model.
Expand All @@ -39,11 +36,8 @@ def _execute_set_statements(self):
Executes the RLS SET statements unless bypassing RLS.
"""
if self._rls_bypass: # Skip setting RLS when bypassing
print("Bypassing RLS")
return
print("Setting RLS")
stmts = self._get_set_statements()
print("stmts:", stmts)
if stmts is not None:
for stmt in stmts:
super().execute(stmt)
Expand Down Expand Up @@ -73,16 +67,8 @@ def __enter__(self):
If the command fails, rollback the transaction.
"""
self.session._rls_bypass = True
try:
# Disable row-level security
self.session.execute(text("SET LOCAL rls.bypass_rls = true;"))
except Exception as e:
print(f"Failed to disable row-level security: {e}")
self.session._rls_bypass = False # Disable bypass flag to avoid issues

# Rollback transaction to avoid failed state
self.session.rollback()
raise # Re-raise the exception to stop further execution
# Disable row-level security
self.session.execute(text("SET LOCAL rls.bypass_rls = true;"))
return self.session

def __exit__(self, exc_type, exc_val, exc_tb):
Expand All @@ -93,18 +79,9 @@ def __exit__(self, exc_type, exc_val, exc_tb):

# If the transaction failed, skip re-enabling RLS
if exc_type is not None:
print(f"Skipping re-enabling RLS due to prior error: {exc_val}")
self.session.rollback()
return

try:
# Re-enable row-level security
self.session.execute(text("SET LOCAL rls.bypass_rls = false;"))
except Exception as e:
print(f"Failed to re-enable row-level security: {e}")

# Optionally rollback if there's a failure
self.session.rollback()
self.session.execute(text("SET LOCAL rls.bypass_rls = false;"))

def execute(self, *args, **kwargs):
return self.session.execute(*args, **kwargs)
2 changes: 0 additions & 2 deletions rls/rls_sessioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,6 @@ def __init__(self, sessionmaker: SessionMaker, context_getter: ContextGetter):
def __call__(
self, *args: Optional[Any], **kwargs: Optional[Any]
): # Get context from the context getter
print("Kwargs:", kwargs)
print("Args:", args)
context = self.context_getter.get_context(*args, **kwargs)
session = self.session_maker(context=context)
try:
Expand Down
5 changes: 0 additions & 5 deletions rls/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,6 @@ def generate_rls_policy(cmd, definition, policy_name, table_name, expr) -> TextC
"CAST(NULLIF(current_setting('rls.bypass_rls', true), '') AS BOOLEAN) = true"
)
expr = f"(({expr}) OR {bypass_rls_expr})"

print("%%%%%%%%%%%%%%%%%%%%%%%%%")
print("expr", expr)
print("%%%%%%%%%%%%%%%%%%%%%%%%%")

if cmd in ["ALL", "SELECT", "UPDATE", "DELETE"]:
return text(f"""
CREATE POLICY {policy_name} ON {table_name}
Expand Down
15 changes: 0 additions & 15 deletions test/...py

This file was deleted.

40 changes: 0 additions & 40 deletions test/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,43 +16,3 @@
async def get_users(db=my_session):
result = db.execute(text("SELECT * FROM users")).all()
return dict(result)


# @app.get("/users/items/1")
# async def get_items_1(db: AsyncSession = Session):
# stmt = select(Item)
# result = await db.execute(stmt)
# items = result.scalars().all()

# return items


# @app.get("/users/items/2")
# async def get_items_2(db: AsyncSession = Session):
# stmt = select(Item1)
# result = db.execute(stmt)
# items = result.scalars().all()

# print("**************************************")
# print(items)
# print("**************************************")


# return items


# @app.get("/users/items/3")
# async def get_items_3(db: AsyncSession = Session):
# stmt = select(Item2)
# result = db.execute(stmt)
# items = result.scalars().all()

# return items


# @app.get("/admin/items")
# async def get_items(db: AsyncSession = Session):
# stmt = select(Item)
# results = bypass_rls_async(db, [stmt])
# items = results[0].scalars().all()
# return items
21 changes: 0 additions & 21 deletions test/setup.py

This file was deleted.

3 changes: 0 additions & 3 deletions test/test_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,3 @@ class MyContext(BaseModel):
with session.bypass_rls() as session:
res2 = session.execute(text("SELECT * FROM items")).fetchall()
print("res2:", res2)


# TODO: in init must create a bypass_rls_role that is super or has bypass rls privilege amongst with most others

0 comments on commit 29a4e1d

Please sign in to comment.