Skip to content

Commit

Permalink
fix: mypy static analysis errors
Browse files Browse the repository at this point in the history
Co-authored-by: Ghaith Kdimati <[email protected]>
  • Loading branch information
Ofahmy143 and Ghaithq committed Sep 19, 2024
1 parent b8747c7 commit 981a88d
Show file tree
Hide file tree
Showing 4 changed files with 46 additions and 43 deletions.
49 changes: 25 additions & 24 deletions rls/database.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,20 @@
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import sessionmaker, Session
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
from sqlalchemy.orm import sessionmaker
from rls.schemas import Policy
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine, async_sessionmaker
from sqlalchemy import text, create_engine
from sqlalchemy.sql.elements import TextClause

from typing import Optional

from fastapi import Request
from fastapi.datastructures import Headers
from sqlalchemy import text, create_engine

#--------------------------------------Set RLS variables-----------------------------#
from rls.schemas import Policy

def _get_set_statements(headers: Headers) -> str:
# --------------------------------------Set RLS variables-----------------------------#


def _get_set_statements(headers: Headers) -> Optional[TextClause]:
# must be setup for the rls_policies key to be set
queries = []

Expand All @@ -25,30 +30,28 @@ def _get_set_statements(headers: Headers) -> str:
comparator_name = policy.condition_args["comparator_name"]
comparator_value = headers.get(comparator_name)

if(comparator_value is None):
if comparator_value is None:
continue

temp_query = f"SET LOCAL {db_var_name} = {comparator_value};"

queries.append(temp_query)


if(len(queries) == 0):
return ""
combined_query = text('\n'.join(queries)) # Combine queries into a single string
if len(queries) == 0:
return None
combined_query = text("\n".join(queries)) # Combine queries into a single string

return combined_query



#--------------------------------------Base Initialization-----------------------------#
# --------------------------------------Base Initialization-----------------------------#


# Base class for our models
Base = declarative_base()


#--------------------------------------Engines Initialization-----------------------------#
# --------------------------------------Engines Initialization-----------------------------#


ASYNC_DATABASE_URL = "postgresql+asyncpg://my_user:secure_password@localhost/session"
Expand All @@ -61,9 +64,9 @@ def _get_set_statements(headers: Headers) -> str:
async_engine = create_async_engine(ASYNC_DATABASE_URL)

# SQLAlchemy session
AsyncSessionLocal = sessionmaker(
AsyncSessionLocal = async_sessionmaker(
bind=async_engine,
class_=AsyncSession, # Specify that the session should be async
class_=AsyncSession, # Async session class
expire_on_commit=False,
)

Expand All @@ -72,28 +75,26 @@ def _get_set_statements(headers: Headers) -> str:

SyncSessionLocal = sessionmaker(
bind=sync_engine,
class_= Session,
class_=Session,
expire_on_commit=False,
)



#--------------------------------------Deps injection functions-----------------------------#
# --------------------------------------Deps injection functions-----------------------------#


def get_sync_session(request: Request):
with SyncSessionLocal() as session:
stmts = _get_set_statements(request.headers)
if(stmts != ""):
session.execute(stmts)
if stmts is not None:
session.execute(stmts)
yield session




# Dependency to get DB session in routes
async def get_async_session(request: Request):
async with AsyncSessionLocal() as session:
stmts = _get_set_statements(request.headers)
if(stmts != ""):
if stmts is not None:
await session.execute(stmts)
yield session
7 changes: 3 additions & 4 deletions test/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,16 @@
from sqlalchemy.future import select

from rls.database import get_async_session
from models import Item
from test.models import Item


app = FastAPI()


@app.get("/users/{user_id}/items")
async def get_users(user_id, db: AsyncSession = Depends(get_async_session)):

stmt = select(Item)
result = await db.execute(stmt)
items = result.scalars().all()


return items
return items
25 changes: 14 additions & 11 deletions test/models.py
Original file line number Diff line number Diff line change
@@ -1,35 +1,38 @@
from sqlalchemy import Column, Integer, String, ForeignKey
from sqlalchemy.orm import relationship
from sqlalchemy.ext.declarative import declarative_base
from rls.schemas import Permissive
from rls.schemas import Permissive, Operation, ExpressionTypes, Command
from rls.register_rls import register_rls
from rls.database import Base

# Register RLS policies
register_rls(Base)


class User(Base):
__tablename__ = 'users'
__tablename__ = "users"

id = Column(Integer, primary_key=True, index=True)
username = Column(String, unique=True, index=True)


class Item(Base):
__tablename__ = 'items'
__tablename__ = "items"

id = Column(Integer, primary_key=True, index=True)
title = Column(String, index=True)
description = Column(String)
owner_id = Column(Integer, ForeignKey('users.id'))
owner_id = Column(Integer, ForeignKey("users.id"))

owner = relationship("User")

__rls_policies__ = [
Permissive(condition_args={
Permissive(
condition_args={
"comparator_name": "user_id",
"operation": "EQUALITY",
"type": "INTEGER",
"column_name": "owner_id"
}, cmd=["SELECT","INSERT", "UPDATE", "DELETE"])
"operation": Operation.equality,
"type": ExpressionTypes.integer,
"column_name": "owner_id",
},
cmd=[Command.all],
)
]
8 changes: 4 additions & 4 deletions test/setup.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
# sync_setup.py
from sqlalchemy import create_engine
from models import Base
from models import Base
from .models import Base

DATABASE_URL = "postgresql://user:password@localhost/session"
# Use a synchronous engine for schema creation
engine = create_engine(DATABASE_URL)


def setup_database():
# Create tables
Base.metadata.create_all(engine)


if __name__ == "__main__":
setup_database()
setup_database()

0 comments on commit 981a88d

Please sign in to comment.