Skip to content

Commit

Permalink
test: add tests setup
Browse files Browse the repository at this point in the history
Co-authored-by: Ghaith Kdimati <[email protected]>
  • Loading branch information
Ofahmy143 and Ghaithq committed Sep 17, 2024
1 parent 108750b commit bf0cc9f
Show file tree
Hide file tree
Showing 7 changed files with 139 additions and 0 deletions.
Empty file added test/__init__.py
Empty file.
26 changes: 26 additions & 0 deletions test/database.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import sessionmaker
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
from sqlalchemy.orm import sessionmaker

DATABASE_URL = "postgresql+asyncpg://my_user:secure_password@localhost/session"




# SQLAlchemy engine
engine = create_async_engine(DATABASE_URL)

# SQLAlchemy session
AsyncSessionLocal = sessionmaker(
bind=engine,
class_=AsyncSession, # Specify that the session should be async
expire_on_commit=False,
)
# Base class for our models
Base = declarative_base()

# Dependency to get DB session in routes
async def get_db():
async with AsyncSessionLocal() as session:
yield session
33 changes: 33 additions & 0 deletions test/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
from fastapi import FastAPI, Depends
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
from sqlalchemy import text

from database import get_db
from models import Item


app = FastAPI()

@app.get("/users/{user_id}/items")
async def get_users(user_id, db: AsyncSession = Depends(get_db)):

items = None
async with db as session:
print('session', session)
stmt = text(f"SET LOCAL app.current_user_id = {user_id}")
print('stmt', stmt)
await db.execute(stmt)

# Retrieve the parameter value
get_param_stmt = text("SELECT current_setting('app.current_user_id')")
result = await db.execute(get_param_stmt)
current_user_id = result.scalar() # Get the single value
print('after: current_user_id', current_user_id)

stmt = select(Item)
result = await db.execute(stmt)
items = result.scalars().all()


return items
30 changes: 30 additions & 0 deletions test/models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
from sqlalchemy import Column, Integer, String, ForeignKey
from sqlalchemy.orm import relationship
from sqlalchemy.ext.declarative import declarative_base
from rls.schemas import Permissive
from rls.register_rls import register_rls
Base = declarative_base()

# Register RLS policies
register_rls(Base)


class User(Base):
__tablename__ = 'users'

id = Column(Integer, primary_key=True, index=True)
username = Column(String, unique=True, index=True)

class Item(Base):
__tablename__ = 'items'

id = Column(Integer, primary_key=True, index=True)
title = Column(String, index=True)
description = Column(String)
owner_id = Column(Integer, ForeignKey('users.id'))

owner = relationship("User")

__rls_policies__ = [
Permissive(expr=f"owner_id = current_setting('app.current_user_id')::integer", cmd=["SELECT","INSERT", "UPDATE", "DELETE"])
]
19 changes: 19 additions & 0 deletions test/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
annotated-types==0.7.0
anyio==4.4.0
asyncpg==0.29.0
certifi==2024.8.30
click==8.1.7
fastapi==0.114.2
greenlet==3.1.0
h11==0.14.0
httpcore==1.0.5
httpx==0.27.2
idna==3.10
psycopg2==2.9.9
pydantic==2.9.1
pydantic_core==2.23.3
sniffio==1.3.1
SQLAlchemy==2.0.34
starlette==0.38.5
typing_extensions==4.12.2
uvicorn==0.30.6
16 changes: 16 additions & 0 deletions test/setup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# sync_setup.py
from sqlalchemy import create_engine
from models import Base
from models import Base

DATABASE_URL = "postgresql://user:password@localhost/session"
# Use a synchronous engine for schema creation
engine = create_engine(DATABASE_URL)

def setup_database():
# Create tables
Base.metadata.create_all(engine)


if __name__ == "__main__":
setup_database()
15 changes: 15 additions & 0 deletions test/test_main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
import asyncio
import httpx

async def get_items(user_id):
async with httpx.AsyncClient() as client:
response = await client.get(f"http://localhost:8000/users/{user_id}/items")
print(f"User {user_id}: {response.json()}")

async def main():
await asyncio.gather(
get_items(1),
get_items(2),
)

asyncio.run(main())

0 comments on commit bf0cc9f

Please sign in to comment.