Skip to content

Commit

Permalink
Initial Commit: package structure
Browse files Browse the repository at this point in the history
  • Loading branch information
Ofahmy143 committed Sep 17, 2024
0 parents commit 108750b
Show file tree
Hide file tree
Showing 7 changed files with 158 additions and 0 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
/venv/
*pycache*
12 changes: 12 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
annotated-types==0.7.0
anyio==4.4.0
fastapi==0.114.2
greenlet==3.1.0
idna==3.10
psycopg2==2.9.9
pydantic==2.9.1
pydantic_core==2.23.3
sniffio==1.3.1
SQLAlchemy==2.0.34
starlette==0.38.5
typing_extensions==4.12.2
Empty file added rls/__init__.py
Empty file.
30 changes: 30 additions & 0 deletions rls/create_policies.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
from typing import Type

from sqlalchemy import text
from sqlalchemy.engine import Connection
from sqlalchemy.ext.declarative import DeclarativeMeta


def create_policies(Base: Type[DeclarativeMeta], connection: Connection):
"""Create policies for `Base.metadata.create_all()`."""
for table, settings in Base.metadata.info["rls_policies"].items():
# enable
stmt = text(f"ALTER TABLE {table} ENABLE ROW LEVEL SECURITY;")
connection.execute(stmt)
# force by default
stmt = text(f"ALTER TABLE {table} FORCE ROW LEVEL SECURITY;")
connection.execute(stmt)
# policies
print("SETTINGS", settings)
for ix, policy in enumerate(settings):
print("*************************")
print("POLICY", policy)
print("*************************")
for pol_stmt in policy.get_sql_policies(
table_name=table, name_suffix=str(ix)
):
print("*************************")
print(pol_stmt)
print("*************************")
connection.execute(pol_stmt)
connection.commit()
9 changes: 9 additions & 0 deletions rls/principals.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
Always = "true"
Never = "false"

# ID principals
Authenticated = "current_setting('app.current_user_id')::INTEGER IS NOT NULL"
UserOwner = "owner_id = current_setting('app.current_user_id')::integer"

# UUID principals
UserOwnerUuid = "owner_id = current_setting('app.current_user_id')::UUID"
32 changes: 32 additions & 0 deletions rls/register_rls.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
from typing import Type

from sqlalchemy import event, text
from sqlalchemy.ext.declarative import DeclarativeMeta

from .create_policies import create_policies



def set_metadata_info(Base: Type[DeclarativeMeta]):
"""RLS policies are first added to the Metadata before applied."""
Base.metadata.info.setdefault("rls_policies", dict())
for mapper in Base.registry.mappers:
if not hasattr(mapper.class_, "__rls_policies__"):
continue
Base.metadata.info["rls_policies"][
mapper.tables[0].fullname
] = mapper.class_.__rls_policies__



def register_rls(Base: Type[DeclarativeMeta]):

# required for `alembic revision --autogenerate``
set_metadata_info(Base)

@event.listens_for(Base.metadata, "after_create")
def receive_after_create(target, connection, tables, **kw):

# required for `Base.metadata.create_all()`
set_metadata_info(Base)
create_policies(Base, connection)
73 changes: 73 additions & 0 deletions rls/schemas.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
from enum import Enum
from typing import List, Literal, Union

from pydantic import BaseModel
from sqlalchemy import text


class Command(str, Enum):
# policies: https://www.postgresql.org/docs/current/sql-createpolicy.html
all = "ALL"
select = "SELECT"
insert = "INSERT"
update = "UPDATE"
delete = "DELETE"


class Policy(BaseModel):
definition: str
expr: str
cmd: Union[Command, List[Command]]

def get_sql_policies(self, table_name: str, name_suffix: str = "0"):
commands = [self.cmd] if isinstance(self.cmd, str) else self.cmd

print("%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%")
print('definition', self.definition)
print('expr', self.expr)
print('commands', (commands[0].value))
print('table_name', table_name)
print("%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%")


policy_lists = []
for cmd in commands:
cmd_value = cmd.value if isinstance(cmd, Command) else cmd
policy_name = (
f"{table_name}_{self.definition}" f"_{cmd_value}_policy_{name_suffix}".lower()
)


if cmd_value in ["ALL", "SELECT", "UPDATE", "DELETE"]:
policy_lists.append(
text(
f"""
CREATE POLICY {policy_name} ON {table_name}
AS {self.definition}
FOR {cmd_value}
USING ({self.expr})
"""
)
)
elif cmd in ["INSERT"]:
policy_lists.append(
text(
f"""
CREATE POLICY {policy_name} ON {table_name}
AS {self.definition}
FOR {cmd_value}
WITH CHECK ({self.expr})
"""
)
)
else:
raise ValueError(f'Unknown policy command"{cmd_value}"')
return policy_lists


class Permissive(Policy):
definition: Literal["PERMISSIVE"] = "PERMISSIVE"


class Restrictive(Policy):
definition: Literal["RESTRICTIVE"] = "RESTRICTIVE"

0 comments on commit 108750b

Please sign in to comment.