From bf0cc9f084fa5bfafe3a5d1d3c026f8aa6ba798b Mon Sep 17 00:00:00 2001 From: Omar <102324541+Ofahmy143@users.noreply.github.com> Date: Tue, 17 Sep 2024 09:47:55 +0300 Subject: [PATCH] test: add tests setup Co-authored-by: Ghaith Kdimati --- test/__init__.py | 0 test/database.py | 26 ++++++++++++++++++++++++++ test/main.py | 33 +++++++++++++++++++++++++++++++++ test/models.py | 30 ++++++++++++++++++++++++++++++ test/requirements.txt | 19 +++++++++++++++++++ test/setup.py | 16 ++++++++++++++++ test/test_main.py | 15 +++++++++++++++ 7 files changed, 139 insertions(+) create mode 100644 test/__init__.py create mode 100644 test/database.py create mode 100644 test/main.py create mode 100644 test/models.py create mode 100644 test/requirements.txt create mode 100644 test/setup.py create mode 100644 test/test_main.py diff --git a/test/__init__.py b/test/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/test/database.py b/test/database.py new file mode 100644 index 0000000..e98f018 --- /dev/null +++ b/test/database.py @@ -0,0 +1,26 @@ +from sqlalchemy.ext.declarative import declarative_base +from sqlalchemy.orm import sessionmaker +from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine +from sqlalchemy.orm import sessionmaker + +DATABASE_URL = "postgresql+asyncpg://my_user:secure_password@localhost/session" + + + + +# SQLAlchemy engine +engine = create_async_engine(DATABASE_URL) + +# SQLAlchemy session +AsyncSessionLocal = sessionmaker( + bind=engine, + class_=AsyncSession, # Specify that the session should be async + expire_on_commit=False, +) +# Base class for our models +Base = declarative_base() + +# Dependency to get DB session in routes +async def get_db(): + async with AsyncSessionLocal() as session: + yield session diff --git a/test/main.py b/test/main.py new file mode 100644 index 0000000..a7e80e7 --- /dev/null +++ b/test/main.py @@ -0,0 +1,33 @@ +from fastapi import FastAPI, Depends +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.future import select +from sqlalchemy import text + +from database import get_db +from models import Item + + +app = FastAPI() + +@app.get("/users/{user_id}/items") +async def get_users(user_id, db: AsyncSession = Depends(get_db)): + + items = None + async with db as session: + print('session', session) + stmt = text(f"SET LOCAL app.current_user_id = {user_id}") + print('stmt', stmt) + await db.execute(stmt) + + # Retrieve the parameter value + get_param_stmt = text("SELECT current_setting('app.current_user_id')") + result = await db.execute(get_param_stmt) + current_user_id = result.scalar() # Get the single value + print('after: current_user_id', current_user_id) + + stmt = select(Item) + result = await db.execute(stmt) + items = result.scalars().all() + + + return items \ No newline at end of file diff --git a/test/models.py b/test/models.py new file mode 100644 index 0000000..6895217 --- /dev/null +++ b/test/models.py @@ -0,0 +1,30 @@ +from sqlalchemy import Column, Integer, String, ForeignKey +from sqlalchemy.orm import relationship +from sqlalchemy.ext.declarative import declarative_base +from rls.schemas import Permissive +from rls.register_rls import register_rls +Base = declarative_base() + +# Register RLS policies +register_rls(Base) + + +class User(Base): + __tablename__ = 'users' + + id = Column(Integer, primary_key=True, index=True) + username = Column(String, unique=True, index=True) + +class Item(Base): + __tablename__ = 'items' + + id = Column(Integer, primary_key=True, index=True) + title = Column(String, index=True) + description = Column(String) + owner_id = Column(Integer, ForeignKey('users.id')) + + owner = relationship("User") + + __rls_policies__ = [ + Permissive(expr=f"owner_id = current_setting('app.current_user_id')::integer", cmd=["SELECT","INSERT", "UPDATE", "DELETE"]) +] diff --git a/test/requirements.txt b/test/requirements.txt new file mode 100644 index 0000000..32d3c9f --- /dev/null +++ b/test/requirements.txt @@ -0,0 +1,19 @@ +annotated-types==0.7.0 +anyio==4.4.0 +asyncpg==0.29.0 +certifi==2024.8.30 +click==8.1.7 +fastapi==0.114.2 +greenlet==3.1.0 +h11==0.14.0 +httpcore==1.0.5 +httpx==0.27.2 +idna==3.10 +psycopg2==2.9.9 +pydantic==2.9.1 +pydantic_core==2.23.3 +sniffio==1.3.1 +SQLAlchemy==2.0.34 +starlette==0.38.5 +typing_extensions==4.12.2 +uvicorn==0.30.6 diff --git a/test/setup.py b/test/setup.py new file mode 100644 index 0000000..791b207 --- /dev/null +++ b/test/setup.py @@ -0,0 +1,16 @@ +# sync_setup.py +from sqlalchemy import create_engine +from models import Base +from models import Base + +DATABASE_URL = "postgresql://user:password@localhost/session" +# Use a synchronous engine for schema creation +engine = create_engine(DATABASE_URL) + +def setup_database(): + # Create tables + Base.metadata.create_all(engine) + + +if __name__ == "__main__": + setup_database() \ No newline at end of file diff --git a/test/test_main.py b/test/test_main.py new file mode 100644 index 0000000..3cf23a9 --- /dev/null +++ b/test/test_main.py @@ -0,0 +1,15 @@ +import asyncio +import httpx + +async def get_items(user_id): + async with httpx.AsyncClient() as client: + response = await client.get(f"http://localhost:8000/users/{user_id}/items") + print(f"User {user_id}: {response.json()}") + +async def main(): + await asyncio.gather( + get_items(1), + get_items(2), + ) + +asyncio.run(main())