-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Co-authored-by: Ghaith Kdimati <[email protected]>
- Loading branch information
Showing
7 changed files
with
139 additions
and
0 deletions.
There are no files selected for viewing
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
from sqlalchemy.ext.declarative import declarative_base | ||
from sqlalchemy.orm import sessionmaker | ||
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine | ||
from sqlalchemy.orm import sessionmaker | ||
|
||
DATABASE_URL = "postgresql+asyncpg://my_user:secure_password@localhost/session" | ||
|
||
|
||
|
||
|
||
# SQLAlchemy engine | ||
engine = create_async_engine(DATABASE_URL) | ||
|
||
# SQLAlchemy session | ||
AsyncSessionLocal = sessionmaker( | ||
bind=engine, | ||
class_=AsyncSession, # Specify that the session should be async | ||
expire_on_commit=False, | ||
) | ||
# Base class for our models | ||
Base = declarative_base() | ||
|
||
# Dependency to get DB session in routes | ||
async def get_db(): | ||
async with AsyncSessionLocal() as session: | ||
yield session |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
from fastapi import FastAPI, Depends | ||
from sqlalchemy.ext.asyncio import AsyncSession | ||
from sqlalchemy.future import select | ||
from sqlalchemy import text | ||
|
||
from database import get_db | ||
from models import Item | ||
|
||
|
||
app = FastAPI() | ||
|
||
@app.get("/users/{user_id}/items") | ||
async def get_users(user_id, db: AsyncSession = Depends(get_db)): | ||
|
||
items = None | ||
async with db as session: | ||
print('session', session) | ||
stmt = text(f"SET LOCAL app.current_user_id = {user_id}") | ||
print('stmt', stmt) | ||
await db.execute(stmt) | ||
|
||
# Retrieve the parameter value | ||
get_param_stmt = text("SELECT current_setting('app.current_user_id')") | ||
result = await db.execute(get_param_stmt) | ||
current_user_id = result.scalar() # Get the single value | ||
print('after: current_user_id', current_user_id) | ||
|
||
stmt = select(Item) | ||
result = await db.execute(stmt) | ||
items = result.scalars().all() | ||
|
||
|
||
return items |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
from sqlalchemy import Column, Integer, String, ForeignKey | ||
from sqlalchemy.orm import relationship | ||
from sqlalchemy.ext.declarative import declarative_base | ||
from rls.schemas import Permissive | ||
from rls.register_rls import register_rls | ||
Base = declarative_base() | ||
|
||
# Register RLS policies | ||
register_rls(Base) | ||
|
||
|
||
class User(Base): | ||
__tablename__ = 'users' | ||
|
||
id = Column(Integer, primary_key=True, index=True) | ||
username = Column(String, unique=True, index=True) | ||
|
||
class Item(Base): | ||
__tablename__ = 'items' | ||
|
||
id = Column(Integer, primary_key=True, index=True) | ||
title = Column(String, index=True) | ||
description = Column(String) | ||
owner_id = Column(Integer, ForeignKey('users.id')) | ||
|
||
owner = relationship("User") | ||
|
||
__rls_policies__ = [ | ||
Permissive(expr=f"owner_id = current_setting('app.current_user_id')::integer", cmd=["SELECT","INSERT", "UPDATE", "DELETE"]) | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
annotated-types==0.7.0 | ||
anyio==4.4.0 | ||
asyncpg==0.29.0 | ||
certifi==2024.8.30 | ||
click==8.1.7 | ||
fastapi==0.114.2 | ||
greenlet==3.1.0 | ||
h11==0.14.0 | ||
httpcore==1.0.5 | ||
httpx==0.27.2 | ||
idna==3.10 | ||
psycopg2==2.9.9 | ||
pydantic==2.9.1 | ||
pydantic_core==2.23.3 | ||
sniffio==1.3.1 | ||
SQLAlchemy==2.0.34 | ||
starlette==0.38.5 | ||
typing_extensions==4.12.2 | ||
uvicorn==0.30.6 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
# sync_setup.py | ||
from sqlalchemy import create_engine | ||
from models import Base | ||
from models import Base | ||
|
||
DATABASE_URL = "postgresql://user:password@localhost/session" | ||
# Use a synchronous engine for schema creation | ||
engine = create_engine(DATABASE_URL) | ||
|
||
def setup_database(): | ||
# Create tables | ||
Base.metadata.create_all(engine) | ||
|
||
|
||
if __name__ == "__main__": | ||
setup_database() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
import asyncio | ||
import httpx | ||
|
||
async def get_items(user_id): | ||
async with httpx.AsyncClient() as client: | ||
response = await client.get(f"http://localhost:8000/users/{user_id}/items") | ||
print(f"User {user_id}: {response.json()}") | ||
|
||
async def main(): | ||
await asyncio.gather( | ||
get_items(1), | ||
get_items(2), | ||
) | ||
|
||
asyncio.run(main()) |