Skip to content

Commit

Permalink
feat: update bypass rls execution and move alembic files to test
Browse files Browse the repository at this point in the history
Co-authored-by: Ghaith Kdimati <[email protected]>"
  • Loading branch information
Ofahmy143 committed Oct 1, 2024
1 parent f7e02b2 commit 7f3b711
Show file tree
Hide file tree
Showing 11 changed files with 205 additions and 104 deletions.
36 changes: 0 additions & 36 deletions alembic/versions/22db913c0bf7_add_items_table.py

This file was deleted.

30 changes: 8 additions & 22 deletions rls/alembic_rls.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from alembic.operations import MigrateOperation, Operations
from sqlalchemy import text
from sqlalchemy.ext.declarative import DeclarativeMeta
from .utils import generate_rls_policy


############################
Expand Down Expand Up @@ -272,30 +273,15 @@ def create_policy(operations, operation):
cmd = operation.cmd
expr = operation.expr

sql = ""
# Generate the SQL to create the policy
if cmd in ["ALL", "SELECT", "UPDATE", "DELETE"]:
sql = text(
f"""
CREATE POLICY {policy_name} ON {table_name}
AS {definition}
FOR {cmd}
USING ({expr})
"""
)

elif cmd in ["INSERT"]:
sql = text(
f"""
CREATE POLICY {policy_name} ON {table_name}
AS {definition}
FOR {cmd}
WITH CHECK ({expr})
"""
)

else:
raise ValueError(f'Unknown policy command"{cmd}"')
sql = generate_rls_policy(
cmd=cmd,
definition=definition,
policy_name=policy_name,
table_name=table_name,
expr=expr,
)

operations.execute(sql)

Expand Down
77 changes: 73 additions & 4 deletions rls/rls_session.py
Original file line number Diff line number Diff line change
@@ -1,31 +1,49 @@
from sqlalchemy.orm.session import Session
from sqlalchemy import text
from sqlalchemy.orm import Session
from pydantic import BaseModel
from sqlalchemy import text
from typing import Optional


class RlsSession(Session):
def __init__(self, context: Optional[BaseModel], *args, **kwargs):
super().__init__(*args, **kwargs)
self._rls_bypass = False # Track RLS bypass state
if context is not None:
self.context = context

def bypass_rls(self):
"""
Context manager to bypass RLS.
Usage: with session.bypass_rls() as session:
"""
return self.BypassRLSContext(self)

def setContext(self, context: BaseModel):
self.context = context

def _get_set_statements(self):
"""
Generates SQL SET statements based on the context model.
"""
stmts = []
if self.context is None:
if self.context is None or self._rls_bypass: # Skip RLS statements if bypassed
return None

for key, value in self.context.model_dump().items():
# TODO: ask whether we set the rls. or not
stmt = text(f"SET rls.{key} = {value};")
stmts.append(stmt)
return stmts

def _execute_set_statements(self):
"""
Executes the RLS SET statements unless bypassing RLS.
"""
if self._rls_bypass: # Skip setting RLS when bypassing
print("Bypassing RLS")
return
print("Setting RLS")
stmts = self._get_set_statements()
print("stmts:", stmts)
if stmts is not None:
for stmt in stmts:
super().execute(stmt)
Expand All @@ -37,5 +55,56 @@ def set_context(self, context):
self.context = context

def execute(self, *args, **kwargs):
"""
Executes SQL queries, applying RLS unless bypassing.
"""
self._execute_set_statements()
return super().execute(*args, **kwargs)

# Inner class for the context manager
# Updated BypassRLSContext to handle errors
class BypassRLSContext:
def __init__(self, session: "RlsSession"):
self.session = session

def __enter__(self):
"""
When entering the context, attempt to bypass RLS.
If the command fails, rollback the transaction.
"""
self.session._rls_bypass = True
try:
# Disable row-level security
self.session.execute(text("SET LOCAL rls.bypass_rls = true;"))
except Exception as e:
print(f"Failed to disable row-level security: {e}")
self.session._rls_bypass = False # Disable bypass flag to avoid issues

# Rollback transaction to avoid failed state
self.session.rollback()
raise # Re-raise the exception to stop further execution
return self.session

def __exit__(self, exc_type, exc_val, exc_tb):
"""
When exiting the context, restore RLS if it was successfully disabled.
"""
self.session._rls_bypass = False

# If the transaction failed, skip re-enabling RLS
if exc_type is not None:
print(f"Skipping re-enabling RLS due to prior error: {exc_val}")
self.session.rollback()
return

try:
# Re-enable row-level security
self.session.execute(text("SET LOCAL rls.bypass_rls = false;"))
except Exception as e:
print(f"Failed to re-enable row-level security: {e}")

# Optionally rollback if there's a failure
self.session.rollback()

def execute(self, *args, **kwargs):
return self.session.execute(*args, **kwargs)
54 changes: 16 additions & 38 deletions rls/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
from typing import List, Literal, Union, TypedDict, Optional, NotRequired

from pydantic import BaseModel
from sqlalchemy import text
from .utils import generate_rls_policy
from sqlalchemy.sql import sqltypes

import re

Expand Down Expand Up @@ -31,16 +32,9 @@ class Operation(str, Enum):
like = "LIKE"


class ExpressionTypes(str, Enum):
integer = "INTEGER"
uuid = "UUID"
text = "TEXT"
boolean = "BOOLEAN"


class ConditionArgs(TypedDict):
comparator_name: str
type: ExpressionTypes
type: sqltypes
operation: NotRequired[Optional[Operation]] = None
column_name: NotRequired[Optional[str]] = None

Expand All @@ -56,8 +50,12 @@ class Policy(BaseModel):
__expr: str = None
__policy_suffix: int = 0

class Config:
arbitrary_types_allowed = True

def _get_safe_variable_name(self, idx: int = 0):
return f"NULLIF(current_setting('rls.{self.condition_args[idx]["comparator_name"]}', true),'')::{self.condition_args[idx]['type'].value}"
type_str = self.condition_args[idx]["type"].__visit_name__.upper()
return f"NULLIF(current_setting('rls.{self.condition_args[idx]["comparator_name"]}', true),'')::{type_str}"

def _get_expr_from_params(self, table_name: str, idx: int = 0):
safe_variable_name = self._get_safe_variable_name(idx=idx)
Expand Down Expand Up @@ -146,10 +144,6 @@ def get_sql_policies(self, table_name: str, name_suffix: str = "0"):

policy_lists = []

print("&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&")
print(self.__expr)
print("&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&")

for cmd in commands:
cmd_value = cmd.value if isinstance(cmd, Command) else cmd
policy_name = (
Expand All @@ -158,30 +152,14 @@ def get_sql_policies(self, table_name: str, name_suffix: str = "0"):
)
self.__policy_names.append(policy_name)

if cmd_value in ["ALL", "SELECT", "UPDATE", "DELETE"]:
policy_lists.append(
text(
f"""
CREATE POLICY {policy_name} ON {table_name}
AS {self.definition}
FOR {cmd_value}
USING ({self.__expr})
"""
)
)
elif cmd in ["INSERT"]:
policy_lists.append(
text(
f"""
CREATE POLICY {policy_name} ON {table_name}
AS {self.definition}
FOR {cmd_value}
WITH CHECK ({self.__expr})
"""
)
)
else:
raise ValueError(f'Unknown policy command"{cmd_value}"')
generated_policy = generate_rls_policy(
cmd=cmd_value,
definition=self.definition,
policy_name=policy_name,
table_name=table_name,
expr=self.__expr,
)
policy_lists.append(generated_policy)
return policy_lists


Expand Down
31 changes: 31 additions & 0 deletions rls/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
from sqlalchemy import text, TextClause


def generate_rls_policy(cmd, definition, policy_name, table_name, expr) -> TextClause:
bypass_rls_expr = (
"CAST(NULLIF(current_setting('rls.bypass_rls', true), '') AS BOOLEAN) = true"
)
expr = f"(({expr}) OR {bypass_rls_expr})"

print("%%%%%%%%%%%%%%%%%%%%%%%%%")
print("expr", expr)
print("%%%%%%%%%%%%%%%%%%%%%%%%%")

if cmd in ["ALL", "SELECT", "UPDATE", "DELETE"]:
return text(f"""
CREATE POLICY {policy_name} ON {table_name}
AS {definition}
FOR {cmd}
USING ({expr})
""")

elif cmd in ["INSERT"]:
return text(f"""
CREATE POLICY {policy_name} ON {table_name}
AS {definition}
FOR {cmd}
WITH CHECK ({expr})
""")

else:
raise ValueError(f'Unknown policy command"{cmd}"')
File renamed without changes.
File renamed without changes.
File renamed without changes.
67 changes: 67 additions & 0 deletions test/alembic/versions/4e47974b6aaf_initial_updates.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
"""initial updates
Revision ID: 4e47974b6aaf
Revises:
Create Date: 2024-09-30 12:55:59.618503
"""

from typing import Sequence, Union

from alembic import op
import sqlalchemy as sa


# revision identifiers, used by Alembic.
revision: str = "4e47974b6aaf"
down_revision: Union[str, None] = None
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None


def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.create_table(
"users",
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("username", sa.String(), nullable=True),
sa.PrimaryKeyConstraint("id"),
)
op.create_index(op.f("ix_users_id"), "users", ["id"], unique=False)
op.create_index(op.f("ix_users_username"), "users", ["username"], unique=True)
op.create_table(
"items",
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("title", sa.String(), nullable=True),
sa.Column("description", sa.String(), nullable=True),
sa.Column("owner_id", sa.Integer(), nullable=True),
sa.ForeignKeyConstraint(
["owner_id"],
["users.id"],
),
sa.PrimaryKeyConstraint("id"),
)
op.create_index(op.f("ix_items_id"), "items", ["id"], unique=False)
op.create_index(op.f("ix_items_title"), "items", ["title"], unique=False)
op.enable_rls("items")
op.create_policy(
table_name="items",
policy_name="items_permissive_all_policy_0",
cmd="ALL",
definition="PERMISSIVE",
expr="owner_id > NULLIF(current_setting('rls.account_id', true),'')::INTEGER",
)
# ### end Alembic commands ###


def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_policy(tablename="items", policyname="items_permissive_all_policy_0")
op.disable_rls("items")
op.drop_index(op.f("ix_items_title"), table_name="items")
op.drop_index(op.f("ix_items_id"), table_name="items")
op.drop_table("items")
op.drop_index(op.f("ix_users_username"), table_name="users")
op.drop_index(op.f("ix_users_id"), table_name="users")
op.drop_table("users")
# ### end Alembic commands ###
Loading

0 comments on commit 7f3b711

Please sign in to comment.