Skip to content

Commit

Permalink
refactor: removed unnecessary code
Browse files Browse the repository at this point in the history
Co-authored-by: Omar Fahmy <[email protected]>
  • Loading branch information
Ghaithq and Ofahmy143 committed Oct 6, 2024
1 parent 46c24c2 commit d4a3306
Show file tree
Hide file tree
Showing 8 changed files with 2 additions and 108 deletions.
9 changes: 0 additions & 9 deletions rls/principals.py

This file was deleted.

15 changes: 2 additions & 13 deletions rls/rls_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,6 @@ def bypass_rls(self):
"""
return self.BypassRLSContext(self)

def setContext(self, context: BaseModel):
self.context = context

def _get_set_statements(self):
"""
Generates SQL SET statements based on the context model.
Expand All @@ -39,11 +36,8 @@ def _execute_set_statements(self):
Executes the RLS SET statements unless bypassing RLS.
"""
if self._rls_bypass: # Skip setting RLS when bypassing
print("Bypassing RLS")
return
print("Setting RLS")
stmts = self._get_set_statements()
print("stmts:", stmts)
if stmts is not None:
for stmt in stmts:
super().execute(stmt)
Expand Down Expand Up @@ -76,8 +70,7 @@ def __enter__(self):
try:
# Disable row-level security
self.session.execute(text("SET LOCAL rls.bypass_rls = true;"))
except Exception as e:
print(f"Failed to disable row-level security: {e}")
except Exception:
self.session._rls_bypass = False # Disable bypass flag to avoid issues

# Rollback transaction to avoid failed state
Expand All @@ -93,16 +86,12 @@ def __exit__(self, exc_type, exc_val, exc_tb):

# If the transaction failed, skip re-enabling RLS
if exc_type is not None:
print(f"Skipping re-enabling RLS due to prior error: {exc_val}")
self.session.rollback()
return

try:
# Re-enable row-level security
self.session.execute(text("SET LOCAL rls.bypass_rls = false;"))
except Exception as e:
print(f"Failed to re-enable row-level security: {e}")

except Exception:
# Optionally rollback if there's a failure
self.session.rollback()

Expand Down
2 changes: 0 additions & 2 deletions rls/rls_sessioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,6 @@ def __init__(self, sessionmaker: SessionMaker, context_getter: ContextGetter):
def __call__(
self, *args: Optional[Any], **kwargs: Optional[Any]
): # Get context from the context getter
print("Kwargs:", kwargs)
print("Args:", args)
context = self.context_getter.get_context(*args, **kwargs)
session = self.session_maker(context=context)
try:
Expand Down
5 changes: 0 additions & 5 deletions rls/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,6 @@ def generate_rls_policy(cmd, definition, policy_name, table_name, expr) -> TextC
"CAST(NULLIF(current_setting('rls.bypass_rls', true), '') AS BOOLEAN) = true"
)
expr = f"(({expr}) OR {bypass_rls_expr})"

print("%%%%%%%%%%%%%%%%%%%%%%%%%")
print("expr", expr)
print("%%%%%%%%%%%%%%%%%%%%%%%%%")

if cmd in ["ALL", "SELECT", "UPDATE", "DELETE"]:
return text(f"""
CREATE POLICY {policy_name} ON {table_name}
Expand Down
15 changes: 0 additions & 15 deletions test/...py

This file was deleted.

40 changes: 0 additions & 40 deletions test/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,43 +16,3 @@
async def get_users(db=my_session):
result = db.execute(text("SELECT * FROM users")).all()
return dict(result)


# @app.get("/users/items/1")
# async def get_items_1(db: AsyncSession = Session):
# stmt = select(Item)
# result = await db.execute(stmt)
# items = result.scalars().all()

# return items


# @app.get("/users/items/2")
# async def get_items_2(db: AsyncSession = Session):
# stmt = select(Item1)
# result = db.execute(stmt)
# items = result.scalars().all()

# print("**************************************")
# print(items)
# print("**************************************")


# return items


# @app.get("/users/items/3")
# async def get_items_3(db: AsyncSession = Session):
# stmt = select(Item2)
# result = db.execute(stmt)
# items = result.scalars().all()

# return items


# @app.get("/admin/items")
# async def get_items(db: AsyncSession = Session):
# stmt = select(Item)
# results = bypass_rls_async(db, [stmt])
# items = results[0].scalars().all()
# return items
21 changes: 0 additions & 21 deletions test/setup.py

This file was deleted.

3 changes: 0 additions & 3 deletions test/test_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,3 @@ class MyContext(BaseModel):
with session.bypass_rls() as session:
res2 = session.execute(text("SELECT * FROM items")).fetchall()
print("res2:", res2)


# TODO: in init must create a bypass_rls_role that is super or has bypass rls privilege amongst with most others

0 comments on commit d4a3306

Please sign in to comment.