Skip to content

Commit

Permalink
Add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
ishefi committed Apr 17, 2024
1 parent b10330b commit 3075b22
Show file tree
Hide file tree
Showing 8 changed files with 98 additions and 34 deletions.
2 changes: 1 addition & 1 deletion logic/game_logic.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ async def get_secret(self) -> str:

@staticmethod
@lru_cache
def _get_cached_secret(session: Session, date: datetime.date) -> str | None:
def _get_cached_secret(session: Session, date: datetime.date) -> str:
# TODO: this function is accessing db but is NOT ASYNC, which might be
# problematic if we choose to do async stuff with sql in the future.
# the reason for that is `@lru_cache` does not support async.
Expand Down
26 changes: 16 additions & 10 deletions mock/mock_db.py
Original file line number Diff line number Diff line change
@@ -1,38 +1,44 @@
from __future__ import annotations

import sqlite3
from typing import TYPE_CHECKING

from sqlalchemy import event
from sqlalchemy import Engine
from sqlalchemy import event
from sqlmodel import Session
from sqlmodel import SQLModel
from sqlmodel import StaticPool
from sqlmodel import create_engine
from sqlmodel import Session

from common import tables


if TYPE_CHECKING:
from typing import Any
from typing import TypeVar
T = TypeVar("T", bound=tables.SQLModel)

T = TypeVar("T", bound=SQLModel)


def collation(string1, string2):
def collation(string1: str, string2: str) -> int:
if string1 == string2:
return 0
elif string1 > string2:
return 1
else:
return -1


@event.listens_for(Engine, "connect")
def set_sqlite_pragma(dbapi_connection, dummy_connection_record):
def set_sqlite_pragma(
dbapi_connection: sqlite3.Connection, dummy_connection_record: Any
) -> None:
dbapi_connection.create_collation("Hebrew_100_CI_AI_SC_UTF8", collation)
dbapi_connection.create_collation("Hebrew_CI_AI", collation)
cursor = dbapi_connection.cursor()
cursor.execute("PRAGMA foreign_keys=ON")
cursor.close()


class MockDb:
def __init__(self):
def __init__(self) -> None:
self.db_uri = "sqlite:///:memory:?cache=shared"
self.engine = create_engine(
self.db_uri,
Expand All @@ -44,7 +50,7 @@ def __init__(self):
expire_on_commit=False,
autoflush=True,
)
tables.SQLModel.metadata.create_all(self.engine)
SQLModel.metadata.create_all(self.engine)

def add(self, entity: T) -> T:
self.session.begin()
Expand Down
35 changes: 34 additions & 1 deletion poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ types-redis = "^4.6.0.11"
ruff = "0.3.4"
alembic = "^1.13.1"
pytest = "^8.1.1"
pytest-sugar = "^1.0.0"

[tool.ruff]
fix = true
Expand Down
47 changes: 34 additions & 13 deletions routers/admin_routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,15 @@
from fastapi.responses import HTMLResponse
from pydantic import BaseModel
from sqlmodel import Session
from sqlmodel import select

from common import tables
from common.session import hs_transaction
from logic.game_logic import SecretLogic, CacheSecretLogic
from logic.game_logic import CacheSecretLogic
from logic.game_logic import SecretLogic
from model import GensimModel
from routers.base import render
from routers.base import super_admin
from sqlmodel import select
from common import tables

TOP_SAMPLE = 10000

Expand All @@ -25,49 +26,69 @@
async def index(request: Request) -> HTMLResponse:
model = request.app.state.model
secret_logic = SecretLogic(request.app.state.session)
all_secrets = await secret_logic.get_all_secrets(with_future=True)
potential_secrets = []
all_secrets = [
secret[0] for secret in await secret_logic.get_all_secrets(with_future=True)
]
potential_secrets: list[str] = []
while len(potential_secrets) < 45:
secret = await get_random_word(model) # todo: in batches
if secret not in all_secrets:
potential_secrets.append(secret)

return render(name="set_secret.html", request=request, potential_secrets=potential_secrets)
return render(
name="set_secret.html", request=request, potential_secrets=potential_secrets
)


@admin_router.get("/model", include_in_schema=False)
async def get_word_data(request: Request, word: str) -> dict[str, list[str] | datetime.date]:
async def get_word_data(
request: Request, word: str
) -> dict[str, list[str] | datetime.date]:
session = request.app.state.session
redis = request.app.state.redis
model = request.app.state.model
logic = CacheSecretLogic(session=session, redis=redis, secret=word, dt=await get_date(session), model=model)
logic = CacheSecretLogic(
session=session,
redis=redis,
secret=word,
dt=await get_date(session),
model=model,
)
await logic.simulate_set_secret(force=False)
cache = await logic.cache
return {
"date": logic.date_,
"data": cache[::-1],
}


class SetSecretRequest(BaseModel):
secret: str
clues: list[str]


@admin_router.post("/set-secret", include_in_schema=False)
async def set_new_secret(request: Request, set_secret: SetSecretRequest):
async def set_new_secret(request: Request, set_secret: SetSecretRequest) -> str:
session = request.app.state.session
redis = request.app.state.redis
model = request.app.state.model
logic = CacheSecretLogic(session=session, redis=redis, secret=set_secret.secret, dt=await get_date(session), model=model)
logic = CacheSecretLogic(
session=session,
redis=redis,
secret=set_secret.secret,
dt=await get_date(session),
model=model,
)
await logic.simulate_set_secret(force=False)
await logic.do_populate(set_secret.clues)
return f"Set '{set_secret.secret}' with clues '{set_secret.clues}' on {logic.date_}"



# TODO: everything below here should be in a separate file, and set_secret script should be updated to use it
async def get_random_word(model: GensimModel) -> str:
rand_index = random.randint(0, TOP_SAMPLE)
return model.model.index_to_key[rand_index]
word: str = model.model.index_to_key[rand_index]
return word


async def get_date(session: Session) -> datetime.date:
Expand All @@ -77,4 +98,4 @@ async def get_date(session: Session) -> datetime.date:
latest: datetime.date = s.exec(query).first()

dt = latest + datetime.timedelta(days=1)
return dt
return dt
8 changes: 5 additions & 3 deletions routers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,11 @@

from fastapi import FastAPI
from fastapi import HTTPException
from fastapi import Request
from fastapi import status
from fastapi.templating import Jinja2Templates
from fastapi.responses import HTMLResponse
from fastapi.templating import Jinja2Templates

from fastapi import Request
from logic.game_logic import CacheSecretLogic
from logic.game_logic import VectorLogic
from logic.user_logic import UserLogic
Expand All @@ -32,7 +32,9 @@ async def get_logics(
delta += app.state.days_delta
date = get_date(delta)
logic = VectorLogic(app.state.session, dt=date, model=app.state.model)
secret = await logic.secret_logic.get_secret() # TODO: raise a user-friendly exception
secret = (
await logic.secret_logic.get_secret()
) # TODO: raise a user-friendly exception
cache_logic = CacheSecretLogic(
app.state.session,
app.state.redis,
Expand Down
3 changes: 1 addition & 2 deletions scripts/semantle.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,12 @@
import sys
from datetime import datetime


base = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
sys.path.extend([base])

from common.session import get_session # noqa: E402
from common.session import get_model # noqa: E402
from common.session import get_redis # noqa: E402
from common.session import get_session # noqa: E402
from logic.game_logic import CacheSecretLogic # noqa: E402
from logic.game_logic import VectorLogic # noqa: E402

Expand Down
10 changes: 6 additions & 4 deletions tests/test_secret_logic.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import unittest
import datetime
import unittest

import pytest
from sqlmodel import Session

Expand All @@ -10,7 +11,7 @@


class TestGameLogic(unittest.IsolatedAsyncioTestCase):
async def asyncSetUp(self):
async def asyncSetUp(self) -> None:
self.db = MockDb()
self.date = datetime.date(2021, 1, 1)
self.testee = SecretLogic(session=self.db.session, dt=self.date)
Expand All @@ -31,12 +32,13 @@ async def test_get_secret(self) -> None:
# assert
self.assertEqual(db_secret.word, secret)

async def test_get_secret__cache(self):
async def test_get_secret__cache(self) -> None:
# arrange
cached = self.db.add(tables.SecretWord(word="cached", game_date=self.date))
await self.testee.get_secret()
with Session(self.db.engine) as session:
db_secret = session.get(tables.SecretWord, cached.id)
assert db_secret is not None
db_secret.word = "not_cached"
session.add(db_secret)
session.commit()
Expand All @@ -47,7 +49,7 @@ async def test_get_secret__cache(self):
# assert
self.assertEqual("cached", secret)

async def test_get_secret__dont_cache_if_no_secret(self):
async def test_get_secret__dont_cache_if_no_secret(self) -> None:
# arrange
try:
await self.testee.get_secret()
Expand Down

0 comments on commit 3075b22

Please sign in to comment.