-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
0 parents
commit 108750b
Showing
7 changed files
with
158 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
/venv/ | ||
*pycache* |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
annotated-types==0.7.0 | ||
anyio==4.4.0 | ||
fastapi==0.114.2 | ||
greenlet==3.1.0 | ||
idna==3.10 | ||
psycopg2==2.9.9 | ||
pydantic==2.9.1 | ||
pydantic_core==2.23.3 | ||
sniffio==1.3.1 | ||
SQLAlchemy==2.0.34 | ||
starlette==0.38.5 | ||
typing_extensions==4.12.2 |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
from typing import Type | ||
|
||
from sqlalchemy import text | ||
from sqlalchemy.engine import Connection | ||
from sqlalchemy.ext.declarative import DeclarativeMeta | ||
|
||
|
||
def create_policies(Base: Type[DeclarativeMeta], connection: Connection): | ||
"""Create policies for `Base.metadata.create_all()`.""" | ||
for table, settings in Base.metadata.info["rls_policies"].items(): | ||
# enable | ||
stmt = text(f"ALTER TABLE {table} ENABLE ROW LEVEL SECURITY;") | ||
connection.execute(stmt) | ||
# force by default | ||
stmt = text(f"ALTER TABLE {table} FORCE ROW LEVEL SECURITY;") | ||
connection.execute(stmt) | ||
# policies | ||
print("SETTINGS", settings) | ||
for ix, policy in enumerate(settings): | ||
print("*************************") | ||
print("POLICY", policy) | ||
print("*************************") | ||
for pol_stmt in policy.get_sql_policies( | ||
table_name=table, name_suffix=str(ix) | ||
): | ||
print("*************************") | ||
print(pol_stmt) | ||
print("*************************") | ||
connection.execute(pol_stmt) | ||
connection.commit() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
Always = "true" | ||
Never = "false" | ||
|
||
# ID principals | ||
Authenticated = "current_setting('app.current_user_id')::INTEGER IS NOT NULL" | ||
UserOwner = "owner_id = current_setting('app.current_user_id')::integer" | ||
|
||
# UUID principals | ||
UserOwnerUuid = "owner_id = current_setting('app.current_user_id')::UUID" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
from typing import Type | ||
|
||
from sqlalchemy import event, text | ||
from sqlalchemy.ext.declarative import DeclarativeMeta | ||
|
||
from .create_policies import create_policies | ||
|
||
|
||
|
||
def set_metadata_info(Base: Type[DeclarativeMeta]): | ||
"""RLS policies are first added to the Metadata before applied.""" | ||
Base.metadata.info.setdefault("rls_policies", dict()) | ||
for mapper in Base.registry.mappers: | ||
if not hasattr(mapper.class_, "__rls_policies__"): | ||
continue | ||
Base.metadata.info["rls_policies"][ | ||
mapper.tables[0].fullname | ||
] = mapper.class_.__rls_policies__ | ||
|
||
|
||
|
||
def register_rls(Base: Type[DeclarativeMeta]): | ||
|
||
# required for `alembic revision --autogenerate`` | ||
set_metadata_info(Base) | ||
|
||
@event.listens_for(Base.metadata, "after_create") | ||
def receive_after_create(target, connection, tables, **kw): | ||
|
||
# required for `Base.metadata.create_all()` | ||
set_metadata_info(Base) | ||
create_policies(Base, connection) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,73 @@ | ||
from enum import Enum | ||
from typing import List, Literal, Union | ||
|
||
from pydantic import BaseModel | ||
from sqlalchemy import text | ||
|
||
|
||
class Command(str, Enum): | ||
# policies: https://www.postgresql.org/docs/current/sql-createpolicy.html | ||
all = "ALL" | ||
select = "SELECT" | ||
insert = "INSERT" | ||
update = "UPDATE" | ||
delete = "DELETE" | ||
|
||
|
||
class Policy(BaseModel): | ||
definition: str | ||
expr: str | ||
cmd: Union[Command, List[Command]] | ||
|
||
def get_sql_policies(self, table_name: str, name_suffix: str = "0"): | ||
commands = [self.cmd] if isinstance(self.cmd, str) else self.cmd | ||
|
||
print("%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%") | ||
print('definition', self.definition) | ||
print('expr', self.expr) | ||
print('commands', (commands[0].value)) | ||
print('table_name', table_name) | ||
print("%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%") | ||
|
||
|
||
policy_lists = [] | ||
for cmd in commands: | ||
cmd_value = cmd.value if isinstance(cmd, Command) else cmd | ||
policy_name = ( | ||
f"{table_name}_{self.definition}" f"_{cmd_value}_policy_{name_suffix}".lower() | ||
) | ||
|
||
|
||
if cmd_value in ["ALL", "SELECT", "UPDATE", "DELETE"]: | ||
policy_lists.append( | ||
text( | ||
f""" | ||
CREATE POLICY {policy_name} ON {table_name} | ||
AS {self.definition} | ||
FOR {cmd_value} | ||
USING ({self.expr}) | ||
""" | ||
) | ||
) | ||
elif cmd in ["INSERT"]: | ||
policy_lists.append( | ||
text( | ||
f""" | ||
CREATE POLICY {policy_name} ON {table_name} | ||
AS {self.definition} | ||
FOR {cmd_value} | ||
WITH CHECK ({self.expr}) | ||
""" | ||
) | ||
) | ||
else: | ||
raise ValueError(f'Unknown policy command"{cmd_value}"') | ||
return policy_lists | ||
|
||
|
||
class Permissive(Policy): | ||
definition: Literal["PERMISSIVE"] = "PERMISSIVE" | ||
|
||
|
||
class Restrictive(Policy): | ||
definition: Literal["RESTRICTIVE"] = "RESTRICTIVE" |