diff --git a/app.py b/app.py index 97093c3..f99bd05 100644 --- a/app.py +++ b/app.py @@ -118,7 +118,7 @@ async def get_user( if session is None: request.state.user = None else: - user_logic = UserLogic(mongo, request.app.state.session) + user_logic = UserLogic(request.app.state.session) user = await user_logic.get_user(session["user_email"]) if user is not None: request.state.user = user diff --git a/logic/auth_logic.py b/logic/auth_logic.py index bc8d687..c888651 100644 --- a/logic/auth_logic.py +++ b/logic/auth_logic.py @@ -24,7 +24,7 @@ def __init__( session: Session, auth_client_id: str, ) -> None: - self.user_logic = UserLogic(mongo, session) + self.user_logic = UserLogic(session) self.sessions = mongo.sessions self.auth_client_id = auth_client_id diff --git a/logic/user_logic.py b/logic/user_logic.py index 900d111..07fb390 100644 --- a/logic/user_logic.py +++ b/logic/user_logic.py @@ -3,12 +3,12 @@ import datetime import hashlib -import sys from typing import TYPE_CHECKING from dateutil.relativedelta import relativedelta from sqlalchemy import func from sqlmodel import asc +from sqlmodel import col from sqlmodel import select from common import config @@ -17,11 +17,9 @@ from common.session import hs_transaction if TYPE_CHECKING: - from typing import Any from typing import Awaitable from typing import Callable - import motor.core from sqlmodel import Session from sqlmodel.sql.expression import SelectOfScalar @@ -35,10 +33,7 @@ class UserLogic: SUPER_ADMIN, ) - def __init__( - self, mongo: motor.core.AgnosticDatabase[Any], session: Session - ) -> None: - self.mongo = mongo + def __init__(self, session: Session) -> None: self.session = session async def create_user(self, user_info: dict[str, str]) -> tables.User: @@ -53,7 +48,6 @@ async def create_user(self, user_info: dict[str, str]) -> tables.User: "family_name": user_info.get("family_name", ""), "first_login": datetime.datetime.utcnow(), } - await self.mongo.users.insert_one(user) with hs_transaction(self.session, expire_on_commit=False) as session: db_user = tables.User(**user) session.add(db_user) @@ -124,12 +118,10 @@ def _get_subscription_duration( class UserHistoryLogic: def __init__( self, - mongo: motor.core.AgnosticDatabase[Any], session: Session, user: tables.User, date: datetime.date, ): - self.mongo = mongo self.session = session self.user = user self.dt = date # TODO: use this @@ -161,101 +153,98 @@ async def update_and_get_history( solver_count=guess.solver_count, ) ) - return await self._fix_history(history, update_db=True) + return history else: return [guess] + history async def get_history(self) -> list[schemas.DistanceResponse]: - user_data = await self.mongo.users.find_one( - self.user_filter, projection=self.projection - ) - if user_data is None: - raise ValueError("User not found") # TODO: use our own error - raw_history = user_data.get("history", []) - history = [] - guesses = set() - for document in raw_history: - historia = schemas.DistanceResponse(**document) - if historia.guess not in guesses: - guesses.add(historia.guess) - history.append(historia) - return await self._fix_history( - history, update_db=len(history) != len(raw_history) - ) - - async def _fix_history( - self, history: list[schemas.DistanceResponse], update_db: bool - ) -> list[schemas.DistanceResponse]: - for i, historia in enumerate(history, start=1): - historia.guess_number = i - if update_db: - # fix duplicates - await self.mongo.users.update_one( - self.user_filter, - { - "$set": { - f"history.{self.date}": [ - historia.model_dump() for historia in history - ] - } - }, + with hs_transaction(self.session, expire_on_commit=False) as session: + history_query = select(tables.UserHistory) + history_query = history_query.where( + tables.UserHistory.user_id == self.user.id + ) + history_query = history_query.where( + tables.UserHistory.game_date == self.date ) - return history + history_query = history_query.order_by(col(tables.UserHistory.id)) + history = session.exec(history_query).all() + return [ + schemas.DistanceResponse( + guess=historia.guess, + similarity=historia.similarity, + distance=historia.distance, + egg=historia.egg, + solver_count=historia.solver_count, + guess_number=i, + ) + for i, historia in enumerate(history, start=1) + ] class UserStatisticsLogic: - def __init__(self, mongo: motor.core.AgnosticDatabase[Any], user: tables.User): - self.mongo = mongo + def __init__(self, session: Session, user: tables.User): + self.session = session self.user = user async def get_statistics(self) -> schemas.UserStatistics: - # TODO: for now this is good enough, but we can do it with aggregation. - # we should probably change the way we save the data - instead of having - # saving history as an object with dates as its members, it should consist of - # list with dates as its members. - user = await self.mongo.users.find_one( - {"email": self.user.email}, projection={"history": 1} + stats_subquery = select( + tables.UserHistory.similarity, + tables.UserHistory.solver_count, + func.row_number() + .over( + partition_by=[col(tables.UserHistory.game_date)], + order_by=col(tables.UserHistory.id), + ) + .label("guess_number"), ) - if user is None: - raise ValueError("User not found") # TODO: use our own error - user_history = user.get("history", {}) - user_history = { - date: [schemas.DistanceResponse(**guess) for guess in history] - for date, history in user_history.items() - if history - } + stats_subquery = stats_subquery.select_from(tables.UserHistory) + stats_subquery = stats_subquery.where( + tables.UserHistory.user_id == self.user.id + ) + stats_sub = stats_subquery.subquery() + stats_query = select( + func.count(), + func.min(stats_sub.c.solver_count), + func.avg(stats_sub.c.guess_number), + ) + stats_query = stats_query.select_from(stats_sub) + stats_query = stats_query.where(stats_sub.c.similarity == 100) - game_streak = self._get_game_streak(user_history.keys()) + with hs_transaction(self.session, expire_on_commit=False) as session: + stats = session.exec(stats_query).one_or_none() - highest_rank = None - total_games_won = 0 - total_guesses = 0 - for historia in user_history.values(): - for guess in historia: - if guess.similarity == 100: - total_games_won += 1 - highest_rank = min(highest_rank or sys.maxsize, guess.solver_count) - total_guesses += guess.guess_number + if stats is None: + total_games_won, highest_rank, avg_guesses = 0, None, 0 + else: + total_games_won, highest_rank, avg_guesses = stats + + game_streak, total_games_played = self._get_game_streak_and_total() return schemas.UserStatistics( game_streak=game_streak, highest_rank=highest_rank, - total_games_played=len(user_history), + total_games_played=total_games_played, total_games_won=total_games_won, - average_guesses=total_guesses / total_games_won if total_games_won else 0, + average_guesses=avg_guesses, ) - def _get_game_streak(self, game_dates: list[str]) -> int: - game_dates = sorted(game_dates, reverse=True) + def _get_game_streak_and_total(self) -> tuple[int, int]: + dates_query = select(col(tables.UserHistory.game_date)) + dates_query = dates_query.where(tables.UserHistory.user_id == self.user.id) + dates_query = dates_query.group_by(col(tables.UserHistory.game_date)) + dates_query = dates_query.order_by(col(tables.UserHistory.game_date).desc()) + with hs_transaction(session=self.session, expire_on_commit=False) as session: + game_dates = session.exec(dates_query).all() + date = datetime.datetime.utcnow().date() game_streak = 0 for game_date in game_dates: - if str(date) == game_date: + if date == game_date: game_streak += 1 date -= datetime.timedelta(days=1) else: break - return game_streak + return game_streak, len(game_dates) class UserClueLogic: diff --git a/routers/game_routes.py b/routers/game_routes.py index 5794343..ff2dc78 100644 --- a/routers/game_routes.py +++ b/routers/game_routes.py @@ -44,7 +44,6 @@ async def distance( if request.headers.get("x-sh-version", "2022-02-20") >= "2023-09-10": if request.state.user: history_logic = UserHistoryLogic( - request.app.state.mongo, request.app.state.session, request.state.user, get_date(request.app.state.days_delta), diff --git a/routers/pages_routes.py b/routers/pages_routes.py index 241cfa5..7d402d1 100644 --- a/routers/pages_routes.py +++ b/routers/pages_routes.py @@ -45,7 +45,6 @@ async def index(request: Request) -> Response: if request.state.user: history_logic = UserHistoryLogic( - request.app.state.mongo, request.app.state.session, request.state.user, get_date(request.app.state.days_delta), @@ -147,7 +146,7 @@ async def menu(request: Request) -> Response: @pages_router.get("/statistics", response_class=HTMLResponse, include_in_schema=False) async def get_statistics(request: Request) -> Response: if request.state.user is not None: - logic = UserStatisticsLogic(request.app.state.mongo, request.state.user) + logic = UserStatisticsLogic(request.app.state.session, request.state.user) statistics = await logic.get_statistics() else: statistics = None diff --git a/routers/subscription_routes.py b/routers/subscription_routes.py index cea5543..955b777 100644 --- a/routers/subscription_routes.py +++ b/routers/subscription_routes.py @@ -30,7 +30,7 @@ async def subscribe(request: Request, data: Annotated[str, Form()]) -> dict[str, ) if not is_valid_token or not is_new_message: raise HTTPException(status_code=status.HTTP_403_FORBIDDEN) - logic = UserLogic(mongo=request.app.state.mongo, session=request.app.state.session) + logic = UserLogic(session=request.app.state.session) success = await logic.subscribe(subscription) success_message = "Success :smile:" if success else "Failed :rage:" requests.post(