Skip to content

Commit

Permalink
feat: bind engine to session and make sure the bind is called only on…
Browse files Browse the repository at this point in the history
…ce and support multiple sources of comparison with nested fields

Co-authored-by: Ghaith Kdimati <[email protected]>
  • Loading branch information
2 people authored and baraka95 committed Sep 27, 2024
1 parent 4a6e633 commit d47f5d8
Show file tree
Hide file tree
Showing 6 changed files with 170 additions and 47 deletions.
146 changes: 120 additions & 26 deletions rls/database.py
Original file line number Diff line number Diff line change
@@ -1,34 +1,118 @@
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import sessionmaker, Session
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine, async_sessionmaker
from sqlalchemy import text, create_engine
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, AsyncEngine
from sqlalchemy import text
from sqlalchemy.sql.elements import TextClause
from sqlalchemy.engine import Engine

import once
import logging
import base64
import json
from typing import Optional

from fastapi import Request
from fastapi.datastructures import Headers

from rls.schemas import Policy
from rls.schemas import Policy, ComparatorSource

# --------------------------------------Set RLS variables-----------------------------#


def _get_set_statements(headers: Headers) -> Optional[TextClause]:
def _base64url_decode(input_str):
"""Decodes a Base64 URL-encoded string."""
padding = "=" * (4 - (len(input_str) % 4)) # Add necessary padding
return base64.urlsafe_b64decode(input_str + padding)


def _decode_jwt(token):
"""Decode the JWT without verifying the signature."""
try:
# Split the token into header, payload, and signature
header_b64, payload_b64, signature_b64 = token.split(".")

# Decode the payload (base64url format)
decoded_payload = _base64url_decode(payload_b64)

# Convert the payload from bytes to JSON
payload = json.loads(decoded_payload)

return payload
except (ValueError, json.JSONDecodeError) as e:
logging.error("Invalid token or payload:", e)
return None


def _parse_bearer_token(AuthorizationHeader: str):
"""Parse the Bearer token from the Authorization header with safety."""
if not AuthorizationHeader:
logging.warning("No Authorization header")
return None

parts = AuthorizationHeader.split()
if parts[0].lower() != "bearer":
logging.warning("Authorization header does not contain Bearer token")
return None

if len(parts) == 1:
logging.warning("No Bearer token found")
return None

if len(parts) > 2:
logging.warning("Authorization header contains multiple Bearer tokens")
return None

return _decode_jwt(parts[1])


def _get_nested_value(dictionary: dict, keys: list[str]):
"""Retrieve the nested value in a dictionary based on a list of keys."""
value = dictionary
for key in keys:
if isinstance(value, dict):
value = value.get(key, None)
else:
return None # If value is not a dict, return None
if value is None:
return None # Return None if any key is not found
return value


def _get_set_statements(req: Request) -> Optional[TextClause]:
# must be setup for the rls_policies key to be set
queries = []

Base.metadata.info.setdefault("rls_policies", dict())
for mapper in Base.registry.mappers:
RlsBase.metadata.info.setdefault("rls_policies", dict())
for mapper in RlsBase.registry.mappers:
if not hasattr(mapper.class_, "__rls_policies__"):
continue

table_name = mapper.tables[0].fullname
policies: list[Policy] = mapper.class_.__rls_policies__

for policy in policies:
db_var_name = policy.get_db_var_name(table_name)

comparator_name = policy.condition_args["comparator_name"]
comparator_value = headers.get(comparator_name)
comparator_name_split = comparator_name.split(".")

comparator_value = None
if (
policy.condition_args["comparator_source"]
== ComparatorSource.bearerTokenPayload
):
comparator_value = _get_nested_value(
_parse_bearer_token(req.headers.get("Authorization")),
comparator_name_split,
)
elif (
policy.condition_args["comparator_source"]
== ComparatorSource.requestUser
):
comparator_value = _get_nested_value(req.user, comparator_name_split)
elif policy.condition_args["comparator_source"] == ComparatorSource.header:
comparator_value = _get_nested_value(req.headers, comparator_name_split)
else:
raise ValueError("Invalid Comparator Source")

if comparator_value is None:
continue
Expand All @@ -44,37 +128,24 @@ def _get_set_statements(headers: Headers) -> Optional[TextClause]:
return combined_query


# --------------------------------------Base Initialization-----------------------------#
# --------------------------------------RlsBase Initialization-----------------------------#


# Base class for our models
Base = declarative_base()
# RlsBase class for our models
RlsBase = declarative_base()


# --------------------------------------Engines Initialization-----------------------------#


ASYNC_DATABASE_URL = "postgresql+asyncpg://my_user:secure_password@localhost/session"

SYNC_DATABASE_URL = "postgresql://my_user:secure_password@localhost/session"


# TODO: DATABASE_URL should be an environment variable be it SYNC or ASYNC
# SQLAlchemy engine
async_engine = create_async_engine(ASYNC_DATABASE_URL)

# SQLAlchemy session
AsyncSessionLocal = async_sessionmaker(
bind=async_engine,
class_=AsyncSession, # Async session class
expire_on_commit=False,
)


sync_engine = create_engine(SYNC_DATABASE_URL)

SyncSessionLocal = sessionmaker(
bind=sync_engine,
class_=Session,
expire_on_commit=False,
)
Expand All @@ -83,9 +154,32 @@ def _get_set_statements(headers: Headers) -> Optional[TextClause]:
# --------------------------------------Deps injection functions-----------------------------#


@once.once
def bind_engine(db_engine: Engine):
if isinstance(db_engine, AsyncEngine):
print("Config Async Session")
AsyncSessionLocal.configure(bind=db_engine)
elif isinstance(db_engine, Engine):
print("Config Sync Session")
SyncSessionLocal.configure(bind=db_engine)
else:
raise ValueError("Invalid Engine type")


def get_session(db_engine: Engine):
bind_engine(db_engine)

if isinstance(db_engine, AsyncEngine):
return get_async_session
elif isinstance(db_engine, Engine):
return get_sync_session
else:
raise ValueError("Invalid Engine type")


def get_sync_session(request: Request):
with SyncSessionLocal() as session:
stmts = _get_set_statements(request.headers)
stmts = _get_set_statements(request)
if stmts is not None:
session.execute(stmts)
yield session
Expand All @@ -94,7 +188,7 @@ def get_sync_session(request: Request):
# Dependency to get DB session in routes
async def get_async_session(request: Request):
async with AsyncSessionLocal() as session:
stmts = _get_set_statements(request.headers)
stmts = _get_set_statements(request)
if stmts is not None:
await session.execute(stmts)
yield session
20 changes: 13 additions & 7 deletions rls/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,18 +13,27 @@ class Command(str, Enum):
update = "UPDATE"
delete = "DELETE"


class Operation(str, Enum):
equality = "EQUALITY"


class ExpressionTypes(str, Enum):
integer = "INTEGER"
uuid = "UUID"
text = "TEXT"
boolean = "BOOLEAN"


class ComparatorSource(str, Enum):
header = "header"
bearerTokenPayload = "bearer_token_payload"
requestUser = "request_user"


class ConditionArgs(TypedDict):
comparator_name: str
comparator_source: ComparatorSource
operation: Operation
type: ExpressionTypes
column_name: str
Expand All @@ -35,12 +44,10 @@ class Policy(BaseModel):
condition_args: ConditionArgs
cmd: Union[Command, List[Command]]


def get_db_var_name(self, table_name):
return f"rls.{table_name}_{self.condition_args['column_name']}"

def _get_expr_from_params(self, table_name: str):

def _get_expr_from_params(self, table_name: str):
variable_name = f"NULLIF(current_setting('{self.get_db_var_name(table_name)}', true),'')::{self.condition_args['type'].value}"

expr = None
Expand All @@ -49,9 +56,8 @@ def _get_expr_from_params(self, table_name: str):

if expr is None:
raise ValueError(f"Unknown operation: {self.condition_args['operation']}")

return expr

return expr

def get_sql_policies(self, table_name: str, name_suffix: str = "0"):
commands = [self.cmd] if isinstance(self.cmd, str) else self.cmd
Expand All @@ -64,9 +70,9 @@ def get_sql_policies(self, table_name: str, name_suffix: str = "0"):
for cmd in commands:
cmd_value = cmd.value if isinstance(cmd, Command) else cmd
policy_name = (
f"{table_name}_{self.definition}" f"_{cmd_value}_policy_{name_suffix}".lower()
f"{table_name}_{self.definition}"
f"_{cmd_value}_policy_{name_suffix}".lower()
)


if cmd_value in ["ALL", "SELECT", "UPDATE", "DELETE"]:
policy_lists.append(
Expand Down
10 changes: 10 additions & 0 deletions test/engines.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
from sqlalchemy.ext.asyncio import create_async_engine
from sqlalchemy import create_engine


ASYNC_DATABASE_URL = "postgresql+asyncpg://my_user:secure_password@localhost/session"
SYNC_DATABASE_URL = "postgresql://my_user:secure_password@localhost/session"
ADMIN_DATABASE_URL = "postgresql://user:password@localhost/session"
async_engine = create_async_engine(ASYNC_DATABASE_URL)
sync_engine = create_engine(SYNC_DATABASE_URL)
admin_engine = create_engine(ADMIN_DATABASE_URL)
9 changes: 7 additions & 2 deletions test/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,20 @@
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select

from rls.database import get_async_session

from rls.database import get_session
from test.models import Item

from .engines import async_engine as db_engine


app = FastAPI()

Session = Depends(get_session(db_engine))


@app.get("/users/items")
async def get_users(db: AsyncSession = Depends(get_async_session)):
async def get_users(db: AsyncSession = Session):
stmt = select(Item)
result = await db.execute(stmt)
items = result.scalars().all()
Expand Down
22 changes: 16 additions & 6 deletions test/models.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,27 @@
from sqlalchemy import Column, Integer, String, ForeignKey
from sqlalchemy.orm import relationship
from rls.schemas import Permissive, Operation, ExpressionTypes, Command
from rls.schemas import (
Permissive,
Operation,
ExpressionTypes,
Command,
ComparatorSource,
)
from rls.register_rls import register_rls
from rls.database import Base
from rls.database import RlsBase

# Register RLS policies
register_rls(Base)
register_rls(RlsBase)


class User(Base):
class User(RlsBase):
__tablename__ = "users"

id = Column(Integer, primary_key=True, index=True)
username = Column(String, unique=True, index=True)


class Item(Base):
class Item(RlsBase):
__tablename__ = "items"

id = Column(Integer, primary_key=True, index=True)
Expand All @@ -28,10 +34,14 @@ class Item(Base):
__rls_policies__ = [
Permissive(
condition_args={
"comparator_name": "user_id",
"comparator_name": "sub",
"comparator_source": ComparatorSource.bearerTokenPayload,
"operation": Operation.equality,
"type": ExpressionTypes.integer,
"column_name": "owner_id",
# "expr": "&column_name& = current_setting('&comparator_name&')::UUID"
# TODO: how to parse the comparator name from expr
# IDEA: may give us a custom function to parse from the comparator he customized
},
cmd=[Command.all],
)
Expand Down
10 changes: 4 additions & 6 deletions test/setup.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,13 @@
# sync_setup.py
from sqlalchemy import create_engine
from .models import Base
from engines import admin_engine as engine
from models import RlsBase

DATABASE_URL = "postgresql://user:password@localhost/session"
# Use a synchronous engine for schema creation
engine = create_engine(DATABASE_URL)
print("Setting up database")


def setup_database():
# Create tables
Base.metadata.create_all(engine)
RlsBase.metadata.create_all(engine)


if __name__ == "__main__":
Expand Down

0 comments on commit d47f5d8

Please sign in to comment.