diff --git a/README.md b/README.md index 4781d58..fcbe8c8 100644 --- a/README.md +++ b/README.md @@ -192,7 +192,14 @@ which takes two arguments: for which you have ```python -# Mock context model +from sqlalchemy.orm import sessionmaker +from rls.rls_session import RlsSession +from rls.rls_sessioner import RlsSessioner, ContextGetter +from pydantic import BaseModel +from test.engines import sync_engine as engine +from sqlalchemy import text + + class ExampleContext(BaseModel): account_id: int provider_id: int @@ -201,9 +208,8 @@ class ExampleContext(BaseModel): # Concrete implementation of ContextGetter class ExampleContextGetter(ContextGetter): def get_context(self, *args, **kwargs) -> ExampleContext: - account_id = kwargs.get('acount_id') - provider_id = kwargs.get('provider_id') - + account_id = kwargs.get("account_id", 1) + provider_id = kwargs.get("provider_id", 2) return ExampleContext(account_id=account_id, provider_id=provider_id) @@ -213,10 +219,18 @@ session_maker = sessionmaker( class_=RlsSession, autoflush=False, autocommit=False, bind=engine ) +my_sessioner = RlsSessioner(sessionmaker=session_maker, context_getter=my_context) + -with RlsSessioner(sessionmaker=session_maker, context_getter=my_context)() as session: + +with my_sessioner(account_id=22, provider_id=99) as session: res = session.execute(text("SELECT * FROM users")).fetchall() + print(res) # output: List of users with account_id = 22 and provider_id = 99 + +with my_sessioner(account_id=11, provider_id=44) as session: + res = session.execute(text("SELECT * FROM users")).fetchall() + print(res) # output: List of users with account_id = 11 and provider_id = 44 ``` --- @@ -230,9 +244,34 @@ if you are trying to use the `RlsSessioner` with fastapi you may face some diffi ```python from rls.rls_sessioner import fastapi_dependency_function +from fastapi import Request app = FastAPI() +class ExampleContext(BaseModel): + account_id: int + provider_id: int + + +# Concrete implementation of ContextGetter +class ExampleContextGetter(ContextGetter): + def get_context(self, *args, **kwargs) -> ExampleContext: + request: Request = kwargs.get('request') + + account_id = request.headers.get('account_id') + provider_id = request.headers.get('provider_id') + + return ExampleContext(account_id=account_id, provider_id=provider_id) + + +my_context = ExampleContextGetter() + +session_maker = sessionmaker( + class_=RlsSession, autoflush=False, autocommit=False, bind=engine +) + + + rls_sessioner = RlsSessioner(sessionmaker=session_maker, context_getter=my_context) my_session = Depends(fastapi_dependency_function(rls_sessioner)) diff --git a/delete-policies.sql b/delete-policies.sql deleted file mode 100644 index 9b2c0f3..0000000 --- a/delete-policies.sql +++ /dev/null @@ -1,12 +0,0 @@ -DO $$ -DECLARE - r RECORD; -BEGIN - FOR r IN - SELECT policyname - FROM pg_policies - WHERE tablename = 'items' - LOOP - EXECUTE 'DROP POLICY ' || quote_ident(r.policyname) || ' ON items'; - END LOOP; -END $$; diff --git a/test/test_sessioner.py b/test/test_sessioner.py index 30e581c..4c0bcf7 100644 --- a/test/test_sessioner.py +++ b/test/test_sessioner.py @@ -6,7 +6,6 @@ from sqlalchemy import text -# Mock context model class ExampleContext(BaseModel): account_id: int provider_id: int @@ -15,7 +14,9 @@ class ExampleContext(BaseModel): # Concrete implementation of ContextGetter class ExampleContextGetter(ContextGetter): def get_context(self, *args, **kwargs) -> ExampleContext: - return ExampleContext(account_id=1, provider_id=2) + account_id = kwargs.get("account_id", 1) + provider_id = kwargs.get("provider_id", 2) + return ExampleContext(account_id=account_id, provider_id=provider_id) my_context = ExampleContextGetter() @@ -24,7 +25,14 @@ def get_context(self, *args, **kwargs) -> ExampleContext: class_=RlsSession, autoflush=False, autocommit=False, bind=engine ) +my_sessioner = RlsSessioner(sessionmaker=session_maker, context_getter=my_context) -with RlsSessioner(sessionmaker=session_maker, context_getter=my_context)() as session: + +with my_sessioner(account_id=22, provider_id=99) as session: + res = session.execute(text("SELECT * FROM users")).fetchall() + print(res) # output: List of users with account_id = 22 and provider_id = 99 + + +with my_sessioner(account_id=11, provider_id=44) as session: res = session.execute(text("SELECT * FROM users")).fetchall() - print(res) + print(res) # output: List of users with account_id = 11 and provider_id = 44