Skip to content

Commit

Permalink
docs: sessioner example update
Browse files Browse the repository at this point in the history
Co-authored-by: Ghaith Kdimati <[email protected]>"
  • Loading branch information
Ofahmy143 committed Sep 29, 2024
1 parent 2cde372 commit 223ec46
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 21 deletions.
49 changes: 44 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,14 @@ which takes two arguments:
for which you have

```python
# Mock context model
from sqlalchemy.orm import sessionmaker
from rls.rls_session import RlsSession
from rls.rls_sessioner import RlsSessioner, ContextGetter
from pydantic import BaseModel
from test.engines import sync_engine as engine
from sqlalchemy import text


class ExampleContext(BaseModel):
account_id: int
provider_id: int
Expand All @@ -201,9 +208,8 @@ class ExampleContext(BaseModel):
# Concrete implementation of ContextGetter
class ExampleContextGetter(ContextGetter):
def get_context(self, *args, **kwargs) -> ExampleContext:
account_id = kwargs.get('acount_id')
provider_id = kwargs.get('provider_id')

account_id = kwargs.get("account_id", 1)
provider_id = kwargs.get("provider_id", 2)
return ExampleContext(account_id=account_id, provider_id=provider_id)


Expand All @@ -213,10 +219,18 @@ session_maker = sessionmaker(
class_=RlsSession, autoflush=False, autocommit=False, bind=engine
)

my_sessioner = RlsSessioner(sessionmaker=session_maker, context_getter=my_context)


with RlsSessioner(sessionmaker=session_maker, context_getter=my_context)() as session:

with my_sessioner(account_id=22, provider_id=99) as session:
res = session.execute(text("SELECT * FROM users")).fetchall()
print(res) # output: List of users with account_id = 22 and provider_id = 99


with my_sessioner(account_id=11, provider_id=44) as session:
res = session.execute(text("SELECT * FROM users")).fetchall()
print(res) # output: List of users with account_id = 11 and provider_id = 44
```

---
Expand All @@ -230,9 +244,34 @@ if you are trying to use the `RlsSessioner` with fastapi you may face some diffi
```python

from rls.rls_sessioner import fastapi_dependency_function
from fastapi import Request

app = FastAPI()

class ExampleContext(BaseModel):
account_id: int
provider_id: int


# Concrete implementation of ContextGetter
class ExampleContextGetter(ContextGetter):
def get_context(self, *args, **kwargs) -> ExampleContext:
request: Request = kwargs.get('request')

account_id = request.headers.get('account_id')
provider_id = request.headers.get('provider_id')

return ExampleContext(account_id=account_id, provider_id=provider_id)


my_context = ExampleContextGetter()

session_maker = sessionmaker(
class_=RlsSession, autoflush=False, autocommit=False, bind=engine
)



rls_sessioner = RlsSessioner(sessionmaker=session_maker, context_getter=my_context)
my_session = Depends(fastapi_dependency_function(rls_sessioner))

Expand Down
12 changes: 0 additions & 12 deletions delete-policies.sql

This file was deleted.

16 changes: 12 additions & 4 deletions test/test_sessioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from sqlalchemy import text


# Mock context model
class ExampleContext(BaseModel):
account_id: int
provider_id: int
Expand All @@ -15,7 +14,9 @@ class ExampleContext(BaseModel):
# Concrete implementation of ContextGetter
class ExampleContextGetter(ContextGetter):
def get_context(self, *args, **kwargs) -> ExampleContext:
return ExampleContext(account_id=1, provider_id=2)
account_id = kwargs.get("account_id", 1)
provider_id = kwargs.get("provider_id", 2)
return ExampleContext(account_id=account_id, provider_id=provider_id)


my_context = ExampleContextGetter()
Expand All @@ -24,7 +25,14 @@ def get_context(self, *args, **kwargs) -> ExampleContext:
class_=RlsSession, autoflush=False, autocommit=False, bind=engine
)

my_sessioner = RlsSessioner(sessionmaker=session_maker, context_getter=my_context)

with RlsSessioner(sessionmaker=session_maker, context_getter=my_context)() as session:

with my_sessioner(account_id=22, provider_id=99) as session:
res = session.execute(text("SELECT * FROM users")).fetchall()
print(res) # output: List of users with account_id = 22 and provider_id = 99


with my_sessioner(account_id=11, provider_id=44) as session:
res = session.execute(text("SELECT * FROM users")).fetchall()
print(res)
print(res) # output: List of users with account_id = 11 and provider_id = 44

0 comments on commit 223ec46

Please sign in to comment.