Skip to content

Commit

Permalink
Merge branch 'development' into ft-dynamodb-extended
Browse files Browse the repository at this point in the history
  • Loading branch information
MauriceBrg authored Apr 18, 2024
2 parents 0001bb3 + 95663ab commit dcc0de2
Show file tree
Hide file tree
Showing 14 changed files with 351 additions and 5 deletions.
15 changes: 15 additions & 0 deletions .github/workflows/test.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,21 @@ jobs:
image: amazon/dynamodb-local
ports:
- 8000:8000

postgresql:
image: postgres:latest
ports:
- 5433:5432
env:
POSTGRES_PASSWORD: pwd
POSTGRES_USER: root
POSTGRES_DB: dummy
# Set health checks to wait until postgres has started
options: >-
--health-cmd pg_isready
--health-interval 10s
--health-timeout 5s
--health-retries 5
steps:
- uses: actions/checkout@v4
- uses: supercharge/[email protected]
Expand Down
1 change: 1 addition & 0 deletions CONTRIBUTORS.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
## Contributors

- [giuppep](https://github.com/giuppep)
- [MauriceBrg](https://github.com/MauriceBrg)
- [eiriklid](https://github.com/eiriklid)
- [necat1](https://github.com/necat1)
Expand Down
11 changes: 11 additions & 0 deletions docker-compose.yml
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,17 @@ services:
ports:
- "11211:11211"

postgres:
image: postgres:latest
environment:
- POSTGRES_USER=root
- POSTGRES_PASSWORD=pwd
- POSTGRES_DB=dummy
ports:
- "5433:5432"
volumes:
- postgres_data:/var/lib/postgresql/data

volumes:
postgres_data:
mongo_data:
Expand Down
3 changes: 2 additions & 1 deletion docs/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -20,4 +20,5 @@ Anything documented here is part of the public API that Flask-Session provides,
.. autoclass:: flask_session.cachelib.CacheLibSessionInterface
.. autoclass:: flask_session.mongodb.MongoDBSessionInterface
.. autoclass:: flask_session.sqlalchemy.SqlAlchemySessionInterface
.. autoclass:: flask_session.dynamodb.DynamoDBSessionInterface
.. autoclass:: flask_session.dynamodb.DynamoDBSessionInterface
.. autoclass:: flask_session.postgresql.PostgreSqlSessionInterface
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -88,4 +88,5 @@ dev-dependencies = [
"boto3>=1.34.68",
"mypy_boto3_dynamodb>=1.34.67",
"pymemcache>=4.0.0",
"psycopg2-binary>=2",
]
1 change: 1 addition & 0 deletions requirements/dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,5 @@ Flask-SQLAlchemy
pymongo
boto3
mypy_boto3_dynamodb
psycopg2-binary

3 changes: 2 additions & 1 deletion requirements/docs.in
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,5 @@ pymongo
flask_sqlalchemy
pymemcache
boto3
mypy_boto3_dynamodb
mypy_boto3_dynamodb
psycopg2-binary
6 changes: 6 additions & 0 deletions requirements/docs.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
#
alabaster==0.7.13
# via sphinx
async-timeout==4.0.3
# via redis
babel==2.12.1
# via sphinx
beautifulsoup4==4.12.3
Expand Down Expand Up @@ -36,6 +38,8 @@ flask-sqlalchemy==3.1.1
# via -r requirements/docs.in
furo==2024.1.29
# via -r requirements/docs.in
greenlet==3.0.3
# via sqlalchemy
idna==3.4
# via requests
imagesize==1.4.1
Expand All @@ -58,6 +62,8 @@ mypy-boto3-dynamodb==1.34.67
# via -r requirements/docs.in
packaging==23.1
# via sphinx
psycopg2-binary==2.9.9
# via -r requirements/docs.in
pygments==2.15.1
# via
# furo
Expand Down
30 changes: 27 additions & 3 deletions src/flask_session/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,9 +100,6 @@ def _get_interface(self, app):
SESSION_SQLALCHEMY_BIND_KEY = config.get(
"SESSION_SQLALCHEMY_BIND_KEY", Defaults.SESSION_SQLALCHEMY_BIND_KEY
)
SESSION_CLEANUP_N_REQUESTS = config.get(
"SESSION_CLEANUP_N_REQUESTS", Defaults.SESSION_CLEANUP_N_REQUESTS
)

# DynamoDB settings
SESSION_DYNAMODB = config.get("SESSION_DYNAMODB", Defaults.SESSION_DYNAMODB)
Expand All @@ -113,6 +110,22 @@ def _get_interface(self, app):
"SESSION_DYNAMODB_TABLE_EXISTS", Defaults.SESSION_DYNAMODB_TABLE_EXISTS
)

# PostgreSQL settings
SESSION_POSTGRESQL = config.get(
"SESSION_POSTGRESQL", Defaults.SESSION_POSTGRESQL
)
SESSION_POSTGRESQL_TABLE = config.get(
"SESSION_POSTGRESQL_TABLE", Defaults.SESSION_POSTGRESQL_TABLE
)
SESSION_POSTGRESQL_SCHEMA = config.get(
"SESSION_POSTGRESQL_SCHEMA", Defaults.SESSION_POSTGRESQL_SCHEMA
)

# Shared settings
SESSION_CLEANUP_N_REQUESTS = config.get(
"SESSION_CLEANUP_N_REQUESTS", Defaults.SESSION_CLEANUP_N_REQUESTS
)

common_params = {
"app": app,
"key_prefix": SESSION_KEY_PREFIX,
Expand Down Expand Up @@ -184,6 +197,17 @@ def _get_interface(self, app):
table_exists=SESSION_DYNAMODB_TABLE_EXISTS,
)

elif SESSION_TYPE == "postgresql":
from .postgresql import PostgreSqlSessionInterface

session_interface = PostgreSqlSessionInterface(
**common_params,
pool=SESSION_POSTGRESQL,
table=SESSION_POSTGRESQL_TABLE,
schema=SESSION_POSTGRESQL_SCHEMA,
cleanup_n_requests=SESSION_CLEANUP_N_REQUESTS,
)

else:
raise ValueError(f"Unrecognized value for SESSION_TYPE: {SESSION_TYPE}")

Expand Down
5 changes: 5 additions & 0 deletions src/flask_session/defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,3 +44,8 @@ class Defaults:
SESSION_DYNAMODB = None
SESSION_DYNAMODB_TABLE = "Sessions"
SESSION_DYNAMODB_TABLE_EXISTS = False

# PostgreSQL settings
SESSION_POSTGRESQL = None
SESSION_POSTGRESQL_TABLE = "flask_sessions"
SESSION_POSTGRESQL_SCHEMA = "public"
1 change: 1 addition & 0 deletions src/flask_session/postgresql/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .postgresql import PostgreSqlSession, PostgreSqlSessionInterface # noqa: F401
84 changes: 84 additions & 0 deletions src/flask_session/postgresql/_queries.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
from psycopg2 import sql


class Queries:
def __init__(self, schema: str, table: str) -> None:
"""Class to hold all the queries used by the session interface.
Args:
schema (str): The name of the schema to use for the session data.
table (str): The name of the table to use for the session data.
"""
self.schema = schema
self.table = table

@property
def create_schema(self) -> str:
return sql.SQL("CREATE SCHEMA IF NOT EXISTS {schema};").format(
schema=sql.Identifier(self.schema)
)

@property
def create_table(self) -> str:
uq_idx = sql.Identifier(f"uq_{self.table}_session_id")
expiry_idx = sql.Identifier(f"{self.table}_expiry_idx")
return sql.SQL(
"""CREATE TABLE IF NOT EXISTS {schema}.{table} (
session_id VARCHAR(255) NOT NULL PRIMARY KEY,
created TIMESTAMP WITHOUT TIME ZONE DEFAULT (NOW() AT TIME ZONE 'utc'),
data BYTEA,
expiry TIMESTAMP WITHOUT TIME ZONE
);
--- Unique session_id
CREATE UNIQUE INDEX IF NOT EXISTS
{uq_idx} ON {schema}.{table} (session_id);
--- Index for expiry timestamp
CREATE INDEX IF NOT EXISTS
{expiry_idx} ON {schema}.{table} (expiry);"""
).format(
schema=sql.Identifier(self.schema),
table=sql.Identifier(self.table),
uq_idx=uq_idx,
expiry_idx=expiry_idx,
)

@property
def retrieve_session_data(self) -> str:
return sql.SQL(
"""--- If the current sessions is expired, delete it
DELETE FROM {schema}.{table}
WHERE session_id = %(session_id)s AND expiry < NOW();
--- Else retrieve it
SELECT data FROM {schema}.{table} WHERE session_id = %(session_id)s;
"""
).format(schema=sql.Identifier(self.schema), table=sql.Identifier(self.table))

@property
def upsert_session(self) -> str:
return sql.SQL(
"""INSERT INTO {schema}.{table} (session_id, data, expiry)
VALUES (%(session_id)s, %(data)s, NOW() + %(ttl)s)
ON CONFLICT (session_id)
DO UPDATE SET data = %(data)s, expiry = NOW() + %(ttl)s;
"""
).format(schema=sql.Identifier(self.schema), table=sql.Identifier(self.table))

@property
def delete_expired_sessions(self) -> str:
return sql.SQL("DELETE FROM {schema}.{table} WHERE expiry < NOW();").format(
schema=sql.Identifier(self.schema), table=sql.Identifier(self.table)
)

@property
def delete_session(self) -> str:
return sql.SQL(
"DELETE FROM {schema}.{table} WHERE session_id = %(session_id)s;"
).format(schema=sql.Identifier(self.schema), table=sql.Identifier(self.table))

@property
def drop_sessions_table(self) -> str:
return sql.SQL("DROP TABLE IF EXISTS {schema}.{table};").format(
schema=sql.Identifier(self.schema), table=sql.Identifier(self.table)
)
145 changes: 145 additions & 0 deletions src/flask_session/postgresql/postgresql.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
from __future__ import annotations

from contextlib import contextmanager
from datetime import timedelta as TimeDelta
from typing import Generator, Optional

from flask import Flask
from itsdangerous import want_bytes
from psycopg2.extensions import connection as PsycoPg2Connection
from psycopg2.extensions import cursor as PsycoPg2Cursor
from psycopg2.pool import ThreadedConnectionPool

from .._utils import retry_query
from ..base import ServerSideSession, ServerSideSessionInterface
from ..defaults import Defaults
from ._queries import Queries


class PostgreSqlSession(ServerSideSession):
pass


class PostgreSqlSessionInterface(ServerSideSessionInterface):
"""A Session interface that uses PostgreSQL as a session storage. (`psycopg2` required)
:param pool: A ``psycopg2.pool.ThreadedConnectionPool`` instance.
:param key_prefix: A prefix that is added to all storage keys.
:param use_signer: Whether to sign the session id cookie or not.
:param permanent: Whether to use permanent session or not.
:param sid_length: The length of the generated session id in bytes.
:param serialization_format: The serialization format to use for the session data.
:param table: The table name you want to use.
:param schema: The db schema to use.
:param cleanup_n_requests: Delete expired sessions on average every N requests.
"""

session_class = PostgreSqlSession
ttl = False

def __init__(
self,
app: Flask,
pool: Optional[ThreadedConnectionPool] = Defaults.SESSION_POSTGRESQL,
key_prefix: str = Defaults.SESSION_KEY_PREFIX,
use_signer: bool = Defaults.SESSION_USE_SIGNER,
permanent: bool = Defaults.SESSION_PERMANENT,
sid_length: int = Defaults.SESSION_ID_LENGTH,
serialization_format: str = Defaults.SESSION_SERIALIZATION_FORMAT,
table: str = Defaults.SESSION_POSTGRESQL_TABLE,
schema: str = Defaults.SESSION_POSTGRESQL_SCHEMA,
cleanup_n_requests: Optional[int] = Defaults.SESSION_CLEANUP_N_REQUESTS,
) -> None:
if not isinstance(pool, ThreadedConnectionPool):
raise TypeError("No valid ThreadedConnectionPool instance provided.")

self.pool = pool

self._table = table
self._schema = schema

self._queries = Queries(schema=self._schema, table=self._table)

self._create_schema_and_table()

super().__init__(
app,
key_prefix,
use_signer,
permanent,
sid_length,
serialization_format,
cleanup_n_requests,
)

@contextmanager
def _get_cursor(
self, conn: Optional[PsycoPg2Connection] = None
) -> Generator[PsycoPg2Cursor, None, None]:
_conn: PsycoPg2Connection = conn or self.pool.getconn()

assert isinstance(_conn, PsycoPg2Connection)
try:
with _conn, _conn.cursor() as cur:
yield cur
except Exception:
raise
finally:
self.pool.putconn(_conn)

@retry_query(max_attempts=3)
def _create_schema_and_table(self) -> None:
with self._get_cursor() as cur:
cur.execute(self._queries.create_schema)
cur.execute(self._queries.create_table)

def _delete_expired_sessions(self) -> None:
"""Delete all expired sessions from the database."""
with self._get_cursor() as cur:
cur.execute(self._queries.delete_expired_sessions)

@retry_query(max_attempts=3)
def _delete_session(self, store_id: str) -> None:
with self._get_cursor() as cur:
cur.execute(
self._queries.delete_session,
dict(session_id=store_id),
)

@retry_query(max_attempts=3)
def _retrieve_session_data(self, store_id: str) -> Optional[dict]:
with self._get_cursor() as cur:
cur.execute(
self._queries.retrieve_session_data,
dict(session_id=store_id),
)
session_data = cur.fetchone()

if session_data is not None:
serialized_session_data = want_bytes(session_data[0])
return self.serializer.loads(serialized_session_data)
return None

@retry_query(max_attempts=3)
def _upsert_session(
self, session_lifetime: TimeDelta, session: ServerSideSession, store_id: str
) -> None:

serialized_session_data = self.serializer.dumps(session)

if session.sid is not None:
assert session.sid == store_id.removeprefix(self.key_prefix)

with self._get_cursor() as cur:
cur.execute(
self._queries.upsert_session,
dict(
session_id=store_id,
data=serialized_session_data,
ttl=session_lifetime,
),
)

def _drop_table(self) -> None:
with self._get_cursor() as cur:
cur.execute(self._queries.drop_sessions_table)
Loading

0 comments on commit dcc0de2

Please sign in to comment.