diff --git a/.env.test b/.env.test index 5947b513d..b5ab2cc3b 100644 --- a/.env.test +++ b/.env.test @@ -7,7 +7,7 @@ SQLITE_DB = "test.db" # If set, the application use a SQLite database instead of # Authorization using JWT # ACCESS_TOKEN_SECRET_KEY="YWZOHliiI53lJMJc5BI_WbGbA4GF2T7Wbt1airIhOXEa3c021c4-1c55-4182-b141-7778bcc8fac4" # Note: modifing this token requires to update the common `test_check_settings_mocking` test RSA_PRIVATE_PEM_STRING = "-----BEGIN RSA PRIVATE KEY-----\nMIIEpQIBAAKCAQEA1tpj3TZDkJakp2RygsM392pQbcmNBOGFT8FlETcRG/JVFT7k\niClJu+CVOJSVD0epfpYp93cYepfw74SezYnBCyuoLJ2yg5Qh4KlCrWmvwM7vhFIN\nx0xddIQi+Gm0T3dxGtv4Ga50TYX4SV4FE3ctJG9m3pyNF6POODp5tMJvShQWYTto\nW9qNhltZ8Z+14bq2INV/efpT47WuMT+VD/fa9/WwopAtgBcQOvq57fv5+DaPOIVR\n9BiP7F+pv+v6wQ373hI22QzCMsA4Whl+BmWFKcFoBDOBRjlW5VqhJWJkWZIRP0q+\nVAZHk2xJK+0YFc9jmaC+ExMtuyHYK0RnQK/8LQIDAQABAoIBABxJ8v4sZ+cAvrs/\nkYhAFf1gpShfck7jNr9SknEa1Aje9m7usf5vmULAhkVF4v55DAsb0HjB2JpDqTiQ\nOKyNZ7qFzAXb2aZTecZv4tScZsS3OngsqZ3FI0T1JPmaSWBxNJY5wkf3XV7btd5L\nH9X5ShtTA7Np33XuXneu01mGhEq3boLro+vfXMHV5QHyle1F4LUFWEqtP0UmZ5wA\nrro0Y7pA8R88tu5X4iWEjQPnAsbRixwFQ9LNMD8+40e1UIguobRySnP5umErHaIh\nKui7ZijLjbZh/dPS0IfpgahL1K6s9XhT3mD9WMvAvMkNtLewHIZZukG45mOQBrjF\nvvyYxoECgYEA+EY6YimGw0IKnUuf+5uZRXST7kDMENz1Flkcj8oZvo47hdX8/lDN\ni0y7gm3VNfHAK2R2KZPmSbtXA0DvS7kmx1/CFcmwkaakhuU5dyCHldWwSaTME3IE\nxjSZfTvlAiq9i6nUflgfkKo3Bdsiq8TYOUAv25S2SwYDH9Tx0fQwwGECgYEA3Ynt\nCHc8e4YRlGT65UQmEZ8cptmqVRyY4ClMU1xht7Pn0G1JwKRraiEL5/LndwscWf3h\nDygQuArJ28pp4d22FEW1LeXozXYUjJoz3anIA45IZ1OihS7Cx7tJB51/QNJeFdF4\nEX/XHaVukHyYSsAxkwCUYOw3cSgZOSEddL5Wf00CgYEA7JlIlDmMwtFR+jqSmJ3c\n//Kr8zZvAnb/Xa/IZ0MrK4yyLsYR1m48o06Ztx9iO4lKIFAZx1+563QL5P7hzOEC\nkqev90GA8hzD2AXksKEgdOrymAvjq3hSEm0YBN+qS1ldzxYmec0TL7L2wq7lqJnr\nkQuZUAG1g2OUYKZ3WSUDvKECgYEAv24NSkFuG/avfiD7w9xtYNCye2KekskROLG2\n6FltfsWQTEQDdNkekChaF2WHqRAKwaBlNymRuNZpsuhnMerZCQ9rDWwbDF86RnyA\n0MuCr7/kxJQ6XQcY/GnTIydu7F5bOlM0gzqKcW2f6m4fUohczf+0N0QmbDsQAJOi\n1lwadgkCgYEA3tkCBJIPTQecfjWiLqSocS6SrwXU+r3Jw6kI3/IB6ban/nsFdHSb\nnADST7f2zZatN6XALwsLU7f2R09R39ub0AJPyfToxo7MngR1rvaUYooF3rLlaU32\n8DqGvGpLkZkwbtcDmcX1zQoHjUo7RvoShZoapr59ihfrkiiEsXOkuGw=\n-----END RSA PRIVATE KEY-----\n" -AUTH_CLIENTS=[["AppAuthClientWithPKCE", null, ["http://127.0.0.1:8000/docs"], "AppAuthClient"], ["AppAuthClientWithClientSecret", "secret", ["http://127.0.0.1:8000/docs"], "AppAuthClient"], ["BaseAuthClient", "secret", ["http://127.0.0.1:8000/docs"], "BaseAuthClient"], ["RalllyAuthClient", "secret", ["http://127.0.0.1:8000/docs"], "RalllyAuthClient"], ["SynapseAuthClient", "secret", ["http://127.0.0.1:8000/docs"], "SynapseAuthClient"], ["AcceptingOnlyECLUsersAuthClient", "secret", ["http://127.0.0.1:8000/docs"], "NextcloudAuthClient"]] +AUTH_CLIENTS=[["AppAuthClientWithPKCE", null, ["http://127.0.0.1:8000/docs"], "AppAuthClient"], ["AppAuthClientWithClientSecret", "secret", ["http://127.0.0.1:8000/docs"], "AppAuthClient"], ["BaseAuthClient", "secret", ["http://127.0.0.1:8000/docs"], "BaseAuthClient"], ["RalllyAuthClient", "secret", ["http://127.0.0.1:8000/docs"], "RalllyAuthClient"], ["SynapseAuthClient", "secret", ["http://127.0.0.1:8000/docs"], "SynapseAuthClient"], ["AcceptingOnlyECLUsersAuthClient", "secret", ["http://127.0.0.1:8000/docs"], "NextcloudAuthClient"], ["RestrictingUsersGroupsAuthClient", "secret", ["http://127.0.0.1:8000/docs"], "DocumensoAuthClient"]] # OIDC # # Host or url of the API, used for Openid connect discovery endpoint diff --git a/app/api.py b/app/api.py index 7ea2bdd32..46170ff86 100644 --- a/app/api.py +++ b/app/api.py @@ -10,6 +10,7 @@ from app.core.groups import endpoints_groups from app.core.notification import endpoints_notification from app.core.payment import endpoints_payment +from app.core.schools import endpoints_schools from app.core.users import endpoints_users from app.modules.module_list import module_list @@ -24,6 +25,7 @@ api_router.include_router(endpoints_notification.router) api_router.include_router(endpoints_payment.router) api_router.include_router(endpoints_users.router) +api_router.include_router(endpoints_schools.router) for module in module_list: api_router.include_router(module.router) diff --git a/app/app.py b/app/app.py index 86bd00fc2..74f38b621 100644 --- a/app/app.py +++ b/app/app.py @@ -27,6 +27,7 @@ from app.core.google_api.google_api import GoogleAPI from app.core.groups.groups_type import GroupType from app.core.log import LogConfig +from app.core.schools.schools_type import SchoolType from app.dependencies import ( get_db, get_redis_client, @@ -185,40 +186,104 @@ def initialize_groups( ) -def initialize_module_visibility( +def initialize_schools( sync_engine: Engine, hyperion_error_logger: logging.Logger, ) -> None: - """Add the default module visibilities for Titan""" + """Add the necessary shools""" + hyperion_error_logger.info("Startup: Adding new groups to the database") with Session(sync_engine) as db: - # Is run to create default module visibilities or when the table is empty - haveBeenInitialized = ( - len(initialization.get_all_module_visibility_membership_sync(db)) > 0 - ) - if haveBeenInitialized: - hyperion_error_logger.info( - "Startup: Modules visibility settings have already been initialized", - ) - return - - hyperion_error_logger.info( - "Startup: Modules visibility settings are empty, initializing them", - ) - for module in module_list: - for default_group_id in module.default_allowed_groups_ids: - module_visibility = models_core.ModuleVisibility( - root=module.root, - allowed_group_id=default_group_id.value, + for school in SchoolType: + exists = initialization.get_school_by_id_sync(school_id=school, db=db) + # We don't want to recreate the groups if they already exist + if not exists: + db_school = models_core.CoreSchool( + id=school, + name=school.name, + email_regex=".*", ) + try: - initialization.create_module_visibility_sync(module_visibility, db) + initialization.create_school_sync(school=db_school, db=db) except IntegrityError as error: hyperion_error_logger.fatal( - f"Startup: Could not add module visibility {module.root}<{default_group_id}> in the database: {error}", + f"Startup: Could not add group {db_school.name}<{db_school.id}> in the database: {error}", ) +def initialize_module_visibility( + sync_engine: Engine, + hyperion_error_logger: logging.Logger, +) -> None: + """Add the default module visibilities for Titan""" + + with Session(sync_engine) as db: + module_awareness = initialization.get_core_data_sync( + "module_awareness", + db, + ) + awareness_list = [] + if module_awareness is not None: + awareness_list = module_awareness.data.split(",") + + # Is run to create default module visibilities or when the table is empty + if awareness_list != [module.root for module in module_list]: + hyperion_error_logger.info( + "Startup: Some modules visibility settings are empty, initializing them", + ) + to_add = [ + module for module in module_list if module.root not in awareness_list + ] + for module in to_add: + if module.default_allowed_groups_ids is not None: + for group_id in module.default_allowed_groups_ids: + module_group_visibility = models_core.ModuleGroupVisibility( + root=module.root, + allowed_group_id=group_id, + ) + try: + initialization.create_module_group_visibility_sync( + module_visibility=module_group_visibility, + db=db, + ) + except ValueError as error: + hyperion_error_logger.fatal( + f"Startup: Could not add module visibility {module.root} in the database: {error}", + ) + if module.default_allowed_account_types is not None: + for account_type in module.default_allowed_account_types: + module_account_type_visibility = ( + models_core.ModuleAccountTypeVisibility( + root=module.root, + allowed_account_type=account_type, + ) + ) + try: + initialization.create_module_account_type_visibility_sync( + module_visibility=module_account_type_visibility, + db=db, + ) + except ValueError as error: + hyperion_error_logger.fatal( + f"Startup: Could not add module visibility {module.root} in the database: {error}", + ) + initialization.set_core_data_sync( + models_core.CoreData( + "module_awareness", + ",".join([module.root for module in module_list]), + ), + db, + ) + hyperion_error_logger.info( + f"Startup: Modules visibility settings initialized for {[module.root for module in to_add]}", + ) + else: + hyperion_error_logger.info( + "Startup: Modules visibility settings already initialized", + ) + + def use_route_path_as_operation_ids(app: FastAPI) -> None: """ Simplify operation IDs so that generated API clients have simpler function names. @@ -261,6 +326,10 @@ def init_db( sync_engine=sync_engine, hyperion_error_logger=hyperion_error_logger, ) + initialize_schools( + sync_engine=sync_engine, + hyperion_error_logger=hyperion_error_logger, + ) initialize_module_visibility( sync_engine=sync_engine, hyperion_error_logger=hyperion_error_logger, diff --git a/app/core/auth/endpoints_auth.py b/app/core/auth/endpoints_auth.py index f8d628eaa..e2d0ee571 100644 --- a/app/core/auth/endpoints_auth.py +++ b/app/core/auth/endpoints_auth.py @@ -41,7 +41,7 @@ from app.types.exceptions import AuthHTTPException from app.types.scopes_type import ScopeType from app.utils.auth.providers import BaseAuthClient -from app.utils.tools import is_user_external, is_user_member_of_an_allowed_group +from app.utils.tools import is_user_member_of_an_allowed_group router = APIRouter(tags=["Auth"]) @@ -310,7 +310,7 @@ async def authorize_validation( if auth_client.allowed_groups is not None: if not is_user_member_of_an_allowed_group( user=user, - allowed_groups=auth_client.allowed_groups, + allowed_groups=auth_client.allowed_groups, # type: ignore ): hyperion_access_logger.warning( f"Authorize-validation: user is not member of an allowed group {authorizereq.email} ({request_id})", @@ -322,19 +322,19 @@ async def authorize_validation( ), status_code=status.HTTP_302_FOUND, ) - if not auth_client.allow_external_users: - if is_user_external(user): + if auth_client.allowed_account_types is not None: + if user.account_type not in auth_client.allowed_account_types: + # TODO We should show an HTML page explaining the issue hyperion_access_logger.warning( - f"Authorize-validation: external users are disabled for this auth provider {auth_client.client_id} ({request_id})", + f"Authorize-validation: user account type is not allowed {authorizereq.email} ({request_id})", ) return RedirectResponse( settings.CLIENT_URL + calypsso.get_error_relative_url( - message="External users are not allowed", + message="User account type is not allowed", ), status_code=status.HTTP_302_FOUND, ) - # We generate a new authorization_code # The authorization code MUST expire # shortly after it is issued to mitigate the risk of leaks. A diff --git a/app/core/cruds_core.py b/app/core/cruds_core.py index ebe325ca6..5dfb769d3 100644 --- a/app/core/cruds_core.py +++ b/app/core/cruds_core.py @@ -5,31 +5,48 @@ from sqlalchemy.ext.asyncio import AsyncSession from app.core import models_core - - -async def get_all_module_visibility_membership( - db: AsyncSession, -): - """Return the every module with their visibility""" - result = await db.execute(select(models_core.ModuleVisibility)) - return result.unique().scalars().all() +from app.core.groups.groups_type import AccountType async def get_modules_by_user( user: models_core.CoreUser, db: AsyncSession, -) -> Sequence[str]: +) -> list[str]: """Return the modules a user has access to""" userGroupIds = [group.id for group in user.groups] - result = await db.execute( - select(models_core.ModuleVisibility.root) - .where(models_core.ModuleVisibility.allowed_group_id.in_(userGroupIds)) - .group_by(models_core.ModuleVisibility.root), + result_group = list( + ( + await db.execute( + select(models_core.ModuleGroupVisibility.root) + .where( + models_core.ModuleGroupVisibility.allowed_group_id.in_( + userGroupIds, + ), + ) + .group_by(models_core.ModuleGroupVisibility.root), + ) + ) + .unique() + .scalars() + .all(), + ) + result_account_type = list( + ( + await db.execute( + select(models_core.ModuleAccountTypeVisibility.root).where( + models_core.ModuleAccountTypeVisibility.allowed_account_type + == user.account_type, + ), + ) + ) + .unique() + .scalars() + .all(), ) - return result.unique().scalars().all() + return result_group + result_account_type async def get_allowed_groups_by_root( @@ -40,8 +57,8 @@ async def get_allowed_groups_by_root( result = await db.execute( select( - models_core.ModuleVisibility.allowed_group_id, - ).where(models_core.ModuleVisibility.root == root), + models_core.ModuleGroupVisibility.allowed_group_id, + ).where(models_core.ModuleGroupVisibility.root == root), ) resultList = result.unique().scalars().all() @@ -49,26 +66,27 @@ async def get_allowed_groups_by_root( return resultList -async def get_module_visibility( +async def get_allowed_account_types_by_root( root: str, - group_id: str, db: AsyncSession, -) -> models_core.ModuleVisibility | None: - """Return module visibility by root and group id""" +) -> Sequence[str]: + """Return the groups allowed to access to a specific root""" result = await db.execute( - select(models_core.ModuleVisibility).where( - models_core.ModuleVisibility.allowed_group_id == group_id, - models_core.ModuleVisibility.root == root, - ), + select( + models_core.ModuleAccountTypeVisibility.allowed_account_type, + ).where(models_core.ModuleAccountTypeVisibility.root == root), ) - return result.unique().scalars().first() + + resultList = result.unique().scalars().all() + + return resultList -async def create_module_visibility( - module_visibility: models_core.ModuleVisibility, +async def create_module_group_visibility( + module_visibility: models_core.ModuleGroupVisibility, db: AsyncSession, -) -> models_core.ModuleVisibility: +) -> models_core.ModuleGroupVisibility: """Create a new module visibility in database and return it""" db.add(module_visibility) @@ -81,15 +99,46 @@ async def create_module_visibility( return module_visibility -async def delete_module_visibility( +async def create_module_account_type_visibility( + module_visibility: models_core.ModuleAccountTypeVisibility, + db: AsyncSession, +) -> models_core.ModuleAccountTypeVisibility: + """Create a new module visibility in database and return it""" + + db.add(module_visibility) + try: + await db.commit() + except IntegrityError: + await db.rollback() + raise + else: + return module_visibility + + +async def delete_module_group_visibility( root: str, allowed_group_id: str, db: AsyncSession, ): await db.execute( - delete(models_core.ModuleVisibility).where( - models_core.ModuleVisibility.root == root, - models_core.ModuleVisibility.allowed_group_id == allowed_group_id, + delete(models_core.ModuleGroupVisibility).where( + models_core.ModuleGroupVisibility.root == root, + models_core.ModuleGroupVisibility.allowed_group_id == allowed_group_id, + ), + ) + await db.commit() + + +async def delete_module_account_type_visibility( + root: str, + allowed_account_type: AccountType, + db: AsyncSession, +): + await db.execute( + delete(models_core.ModuleAccountTypeVisibility).where( + models_core.ModuleAccountTypeVisibility.root == root, + models_core.ModuleAccountTypeVisibility.allowed_account_type + == allowed_account_type, ), ) await db.commit() diff --git a/app/core/endpoints_core.py b/app/core/endpoints_core.py index 8ef7e8c9b..188ff76df 100644 --- a/app/core/endpoints_core.py +++ b/app/core/endpoints_core.py @@ -8,7 +8,7 @@ from app.core import cruds_core, models_core, schemas_core from app.core.config import Settings -from app.core.groups.groups_type import GroupType +from app.core.groups.groups_type import AccountType, GroupType from app.dependencies import ( get_db, get_settings, @@ -174,10 +174,15 @@ async def get_module_visibility( root=module.root, db=db, ) + allowed_account_types = await cruds_core.get_allowed_account_types_by_root( + root=module.root, + db=db, + ) return_module_visibilities.append( schemas_core.ModuleVisibility( root=module.root, allowed_group_ids=allowed_group_ids, + allowed_account_types=allowed_account_types, ), ) @@ -191,7 +196,7 @@ async def get_module_visibility( ) async def get_user_modules_visibility( db: AsyncSession = Depends(get_db), - user: models_core.CoreUser = Depends(is_user), + user: models_core.CoreUser = Depends(is_user()), ): """ Get group user accessible root @@ -213,40 +218,78 @@ async def add_module_visibility( user: models_core.CoreUser = Depends(is_user_a_member_of(GroupType.admin)), ): """ - Add a new group to a module + Add a new group or account type to a module **This endpoint is only usable by administrators** """ - - # We need to check that loaner.group_manager_id is a valid group - if not await is_group_id_valid(module_visibility.allowed_group_id, db=db): + if ( + module_visibility.allowed_group_id is None + and module_visibility.allowed_account_type is None + ): raise HTTPException( status_code=400, - detail="Invalid id, group_id must be a valid group id", + detail="allowed_group_id or allowed_account_type must be set", ) - try: - module_visibility_db = models_core.ModuleVisibility( + + if module_visibility.allowed_group_id is not None: + # We need to check that loaner.group_manager_id is a valid group + if not await is_group_id_valid(module_visibility.allowed_group_id, db=db): + raise HTTPException( + status_code=400, + detail="Invalid id, group_id must be a valid group id", + ) + module_group_visibility_db = models_core.ModuleGroupVisibility( root=module_visibility.root, allowed_group_id=module_visibility.allowed_group_id, ) - - return await cruds_core.create_module_visibility( - module_visibility=module_visibility_db, - db=db, + try: + return await cruds_core.create_module_group_visibility( + module_visibility=module_group_visibility_db, + db=db, + ) + except ValueError as error: + raise HTTPException(status_code=400, detail=str(error)) + + if module_visibility.allowed_account_type is not None: + module_account_visibility_db = models_core.ModuleAccountTypeVisibility( + root=module_visibility.root, + allowed_account_type=module_visibility.allowed_account_type, ) - except ValueError as error: - raise HTTPException(status_code=400, detail=str(error)) + try: + return await cruds_core.create_module_account_type_visibility( + module_visibility=module_account_visibility_db, + db=db, + ) + except ValueError as error: + raise HTTPException(status_code=400, detail=str(error)) -@router.delete("/module-visibility/{root}/{group_id}", status_code=204) -async def delete_session( +@router.delete("/module-visibility/{root}/groups/{group_id}", status_code=204) +async def delete_module_group_visibility( root: str, group_id: str, db: AsyncSession = Depends(get_db), user: models_core.CoreUser = Depends(is_user_a_member_of(GroupType.admin)), ): - await cruds_core.delete_module_visibility( + await cruds_core.delete_module_group_visibility( root=root, allowed_group_id=group_id, db=db, ) + + +@router.delete( + "/module-visibility/{root}/account-types/{account_type}", + status_code=204, +) +async def delete_module_account_type_visibility( + root: str, + account_type: AccountType, + db: AsyncSession = Depends(get_db), + user: models_core.CoreUser = Depends(is_user_a_member_of(GroupType.admin)), +): + await cruds_core.delete_module_account_type_visibility( + root=root, + allowed_account_type=account_type, + db=db, + ) diff --git a/app/core/groups/groups_type.py b/app/core/groups/groups_type.py index a5332ceba..1adb2557f 100644 --- a/app/core/groups/groups_type.py +++ b/app/core/groups/groups_type.py @@ -12,12 +12,12 @@ class GroupType(str, Enum): """ # Account types - student = "39691052-2ae5-4e12-99d0-7a9f5f2b0136" - formerstudent = "ab4c7503-41b3-11ee-8177-089798f1a4a5" - staff = "703056c4-be9d-475c-aa51-b7fc62a96aaa" - association = "29751438-103c-42f2-b09b-33fbb20758a7" - external = "b1cd979e-ecc1-4bd0-bc2b-4dad2ba8cded" - demo = "ae4d1866-e7d9-4d7f-bee7-e0dda24d8dd8" + # student = "39691052-2ae5-4e12-99d0-7a9f5f2b0136" + # former_student = "ab4c7503-41b3-11ee-8177-089798f1a4a5" + # staff = "703056c4-be9d-475c-aa51-b7fc62a96aaa" + # association = "29751438-103c-42f2-b09b-33fbb20758a7" + # external = "b1cd979e-ecc1-4bd0-bc2b-4dad2ba8cded" + # demo = "ae4d1866-e7d9-4d7f-bee7-e0dda24d8dd8" # Core groups admin = "0a25cb76-4b63-4fd3-b939-da6d9feabf28" @@ -45,22 +45,32 @@ class AccountType(str, Enum): These values should match GroupType's. They are the lower level groups in Hyperion """ - student = GroupType.student.value - formerstudent = GroupType.formerstudent.value - staff = GroupType.staff.value - association = GroupType.association.value - external = GroupType.external.value - demo = GroupType.demo.value + student = "student" + former_student = "former_student" + staff = "staff" + association = "association" + external = "external" + other_school_student = "other_school_student" + demo = "demo" def __str__(self): return f"{self.name}<{self.value}>" -def get_ecl_groups() -> list[GroupType]: +def get_ecl_account_types() -> list[AccountType]: return [ - GroupType.AE, - GroupType.staff, - GroupType.student, - GroupType.association, - GroupType.admin, + AccountType.student, + AccountType.former_student, + AccountType.staff, + AccountType.association, + ] + + +def get_account_types_except_externals() -> list[AccountType]: + return [ + AccountType.student, + AccountType.former_student, + AccountType.staff, + AccountType.association, + AccountType.demo, ] diff --git a/app/core/models_core.py b/app/core/models_core.py index 3bd93f82c..760a38c0e 100644 --- a/app/core/models_core.py +++ b/app/core/models_core.py @@ -5,6 +5,7 @@ from sqlalchemy import ForeignKey, String from sqlalchemy.orm import Mapped, mapped_column, relationship +from app.core.groups.groups_type import AccountType from app.types.floors_type import FloorsType from app.types.membership import AvailableAssociationMembership from app.types.sqlalchemy import Base, PrimaryKey @@ -28,7 +29,14 @@ class CoreUser(Base): index=True, ) # Use UUID later email: Mapped[str] = mapped_column(unique=True, index=True) + school_id: Mapped[str] = mapped_column(ForeignKey("core_school.id")) password_hash: Mapped[str] + # Depending on the account type, the user may have different rights and access to different features + # External users may exist for: + # - accounts meant to be used by external services based on Hyperion SSO or Hyperion backend + # - new users that need to do additional steps before being able to all features, + # like using a specific email address, going through an inscription process or being manually validated + account_type: Mapped[AccountType] name: Mapped[str] firstname: Mapped[str] nickname: Mapped[str | None] @@ -38,13 +46,6 @@ class CoreUser(Base): floor: Mapped[FloorsType | None] created_on: Mapped[datetime | None] - # Users that are externals (not members) won't be able to use all features - # These, self registered, external users may exist for: - # - accounts meant to be used by external services based on Hyperion SSO or Hyperion backend - # - new users that need to do additional steps before being able to all features, - # like using a specific email address, going through an inscription process or being manually validated - external: Mapped[bool] = mapped_column(default=False) - # We use list["CoreGroup"] with quotes as CoreGroup is only defined after this class # Defining CoreUser after CoreGroup would cause a similar issue groups: Mapped[list["CoreGroup"]] = relationship( @@ -54,6 +55,12 @@ class CoreUser(Base): lazy="selectin", default_factory=list, ) + school: Mapped["CoreSchool"] = relationship( + "CoreSchool", + back_populates="students", + lazy="selectin", + init=False, + ) class CoreUserUnconfirmed(Base): @@ -65,11 +72,9 @@ class CoreUserUnconfirmed(Base): # for example after losing the previously received confirmation email. # For each user creation request, a row will be added in this table with a new token email: Mapped[str] - account_type: Mapped[str] activation_token: Mapped[str] created_on: Mapped[datetime] expire_on: Mapped[datetime] - external: Mapped[bool] = mapped_column(default=False) class CoreUserRecoverRequest(Base): @@ -120,6 +125,21 @@ class CoreGroup(Base): ) +class CoreSchool(Base): + __tablename__ = "core_school" + + id: Mapped[str] = mapped_column(String, primary_key=True, index=True) + name: Mapped[str] = mapped_column(String, index=True, nullable=False, unique=True) + email_regex: Mapped[str] = mapped_column(String, nullable=False) + + students: Mapped[list["CoreUser"]] = relationship( + "CoreUser", + back_populates="school", + lazy="selectin", + default_factory=list, + ) + + class CoreAssociationMembership(Base): __tablename__ = "core_association_membership" @@ -150,13 +170,20 @@ class CoreData(Base): data: Mapped[str] -class ModuleVisibility(Base): - __tablename__ = "module_visibility" +class ModuleGroupVisibility(Base): + __tablename__ = "module_group_visibility" root: Mapped[str] = mapped_column(primary_key=True) allowed_group_id: Mapped[str] = mapped_column(primary_key=True) +class ModuleAccountTypeVisibility(Base): + __tablename__ = "module_account_type_visibility" + + root: Mapped[str] = mapped_column(primary_key=True) + allowed_account_type: Mapped[AccountType] = mapped_column(primary_key=True) + + class AlembicVersion(Base): """ A table managed exclusively by Alembic, used to keep track of the database schema version. diff --git a/app/core/notification/endpoints_notification.py b/app/core/notification/endpoints_notification.py index e1478da79..49421c949 100644 --- a/app/core/notification/endpoints_notification.py +++ b/app/core/notification/endpoints_notification.py @@ -30,7 +30,7 @@ async def register_firebase_device( firebase_token: str = Body(embed=True), db: AsyncSession = Depends(get_db), - user: models_core.CoreUser = Depends(is_user), + user: models_core.CoreUser = Depends(is_user()), notification_manager: NotificationManager = Depends(get_notification_manager), ): """ @@ -80,7 +80,7 @@ async def register_firebase_device( async def unregister_firebase_device( firebase_token: str, db: AsyncSession = Depends(get_db), - user: models_core.CoreUser = Depends(is_user), + user: models_core.CoreUser = Depends(is_user()), notification_manager: NotificationManager = Depends(get_notification_manager), ): """ @@ -105,7 +105,7 @@ async def get_messages( firebase_token: str, db: AsyncSession = Depends(get_db), # If we want to enable authentification for /messages/{firebase_token} endpoint, we may to uncomment the following line - # user: models_core.CoreUser = Depends(is_user), + # user: models_core.CoreUser = Depends(is_user()), ): """ Get all messages for a specific device from the user @@ -146,7 +146,7 @@ async def subscribe_to_topic( description="The topic to subscribe to. The Topic may be followed by an additional identifier (ex: cinema_4c029b5f-2bf7-4b70-85d4-340a4bd28653)", ), db: AsyncSession = Depends(get_db), - user: models_core.CoreUser = Depends(is_user), + user: models_core.CoreUser = Depends(is_user()), notification_manager: NotificationManager = Depends(get_notification_manager), ): """ @@ -177,7 +177,7 @@ async def subscribe_to_topic( async def unsubscribe_to_topic( topic_str: str, db: AsyncSession = Depends(get_db), - user: models_core.CoreUser = Depends(is_user), + user: models_core.CoreUser = Depends(is_user()), notification_manager: NotificationManager = Depends(get_notification_manager), ): """ @@ -202,7 +202,7 @@ async def unsubscribe_to_topic( ) async def get_topic( db: AsyncSession = Depends(get_db), - user: models_core.CoreUser = Depends(is_user), + user: models_core.CoreUser = Depends(is_user()), ): """ Get topics the user is subscribed to @@ -231,7 +231,7 @@ async def get_topic( async def get_topic_identifier( topic: Topic, db: AsyncSession = Depends(get_db), - user: models_core.CoreUser = Depends(is_user), + user: models_core.CoreUser = Depends(is_user()), ): """ Get custom topic (with identifiers) the user is subscribed to diff --git a/app/core/schemas_core.py b/app/core/schemas_core.py index 7efed4820..03add666f 100644 --- a/app/core/schemas_core.py +++ b/app/core/schemas_core.py @@ -19,6 +19,44 @@ class CoreInformation(BaseModel): minimal_titan_version_code: int +class CoreGroupBase(BaseModel): + """Base schema for group's model""" + + name: str + description: str | None = None + + _normalize_name = field_validator("name")(validators.trailing_spaces_remover) + + +class CoreGroupSimple(CoreGroupBase): + """Simplified schema for group's model, used when getting all groups""" + + id: str + model_config = ConfigDict(from_attributes=True) + + +class CoreSchoolBase(BaseModel): + """Schema for school's model""" + + name: str + email_regex: str + + _normalize_name = field_validator("name")(validators.trailing_spaces_remover) + + +class CoreSchoolSimple(CoreSchoolBase): + id: str + + +class CoreSchoolUpdate(BaseModel): + """Schema for school update""" + + name: str | None = None + email_regex: str | None = None + + _normalize_name = field_validator("name")(validators.trailing_spaces_remover) + + class CoreUserBase(BaseModel): """Base schema for user's model""" @@ -35,39 +73,35 @@ class CoreUserBase(BaseModel): ) -class CoreGroupBase(BaseModel): - """Base schema for group's model""" - - name: str - description: str | None = None - - _normalize_name = field_validator("name")(validators.trailing_spaces_remover) - - class CoreUserSimple(CoreUserBase): """Simplified schema for user's model, used when getting all users""" id: str + account_type: AccountType model_config = ConfigDict(from_attributes=True) -class CoreGroupSimple(CoreGroupBase): - """Simplified schema for group's model, used when getting all groups""" +class CoreSchool(CoreSchoolSimple): + """Schema for school's model similar to core_school table in database""" - id: str model_config = ConfigDict(from_attributes=True) + students: list[CoreUserSimple] = [] + class CoreUser(CoreUserSimple): """Schema for user's model similar to core_user table in database""" email: str + school_id: str + account_type: AccountType birthday: date | None = None promo: int | None = None floor: FloorsType | None = None phone: str | None = None created_on: datetime | None = None groups: list[CoreGroupSimple] = [] + school: CoreSchoolSimple | None = None class CoreUserUpdate(BaseModel): @@ -95,6 +129,8 @@ class CoreUserFusionRequest(BaseModel): class CoreUserUpdateAdmin(BaseModel): + school_id: str | None = None + account_type: AccountType | None = None name: str | None = None firstname: str | None = None promo: int | None = None @@ -102,7 +138,6 @@ class CoreUserUpdateAdmin(BaseModel): birthday: date | None = None phone: str | None = None floor: FloorsType | None = None - external: bool | None = None _normalize_name = field_validator("name")(validators.trailing_spaces_remover) _normalize_firstname = field_validator("firstname")( @@ -123,8 +158,9 @@ class CoreUserCreateRequest(BaseModel): """ email: str - accept_external: bool = Field( - False, + accept_external: bool | None = Field( + None, + deprecated=True, description="Allow Hyperion to create an external user. Without this, Hyperion will only allow non external students to be created. The email address will be used to determine if the user should be external or not. An external user may not have an ECL email address, he won't be able to access most features.", ) @@ -139,12 +175,10 @@ class CoreUserCreateRequest(BaseModel): class CoreBatchUserCreateRequest(BaseModel): """ - The schema is used for batch account creation requests. An account type should be provided + The schema is used for batch account creation requests. """ email: str - account_type: AccountType - external: bool = False # Email normalization, this will modify the email variable # https://pydantic-docs.helpmanual.io/usage/validators/#reuse-validators @@ -163,7 +197,7 @@ class CoreUserActivateRequest(CoreUserBase): floor: FloorsType | None = None promo: int | None = Field( default=None, - description="Promotion of the student, an integer like 21", + description="Promotion of the student, an integer like 2021", ) # Password validator @@ -188,13 +222,6 @@ class CoreGroupCreate(CoreGroupBase): """Model for group creation schema""" -class CoreGroupInDB(CoreGroupBase): - """Schema for user activation""" - - id: str - model_config = ConfigDict(from_attributes=True) - - class CoreGroupUpdate(BaseModel): """Schema for group update""" @@ -272,10 +299,12 @@ class CoreMembershipDelete(BaseModel): class ModuleVisibility(BaseModel): root: str allowed_group_ids: list[str] + allowed_account_types: list[AccountType] model_config = ConfigDict(from_attributes=True) class ModuleVisibilityCreate(BaseModel): root: str - allowed_group_id: str + allowed_group_id: str | None = None + allowed_account_type: AccountType | None = None model_config = ConfigDict(from_attributes=True) diff --git a/app/core/schools/__init__.py b/app/core/schools/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/app/core/schools/cruds_schools.py b/app/core/schools/cruds_schools.py new file mode 100644 index 000000000..110643d12 --- /dev/null +++ b/app/core/schools/cruds_schools.py @@ -0,0 +1,85 @@ +"""File defining the functions called by the endpoints, making queries to the table using the models""" + +from collections.abc import Sequence + +from sqlalchemy import delete, select, update +from sqlalchemy.exc import IntegrityError +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.orm import selectinload + +from app.core import models_core, schemas_core + + +async def get_schools(db: AsyncSession) -> Sequence[models_core.CoreSchool]: + """Return all schools from database""" + + result = await db.execute(select(models_core.CoreSchool)) + return result.scalars().all() + + +async def get_school_by_id( + db: AsyncSession, + school_id: str, +) -> models_core.CoreSchool | None: + """Return school with id from database""" + result = await db.execute( + select(models_core.CoreSchool) + .where(models_core.CoreSchool.id == school_id) + .options( + selectinload(models_core.CoreSchool.students), + ), # needed to load the members from the relationship + ) + return result.scalars().first() + + +async def get_school_by_name( + db: AsyncSession, + school_name: str, +) -> models_core.CoreSchool | None: + """Return school with name from database""" + result = await db.execute( + select(models_core.CoreSchool) + .where(models_core.CoreSchool.name == school_name) + .options( + selectinload(models_core.CoreSchool.students), + ), # needed to load the members from the relationship + ) + return result.scalars().first() + + +async def create_school( + school: models_core.CoreSchool, + db: AsyncSession, +) -> models_core.CoreSchool: + """Create a new school in database and return it""" + + db.add(school) + try: + await db.commit() + except IntegrityError: + await db.rollback() + raise + else: + return school + + +async def delete_school(db: AsyncSession, school_id: str): + """Delete a school from database by id""" + + await db.execute( + delete(models_core.CoreSchool).where(models_core.CoreSchool.id == school_id), + ) + await db.commit() + + +async def update_school( + db: AsyncSession, + school_id: str, + school_update: schemas_core.CoreSchoolUpdate, +): + await db.execute( + update(models_core.CoreSchool) + .where(models_core.CoreSchool.id == school_id) + .values(**school_update.model_dump(exclude_none=True)), + ) + await db.commit() diff --git a/app/core/schools/endpoints_schools.py b/app/core/schools/endpoints_schools.py new file mode 100644 index 000000000..8738be29d --- /dev/null +++ b/app/core/schools/endpoints_schools.py @@ -0,0 +1,185 @@ +""" +File defining the API itself, using fastAPI and schemas, and calling the cruds functions + +School management is part of the core of Hyperion. These endpoints allow managing schools. +""" + +import logging +import re +import uuid + +from fastapi import APIRouter, Depends, HTTPException +from sqlalchemy.ext.asyncio import AsyncSession + +from app.core import models_core, schemas_core +from app.core.groups.groups_type import AccountType, GroupType +from app.core.schools import cruds_schools +from app.core.schools.schools_type import SchoolType +from app.core.users import cruds_users +from app.dependencies import ( + get_db, + is_user_a_member_of, + is_user_an_ecl_member, +) + +router = APIRouter(tags=["Schools"]) + +hyperion_security_logger = logging.getLogger("hyperion.security") + + +@router.get( + "/schools/", + response_model=list[schemas_core.CoreSchoolSimple], + status_code=200, +) +async def read_schools( + db: AsyncSession = Depends(get_db), + user=Depends(is_user_an_ecl_member), +): + """ + Return all schools from database as a list of dictionaries + """ + + schools = await cruds_schools.get_schools(db) + return schools + + +@router.get( + "/schools/{school_id}", + response_model=schemas_core.CoreSchool, + status_code=200, +) +async def read_school( + school_id: str, + db: AsyncSession = Depends(get_db), + user=Depends(is_user_a_member_of(GroupType.admin)), +): + """ + Return school with id from database as a dictionary. This includes a list of users being members of the school. + + **This endpoint is only usable by administrators** + """ + + db_school = await cruds_schools.get_school_by_id(db=db, school_id=school_id) + if db_school is None: + raise HTTPException(status_code=404, detail="School not found") + return db_school + + +@router.post( + "/schools/", + response_model=schemas_core.CoreSchoolSimple, + status_code=201, +) +async def create_school( + school: schemas_core.CoreSchoolBase, + db: AsyncSession = Depends(get_db), + user: schemas_core.CoreUser = Depends(is_user_a_member_of(GroupType.admin)), +): + """ + Create a new school. + + **This endpoint is only usable by administrators** + """ + if ( # We can't have two schools with the same name + await cruds_schools.get_school_by_name(school_name=school.name, db=db) + is not None + ): + raise HTTPException( + status_code=400, + detail="A school with the name {school.name} already exist", + ) + + try: + db_school = models_core.CoreSchool( + id=str(uuid.uuid4()), + name=school.name, + email_regex=school.email_regex, + ) + new_school = await cruds_schools.create_school(school=db_school, db=db) + users = await cruds_users.get_users( + db=db, + included_account_types=[AccountType.external], + ) + for db_user in users: + if re.match(db_school.email_regex, db_user.email): + await cruds_users.update_user( + db, + db_user.id, + schemas_core.CoreUserUpdateAdmin( + school_id=db_school.id, + account_type=AccountType.other_school_student, + ), + ) + except ValueError as error: + raise HTTPException(status_code=422, detail=str(error)) + else: + return new_school + + +@router.patch( + "/schools/{school_id}", + status_code=204, +) +async def update_school( + school_id: str, + school_update: schemas_core.CoreSchoolUpdate, + db: AsyncSession = Depends(get_db), + user=Depends(is_user_a_member_of(GroupType.admin)), +): + """ + Update the name or the description of a school. + + **This endpoint is only usable by administrators** + """ + school = await cruds_schools.get_school_by_id(db=db, school_id=school_id) + + if not school: + raise HTTPException(status_code=404, detail="School not found") + + # If the request ask to update the school name, we need to check it is available + if school_update.name and school_update.name != school.name: + if ( + await cruds_schools.get_school_by_name( + school_name=school_update.name, + db=db, + ) + is not None + ): + raise HTTPException( + status_code=400, + detail="A school with the name {school.name} already exist", + ) + + await cruds_schools.update_school( + db=db, + school_id=school_id, + school_update=school_update, + ) + + +@router.delete( + "/schools/{school_id}", + status_code=204, +) +async def delete_school( + school_id: str, + db: AsyncSession = Depends(get_db), + user=Depends(is_user_a_member_of(GroupType.admin)), +): + """ + Delete school from database. + This will remove the school from all users but won't delete any user. + + `SchoolTypes` schools can not be deleted. + + **This endpoint is only usable by administrators** + """ + + if school_id in set(item.value for item in SchoolType): + raise HTTPException( + status_code=400, + detail="SchoolTypes schools can not be deleted", + ) + + await cruds_schools.delete_school(db=db, school_id=school_id) diff --git a/app/core/schools/schools_type.py b/app/core/schools/schools_type.py new file mode 100644 index 000000000..97b7cbf0d --- /dev/null +++ b/app/core/schools/schools_type.py @@ -0,0 +1,16 @@ +from enum import Enum + + +class SchoolType(str, Enum): + """ + In Hyperion, each user must have a school. Belonging to a school gives access to a set of specific endpoints. + """ + + # Account types + no_school = "dce19aa2-8863-4c93-861e-fb7be8f610ed" + centrale_lyon = "d9772da7-1142-4002-8b86-b694b431dfed" + + # Auth related groups + + def __str__(self): + return f"{self.name}<{self.value}>" diff --git a/app/core/users/cruds_users.py b/app/core/users/cruds_users.py index 38439eb9f..7b3c4c560 100644 --- a/app/core/users/cruds_users.py +++ b/app/core/users/cruds_users.py @@ -2,13 +2,14 @@ from collections.abc import Sequence -from sqlalchemy import ForeignKey, and_, delete, not_, select, update +from sqlalchemy import ForeignKey, and_, delete, not_, or_, select, update from sqlalchemy.exc import IntegrityError from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import selectinload from sqlalchemy_utils import get_referencing_foreign_keys # type: ignore from app.core import models_core, schemas_core +from app.core.groups.groups_type import AccountType async def count_users(db: AsyncSession) -> int: @@ -20,14 +21,18 @@ async def count_users(db: AsyncSession) -> int: async def get_users( db: AsyncSession, + included_account_types: list[AccountType] | None = None, + excluded_account_types: list[AccountType] | None = None, included_groups: list[str] | None = None, excluded_groups: list[str] | None = None, ) -> Sequence[models_core.CoreUser]: """ Return all users from database. - Parameters `included_groups` and `excluded_groups` can be used to filter results. + Parameters `excluded_account_types` and `excluded_groups` can be used to filter results. """ + included_account_types = included_account_types or list(AccountType) + excluded_account_types = excluded_account_types or [] included_groups = included_groups or [] excluded_groups = excluded_groups or [] @@ -43,6 +48,19 @@ async def get_users( ) for group_id in included_groups ], + or_( + False, + *[ + models_core.CoreUser.account_type == account_type + for account_type in included_account_types + ], + ), + *[ + not_( + models_core.CoreUser.account_type == account_type, + ) + for account_type in excluded_account_types + ], # We want, for each group that should not be included # check that the following condition is false : # - at least one of the user's groups match the expected group @@ -235,24 +253,6 @@ async def delete_email_migration_code_by_token( await db.commit() -async def update_user_email_by_id( - user_id: str, - new_email: str, - db: AsyncSession, - make_user_external: bool = False, -): - try: - await db.execute( - update(models_core.CoreUser) - .where(models_core.CoreUser.id == user_id) - .values(email=new_email, external=make_user_external), - ) - await db.commit() - except IntegrityError: - await db.rollback() - raise - - async def delete_recover_request_by_email(db: AsyncSession, email: str): """Delete a user from database by id""" @@ -292,7 +292,7 @@ async def fusion_users( If the user_deleted_id is referenced in a table where the user_kept_id is also referenced, and the change would create a duplicate that violates a unique constraint, the row will be deleted - There are thre cases to consider: + There are three cases to consider: 1. The foreign key is not a primary key of the table: In this case, we can update the row 2. The foreign key is one of the primary keys of the table: diff --git a/app/core/users/endpoints_users.py b/app/core/users/endpoints_users.py index 103d7edc9..fe02e6a1d 100644 --- a/app/core/users/endpoints_users.py +++ b/app/core/users/endpoints_users.py @@ -24,6 +24,8 @@ from app.core.config import Settings from app.core.groups import cruds_groups from app.core.groups.groups_type import AccountType, GroupType +from app.core.schools import cruds_schools +from app.core.schools.schools_type import SchoolType from app.core.users import cruds_users from app.dependencies import ( get_db, @@ -50,15 +52,18 @@ templates = Jinja2Templates(directory="assets/templates") -ECL_EMAIL_REGEX = r"^[\w\-.]*@etu(-enise)?\.ec-lyon\.fr$" +ECL_STAFF_REGEX = r"^[\w\-.]*@(enise\.)?ec-lyon\.fr$" +ECL_STUDENT_REGEX = r"^[\w\-.]*@((etu(-enise)?)|(ecl\d{2}))\.ec-lyon\.fr$" +ECL_FORMER_STUDENT_REGEX = r"^[\w\-.]*@centraliens-lyon\.net$" @router.get( - "/users/", + "/users", response_model=list[schemas_core.CoreUserSimple], status_code=200, ) async def read_users( + accountTypes: list[AccountType] = Query(default=[]), db: AsyncSession = Depends(get_db), user: models_core.CoreUser = Depends(is_user_a_member_of(GroupType.admin)), ): @@ -67,8 +72,9 @@ async def read_users( **This endpoint is only usable by administrators** """ + accountTypes = accountTypes if len(accountTypes) != 0 else list(AccountType) - users = await cruds_users.get_users(db) + users = await cruds_users.get_users(db, included_account_types=accountTypes) return users @@ -82,7 +88,7 @@ async def count_users( user: models_core.CoreUser = Depends(is_user_a_member_of(GroupType.admin)), ): """ - Return all users from database as a list of `CoreUserSimple` + Return the number of users in the database **This endpoint is only usable by administrators** """ @@ -98,6 +104,8 @@ async def count_users( ) async def search_users( query: str, + includedAccountTypes: list[AccountType] = Query(default=[]), + excludedAccountTypes: list[AccountType] = Query(default=[]), includedGroups: list[str] = Query(default=[]), excludedGroups: list[str] = Query(default=[]), db: AsyncSession = Depends(get_db), @@ -113,6 +121,8 @@ async def search_users( users = await cruds_users.get_users( db, + included_account_types=includedAccountTypes, + excluded_account_types=excludedAccountTypes, included_groups=includedGroups, excluded_groups=excludedGroups, ) @@ -120,13 +130,29 @@ async def search_users( return sort_user(string.capwords(query), users) +@router.get( + "/users/account-types", + response_model=list[AccountType], + status_code=200, +) +async def get_account_types( + db: AsyncSession = Depends(get_db), + user: models_core.CoreUser = Depends(is_user_a_member_of(GroupType.admin)), +): + """ + Return all account types hardcoded in the system + """ + + return list(AccountType) + + @router.get( "/users/me", response_model=schemas_core.CoreUser, status_code=200, ) async def read_current_user( - user: models_core.CoreUser = Depends(is_user), + user: models_core.CoreUser = Depends(is_user()), ): """ Return `CoreUser` representation of current user @@ -137,6 +163,11 @@ async def read_current_user( return user +ECL_STAFF_REGEX = r"^[\w\-.]*@(enise\.)?ec-lyon\.fr$" +ECL_STUDENT_REGEX = r"^[\w\-.]*@etu(-enise)?\.ec-lyon\.fr$" +ECL_FORMER_STUDENT_REGEX = r"^[\w\-.]*@centraliens-lyon\.net$" + + @router.post( "/users/create", response_model=standard_responses.Result, @@ -158,56 +189,6 @@ async def create_user_by_user( When creating **student** or **staff** account a valid ECL email is required. Only admin users can create other **account types**, contact ÉCLAIR for more information. """ - # By default we mark the user as external - # but if it has an ECL email address, we will mark it as member - external = True - # Check the account type - - # For staff and student - # ^[\w\-.]*@((etu(-enise)?|enise)\.)?ec-lyon\.fr$ - # For staff - # ^[\w\-.]*@(enise\.)?ec-lyon\.fr$ - # For student - # ^[\w\-.]*@etu(-enise)?\.ec-lyon\.fr$ - - # For former students - # ^[\w\-.]*@centraliens-lyon\.net$ - - # All accepted emails - # ^[\w\-.]*@(((etu(-enise)?|enise)\.)?ec-lyon\.fr|centraliens-lyon\.net)$ - - if re.match(r"^[\w\-.]*@(enise\.)?ec-lyon\.fr$", user_create.email): - # Its a staff email address - account_type = AccountType.staff - external = False - elif re.match( - r"^[\w\-.]*@etu(-enise)?\.ec-lyon\.fr$", - user_create.email, - ): - # Its a student email address - account_type = AccountType.student - external = False - elif re.match( - r"^[\w\-.]*@centraliens-lyon\.net$", - user_create.email, - ): - # Its a former student email address - account_type = AccountType.formerstudent - external = False - elif user_create.accept_external: - account_type = AccountType.external - else: - raise HTTPException( - status_code=400, - detail="Invalid ECL email address.", - ) - - if not user_create.accept_external: - if external: - raise HTTPException( - status_code=400, - detail="The user could not be marked as external. Try to add accept_external=True in the request or use an internal email address.", - ) # Make sure a confirmed account does not already exist db_user = await cruds_users.get_user_by_email(db=db, email=user_create.email) @@ -235,12 +216,10 @@ async def create_user_by_user( await create_user( email=user_create.email, - account_type=account_type, background_tasks=background_tasks, db=db, settings=settings, request_id=request_id, - external=external, ) return standard_responses.Result(success=True) @@ -276,12 +255,10 @@ async def batch_create_users( try: await create_user( email=user_create.email, - account_type=user_create.account_type, background_tasks=background_tasks, db=db, settings=settings, request_id=request_id, - external=user_create.external, ) except Exception as error: failed[user_create.email] = str(error) @@ -291,12 +268,10 @@ async def batch_create_users( async def create_user( email: str, - account_type: AccountType, background_tasks: BackgroundTasks, db: AsyncSession, settings: Settings, request_id: str, - external: bool, ) -> None: """ User creation process. This function is used by both `/users/create` and `/users/admin/create` endpoints @@ -316,12 +291,10 @@ async def create_user( user_unconfirmed = models_core.CoreUserUnconfirmed( id=str(uuid.uuid4()), email=email, - account_type=account_type, activation_token=activation_token, created_on=datetime.now(UTC), expire_on=datetime.now(UTC) + timedelta(hours=settings.USER_ACTIVATION_TOKEN_EXPIRE_HOURS), - external=external, ) await cruds_users.create_unconfirmed_user(user_unconfirmed=user_unconfirmed, db=db) @@ -331,7 +304,7 @@ async def create_user( calypsso_activate_url = settings.CLIENT_URL + calypsso.get_activate_relative_url( activation_token=activation_token, - external=external, + external=True, ) if settings.SMTP_ACTIVE: @@ -395,12 +368,70 @@ async def activate_user( detail=f"The account with the email {unconfirmed_user.email} is already confirmed", ) + # Check the account type + + # For staff and student + # ^[\w\-.]*@((etu(-enise)?|enise)\.)?ec-lyon\.fr$ + # For staff + # ^[\w\-.]*@(enise\.)?ec-lyon\.fr$ + # For student + # ^[\w\-.]*@etu(-enise)?\.ec-lyon\.fr$ + + # For former students + # ^[\w\-.]*@centraliens-lyon\.net$ + + # All accepted emails + # ^[\w\-.]*@(((etu(-enise)?|enise)\.)?ec-lyon\.fr|centraliens-lyon\.net)$ + + # By default we mark the user as external + # but if it has an ECL email address, we will mark it as member + account_type = AccountType.external + if re.match(ECL_STAFF_REGEX, unconfirmed_user.email): + # Its a staff email address + account_type = AccountType.staff + elif re.match( + ECL_STUDENT_REGEX, + unconfirmed_user.email, + ): + # Its a student email address + account_type = AccountType.student + elif re.match( + ECL_FORMER_STUDENT_REGEX, + unconfirmed_user.email, + ): + # Its a former student email address + account_type = AccountType.former_student + if account_type == AccountType.external: + schools = await cruds_schools.get_schools(db=db) + schools = [ + school + for school in schools + if school.id not in (SchoolType.no_school, SchoolType.centrale_lyon) + ] + school = next( + ( + school + for school in schools + if re.match(school.email_regex, unconfirmed_user.email) + ), + None, + ) + if school is not None: + account_type = AccountType.other_school_student + school_id = school.id + else: + account_type = AccountType.external + school_id = SchoolType.no_school + else: + school_id = SchoolType.centrale_lyon # A password should have been provided password_hash = security.get_password_hash(user.password) confirmed_user = models_core.CoreUser( id=unconfirmed_user.id, email=unconfirmed_user.email, + school_id=school_id, + account_type=account_type, password_hash=password_hash, name=user.name, firstname=user.firstname, @@ -410,18 +441,9 @@ async def activate_user( phone=user.phone, floor=user.floor, created_on=datetime.now(UTC), - external=unconfirmed_user.external, ) # We add the new user to the database await cruds_users.create_user(db=db, user=confirmed_user) - await cruds_groups.create_membership( - db=db, - membership=models_core.CoreMembership( - group_id=unconfirmed_user.account_type, - user_id=unconfirmed_user.id, - description=None, - ), - ) # We remove all unconfirmed users with the same email address await cruds_users.delete_unconfirmed_user_by_email( @@ -609,7 +631,7 @@ async def reset_password( async def migrate_mail( mail_migration: schemas_core.MailMigrationRequest, db: AsyncSession = Depends(get_db), - user: models_core.CoreUser = Depends(is_user), + user: models_core.CoreUser = Depends(is_user()), settings: Settings = Depends(get_settings), ): """ @@ -617,10 +639,10 @@ async def migrate_mail( This endpoint will send a confirmation code to the user's new email address. He will need to use this code to confirm the change with `/users/confirm-mail-migration` endpoint. """ - if not re.match(ECL_EMAIL_REGEX, mail_migration.new_email): + if not re.match(user.school.email_regex, mail_migration.new_email): raise HTTPException( status_code=400, - detail="The new email address must match the new ECL format for student users", + detail="The new email address must match the new email format for student users", ) existing_user = await cruds_users.get_user_by_email( @@ -701,29 +723,25 @@ async def migrate_mail_confirm( ) try: - await cruds_users.update_user_email_by_id( - db=db, - user_id=migration_object.user_id, - new_email=migration_object.new_email, - make_user_external=migration_object.make_user_external, - ) - if not migration_object.make_user_external: - group_ids = [group.id for group in user.groups] - if GroupType.external in group_ids: - await cruds_groups.delete_membership_by_group_and_user_id( - db=db, - group_id=GroupType.external, - user_id=migration_object.user_id, - ) - if GroupType.student not in group_ids: - await cruds_groups.create_membership( - db=db, - membership=models_core.CoreMembership( - group_id=GroupType.student, - user_id=migration_object.user_id, - description=None, - ), - ) + if user.school.id == SchoolType.centrale_lyon: + await cruds_users.update_user( + db, + migration_object.user_id, + user_update=schemas_core.CoreUserUpdateAdmin( + email=migration_object.new_email, + account_type=AccountType.student, + ), + ) + elif user.school.id != SchoolType.no_school: + await cruds_users.update_user( + db, + migration_object.user_id, + user_update=schemas_core.CoreUserUpdateAdmin( + email=migration_object.new_email, + account_type=AccountType.other_school_student, + ), + ) + except Exception as error: raise HTTPException(status_code=400, detail=str(error)) @@ -792,7 +810,7 @@ async def read_user( user: models_core.CoreUser = Depends(is_user_a_member_of(GroupType.admin)), ): """ - Return `CoreUserSimple` representation of user with id `user_id` + Return `CoreUser` representation of user with id `user_id` **The user must be authenticated to use this endpoint** """ @@ -820,7 +838,7 @@ async def read_user( status_code=204, ) async def delete_user( - user: models_core.CoreUser = Depends(is_user), + user: models_core.CoreUser = Depends(is_user()), ): """ This endpoint will ask administrators to process to the user deletion. @@ -838,7 +856,7 @@ async def delete_user( async def update_current_user( user_update: schemas_core.CoreUserUpdate, db: AsyncSession = Depends(get_db), - user: models_core.CoreUser = Depends(is_user), + user: models_core.CoreUser = Depends(is_user()), ): """ Update the current user, the request should contain a JSON with the fields to change (not necessarily all fields) and their new value @@ -858,33 +876,31 @@ async def switch_external_user_internal( u: models_core.CoreUser = Depends(is_user_a_member_of(GroupType.admin)), ): """ - Switch all external users to internal users if they have an ECL email address. + Switch all external users to internal users if they have an email address corresponding to their school email regex. **This endpoint is only usable by administrators** """ - users = await cruds_users.get_users(db=db) + users = await cruds_users.get_users( + db=db, + included_account_types=[AccountType.external], + ) for user in users: - if re.match(ECL_EMAIL_REGEX, user.email): - await cruds_users.update_user( - db=db, - user_id=user.id, - user_update=schemas_core.CoreUserUpdateAdmin(external=False), - ) - group_ids = [group.id for group in user.groups] - if GroupType.external in group_ids: - await cruds_groups.delete_membership_by_group_and_user_id( - db=db, - group_id=GroupType.external, - user_id=user.id, + if re.match(user.school.email_regex, user.email): + if user.school.id == SchoolType.centrale_lyon: + await cruds_users.update_user( + db, + user.id, + user_update=schemas_core.CoreUserUpdateAdmin( + account_type=AccountType.student, + ), ) - if GroupType.student not in group_ids: - await cruds_groups.create_membership( - db=db, - membership=models_core.CoreMembership( - group_id=GroupType.student, - user_id=user.id, - description=None, + elif user.school.id != SchoolType.no_school: + await cruds_users.update_user( + db, + user.id, + user_update=schemas_core.CoreUserUpdateAdmin( + account_type=AccountType.other_school_student, ), ) @@ -955,6 +971,13 @@ async def update_user( db_user = await cruds_users.get_user_by_id(db=db, user_id=user_id) if not db_user: raise HTTPException(status_code=404, detail="User not found") + if user_update.school_id is not None: + school = await cruds_schools.get_school_by_id( + db=db, + school_id=user_update.school_id, + ) + if school is None: + raise HTTPException(status_code=404, detail="School not found") await cruds_users.update_user(db=db, user_id=user_id, user_update=user_update) @@ -966,7 +989,7 @@ async def update_user( ) async def create_current_user_profile_picture( image: UploadFile = File(...), - user: models_core.CoreUser = Depends(is_user), + user: models_core.CoreUser = Depends(is_user()), request_id: str = Depends(get_request_id), ): """ @@ -997,7 +1020,7 @@ async def create_current_user_profile_picture( status_code=200, ) async def read_own_profile_picture( - user: models_core.CoreUser = Depends(is_user), + user: models_core.CoreUser = Depends(is_user()), ): """ Get the profile picture of the authenticated user. diff --git a/app/dependencies.py b/app/dependencies.py index da7c345ff..1e5c1f90d 100644 --- a/app/dependencies.py +++ b/app/dependencies.py @@ -24,7 +24,7 @@ async def get_users(db: AsyncSession = Depends(get_db)): from app.core import models_core, security from app.core.auth import schemas_auth from app.core.config import Settings, construct_prod_settings -from app.core.groups.groups_type import GroupType, get_ecl_groups +from app.core.groups.groups_type import AccountType, GroupType, get_ecl_account_types from app.core.payment.payment_tool import PaymentTool from app.modules.raid.utils.drive.drive_file_manager import DriveFileManager from app.types.scopes_type import ScopeType @@ -286,10 +286,12 @@ async def get_current_user( def is_user( - user: models_core.CoreUser = Depends( - get_user_from_token_with_scopes([[ScopeType.API]]), - ), -) -> models_core.CoreUser: + excluded_groups: list[GroupType] | None = None, + included_groups: list[GroupType] | None = None, + excluded_account_types: list[AccountType] | None = None, + included_account_types: list[AccountType] | None = None, + exclude_external: bool = False, +) -> Callable[[models_core.CoreUser], models_core.CoreUser]: """ A dependency that will: * check if the request header contains a valid API JWT token (a token that can be used to call endpoints from the API) @@ -298,12 +300,51 @@ def is_user( To check if the user is not external, use is_user_a_member dependency To check if the user is not external and is the member of a group, use is_user_a_member_of generator """ - return user + + excluded_groups = excluded_groups or [] + excluded_account_types = excluded_account_types or [] + included_account_types = included_account_types or list(AccountType) + + def is_user( + user: models_core.CoreUser = Depends( + get_user_from_token_with_scopes([[ScopeType.API]]), + ), + ): + groups_id: list[str] = [group.id for group in user.groups] + if GroupType.admin in groups_id: + return user + if user.account_type in excluded_account_types: + raise HTTPException( + status_code=403, + detail="Unauthorized, user account type is not allowed", + ) + if exclude_external and is_user_external(user): + raise HTTPException( + status_code=403, + detail="Unauthorized, user is an external user", + ) + if is_user_member_of_an_allowed_group(user, excluded_groups): + raise HTTPException( + status_code=403, + detail=f"Unauthorized, user is a member of any of the groups {excluded_groups}", + ) + if included_groups is not None and not is_user_member_of_an_allowed_group( + user, + included_groups, + ): + raise HTTPException( + status_code=403, + detail="Unauthorized, user is not a member of an allowed group", + ) + if user.account_type in included_account_types: + return user + + return is_user def is_user_a_member( user: models_core.CoreUser = Depends( - is_user, + is_user(exclude_external=True), ), request_id: str = Depends(get_request_id), ) -> models_core.CoreUser: @@ -315,22 +356,12 @@ def is_user_a_member( To check if the user is the member of a group, use is_user_a_member_of generator """ - if is_user_external(user): - hyperion_access_logger.warning( - f"Is_user_a_member: user is an external user ({request_id})", - ) - - raise HTTPException( - status_code=403, - detail="Unauthorized, user is an external user", - ) - return user def is_user_an_ecl_member( user: models_core.CoreUser = Depends( - is_user_a_member, + is_user(included_account_types=get_ecl_account_types()), ), request_id: str = Depends(get_request_id), ) -> models_core.CoreUser: @@ -344,37 +375,36 @@ def is_user_an_ecl_member( To check if the user is the member of a group, use is_user_a_member_of generator """ - if is_user_member_of_an_allowed_group( - user=user, - allowed_groups=get_ecl_groups(), - ): - # We know the user is a member of the group, we don't need to return an error and can return the CoreUser object - return user - - hyperion_access_logger.warning( - f"Is_user_a_member_of: Unauthorized, user is not a member of student, staff, association or AE group ({request_id})", - ) - - raise HTTPException( - status_code=403, - detail="Unauthorized, user is not a member of student, staff or AE group", - ) + return user def is_user_a_member_of( group_id: GroupType, + exclude_external: bool = False, ) -> Callable[[models_core.CoreUser], Coroutine[Any, Any, models_core.CoreUser]]: """ Generate a dependency which will: * check if the request header contains a valid API JWT token (a token that can be used to call endpoints from the API) * make sure the user making the request exists and is a member of the group with the given id - * make sure the user is not an external user + * make sure the user is not an external user if `exclude_external` is True * return the corresponding user `models_core.CoreUser` object """ - async def is_user_a_member_of( + async def is_user_a_member_of_without_external( user: models_core.CoreUser = Depends( - is_user_a_member, + is_user(included_groups=[group_id], exclude_external=exclude_external), + ), + request_id: str = Depends(get_request_id), + ) -> models_core.CoreUser: + """ + A dependency that checks that user is a member of the group with the given id then returns the corresponding user. + """ + + return user + + async def is_user_a_member_of_with_external( + user: models_core.CoreUser = Depends( + is_user(), ), request_id: str = Depends(get_request_id), ) -> models_core.CoreUser: @@ -395,4 +425,8 @@ async def is_user_a_member_of( detail=f"Unauthorized, user is not a member of the group {group_id}", ) - return is_user_a_member_of + return ( + is_user_a_member_of_with_external + if not exclude_external + else is_user_a_member_of_without_external + ) diff --git a/app/modules/advert/endpoints_advert.py b/app/modules/advert/endpoints_advert.py index 6366e30a1..d81b6b83d 100644 --- a/app/modules/advert/endpoints_advert.py +++ b/app/modules/advert/endpoints_advert.py @@ -7,7 +7,7 @@ from sqlalchemy.ext.asyncio import AsyncSession from app.core import models_core, standard_responses -from app.core.groups.groups_type import GroupType +from app.core.groups.groups_type import AccountType, GroupType from app.core.notification.notification_types import CustomTopic, Topic from app.core.notification.schemas_notification import Message from app.dependencies import ( @@ -31,7 +31,7 @@ module = Module( root="advert", tag="Advert", - default_allowed_groups_ids=[GroupType.student, GroupType.staff], + default_allowed_account_types=[AccountType.student, AccountType.staff], ) hyperion_error_logger = logging.getLogger("hyperion.error") diff --git a/app/modules/amap/endpoints_amap.py b/app/modules/amap/endpoints_amap.py index 18a3f3098..324834eab 100644 --- a/app/modules/amap/endpoints_amap.py +++ b/app/modules/amap/endpoints_amap.py @@ -7,7 +7,7 @@ from sqlalchemy.ext.asyncio import AsyncSession from app.core import models_core, schemas_core -from app.core.groups.groups_type import GroupType +from app.core.groups.groups_type import AccountType, GroupType from app.core.notification.notification_types import CustomTopic, Topic from app.core.notification.schemas_notification import Message from app.core.users import cruds_users @@ -30,7 +30,7 @@ module = Module( root="amap", tag="AMAP", - default_allowed_groups_ids=[GroupType.student, GroupType.staff], + default_allowed_account_types=[AccountType.student, AccountType.staff], ) hyperion_amap_logger = logging.getLogger("hyperion.amap") diff --git a/app/modules/booking/endpoints_booking.py b/app/modules/booking/endpoints_booking.py index acb2de352..32f35c965 100644 --- a/app/modules/booking/endpoints_booking.py +++ b/app/modules/booking/endpoints_booking.py @@ -8,7 +8,7 @@ from app.core import models_core from app.core.groups import cruds_groups -from app.core.groups.groups_type import GroupType +from app.core.groups.groups_type import AccountType, GroupType from app.core.notification.schemas_notification import Message from app.dependencies import ( get_db, @@ -25,7 +25,7 @@ module = Module( root="booking", tag="Booking", - default_allowed_groups_ids=[GroupType.student, GroupType.staff], + default_allowed_account_types=[AccountType.student, AccountType.staff], ) hyperion_error_logger = logging.getLogger("hyperion.error") diff --git a/app/modules/calendar/endpoints_calendar.py b/app/modules/calendar/endpoints_calendar.py index 013cf8def..f2168b0e8 100644 --- a/app/modules/calendar/endpoints_calendar.py +++ b/app/modules/calendar/endpoints_calendar.py @@ -6,7 +6,7 @@ from sqlalchemy.ext.asyncio import AsyncSession from app.core import models_core -from app.core.groups.groups_type import GroupType +from app.core.groups.groups_type import AccountType, GroupType from app.dependencies import get_db, is_user_a_member_of, is_user_an_ecl_member from app.modules.calendar import cruds_calendar, models_calendar, schemas_calendar from app.modules.calendar.types_calendar import Decision @@ -16,7 +16,7 @@ module = Module( root="event", tag="Calendar", - default_allowed_groups_ids=[GroupType.student, GroupType.staff], + default_allowed_account_types=[AccountType.student, AccountType.staff], ) ical_file_path = "data/ics/ae_calendar.ics" diff --git a/app/modules/cdr/endpoints_cdr.py b/app/modules/cdr/endpoints_cdr.py index d7fe58086..f69f9a581 100644 --- a/app/modules/cdr/endpoints_cdr.py +++ b/app/modules/cdr/endpoints_cdr.py @@ -141,6 +141,7 @@ async def get_cdr_users_pending_validation( return [ schemas_cdr.CdrUser( + account_type=user.account_type, curriculum=curriculum_memberships_mapping.get(user.id, None), promo=user.promo, email=user.email, @@ -164,7 +165,7 @@ async def get_cdr_users_pending_validation( async def get_cdr_user( user_id: str, db: AsyncSession = Depends(get_db), - user: models_core.CoreUser = Depends(is_user), + user: models_core.CoreUser = Depends(is_user()), ): """ Get a user. @@ -193,6 +194,7 @@ async def get_cdr_user( curriculum_complete = {c.id: c for c in await cruds_cdr.get_curriculums(db=db)} return schemas_cdr.CdrUser( + account_type=user_db.account_type, name=user_db.name, firstname=user_db.firstname, nickname=user_db.nickname, @@ -300,6 +302,7 @@ async def update_cdr_user( curriculum=schemas_cdr.CurriculumComplete( **curriculum.__dict__, ), + account_type=user_db.account_type, name=user_db.name, firstname=user_db.firstname, nickname=user_db.nickname, @@ -369,7 +372,7 @@ async def get_sellers_by_user_id( ) async def get_online_sellers( db: AsyncSession = Depends(get_db), - user: models_core.CoreUser = Depends(is_user), + user: models_core.CoreUser = Depends(is_user()), ): """ Get all sellers that has online available products. @@ -523,7 +526,7 @@ async def send_seller_results( ) async def get_all_available_online_products( db: AsyncSession = Depends(get_db), - user: models_core.CoreUser = Depends(is_user), + user: models_core.CoreUser = Depends(is_user()), ): """ Get a seller's online available products. @@ -693,7 +696,7 @@ async def get_products_by_seller_id( async def get_available_online_products( seller_id: UUID, db: AsyncSession = Depends(get_db), - user: models_core.CoreUser = Depends(is_user), + user: models_core.CoreUser = Depends(is_user()), ): """ Get a seller's online available products. @@ -1256,7 +1259,7 @@ async def delete_document( async def get_purchases_by_user_id( user_id: str, db: AsyncSession = Depends(get_db), - user: models_core.CoreUser = Depends(is_user), + user: models_core.CoreUser = Depends(is_user()), ): """ Get a user's purchases. @@ -1311,7 +1314,7 @@ async def get_purchases_by_user_id( ) async def get_my_purchases( db: AsyncSession = Depends(get_db), - user: models_core.CoreUser = Depends(is_user), + user: models_core.CoreUser = Depends(is_user()), ): return await get_purchases_by_user_id(user.id, db, user) @@ -1325,7 +1328,7 @@ async def get_purchases_by_user_id_by_seller_id( seller_id: UUID, user_id: str, db: AsyncSession = Depends(get_db), - user: models_core.CoreUser = Depends(is_user), + user: models_core.CoreUser = Depends(is_user()), ): """ Get a user's purchases. @@ -1382,7 +1385,7 @@ async def create_purchase( product_variant_id: UUID, purchase: schemas_cdr.PurchaseBase, db: AsyncSession = Depends(get_db), - user: models_core.CoreUser = Depends(is_user), + user: models_core.CoreUser = Depends(is_user()), ): """ Create a purchase. @@ -1720,7 +1723,7 @@ async def delete_purchase( user_id: str, product_variant_id: UUID, db: AsyncSession = Depends(get_db), - user: models_core.CoreUser = Depends(is_user), + user: models_core.CoreUser = Depends(is_user()), ): """ Delete a purchase. @@ -1831,7 +1834,7 @@ async def delete_purchase( async def get_signatures_by_user_id( user_id: str, db: AsyncSession = Depends(get_db), - user: models_core.CoreUser = Depends(is_user), + user: models_core.CoreUser = Depends(is_user()), ): """ Get a user's signatures. @@ -1858,7 +1861,7 @@ async def get_signatures_by_user_id_by_seller_id( seller_id: UUID, user_id: str, db: AsyncSession = Depends(get_db), - user: models_core.CoreUser = Depends(is_user), + user: models_core.CoreUser = Depends(is_user()), ): """ Get a user's signatures for a single seller. @@ -1885,7 +1888,7 @@ async def create_signature( document_id: UUID, signature: schemas_cdr.SignatureBase, db: AsyncSession = Depends(get_db), - user: models_core.CoreUser = Depends(is_user), + user: models_core.CoreUser = Depends(is_user()), ): """ Create a signature. @@ -1990,7 +1993,7 @@ async def delete_signature( ) async def get_curriculums( db: AsyncSession = Depends(get_db), - user: models_core.CoreUser = Depends(is_user), + user: models_core.CoreUser = Depends(is_user()), ): """ Get all curriculums. @@ -2077,7 +2080,7 @@ async def create_curriculum_membership( user_id: str, curriculum_id: UUID, db: AsyncSession = Depends(get_db), - user: models_core.CoreUser = Depends(is_user), + user: models_core.CoreUser = Depends(is_user()), ws_manager: WebsocketConnectionManager = Depends(get_websocket_connection_manager), ): """ @@ -2147,6 +2150,7 @@ async def create_curriculum_membership( await ws_manager.send_message_to_room( message=schemas_cdr.NewUserWSMessageModel( data=schemas_cdr.CdrUser( + account_type=db_user.account_type, curriculum=schemas_cdr.CurriculumComplete( id=wanted_curriculum.id, name=wanted_curriculum.name, @@ -2178,7 +2182,7 @@ async def update_curriculum_membership( user_id: str, curriculum_id: UUID, db: AsyncSession = Depends(get_db), - user: models_core.CoreUser = Depends(is_user), + user: models_core.CoreUser = Depends(is_user()), ws_manager: WebsocketConnectionManager = Depends(get_websocket_connection_manager), ): """ @@ -2226,6 +2230,7 @@ async def update_curriculum_membership( await ws_manager.send_message_to_room( message=schemas_cdr.UpdateUserWSMessageModel( data=schemas_cdr.CdrUser( + account_type=db_user.account_type, curriculum=schemas_cdr.CurriculumComplete( id=curriculum.id, name=curriculum.name, @@ -2257,7 +2262,7 @@ async def delete_curriculum_membership( user_id: str, curriculum_id: UUID, db: AsyncSession = Depends(get_db), - user: models_core.CoreUser = Depends(is_user), + user: models_core.CoreUser = Depends(is_user()), ws_manager: WebsocketConnectionManager = Depends(get_websocket_connection_manager), ): """ @@ -2305,6 +2310,7 @@ async def delete_curriculum_membership( await ws_manager.send_message_to_room( message=schemas_cdr.UpdateUserWSMessageModel( data=schemas_cdr.CdrUser( + account_type=db_user.account_type, curriculum=None, promo=db_user.promo, email=db_user.email, @@ -2333,7 +2339,7 @@ async def delete_curriculum_membership( async def get_payments_by_user_id( user_id: str, db: AsyncSession = Depends(get_db), - user: models_core.CoreUser = Depends(is_user), + user: models_core.CoreUser = Depends(is_user()), ): """ Get a user's payments. @@ -2454,7 +2460,7 @@ async def delete_payment( ) async def get_payment_url( db: AsyncSession = Depends(get_db), - user: models_core.CoreUser = Depends(is_user), + user: models_core.CoreUser = Depends(is_user()), settings: Settings = Depends(get_settings), payment_tool: PaymentTool = Depends(get_payment_tool), ): @@ -2477,6 +2483,8 @@ async def get_payment_url( detail="Please give an amount in cents, greater than 1€.", ) user_schema = schemas_core.CoreUser( + account_type=user.account_type, + school_id=user.school_id, email=user.email, birthday=user.birthday, promo=user.promo, @@ -2525,7 +2533,7 @@ async def get_payment_url( async def get_memberships_by_user_id( user_id: str, db: AsyncSession = Depends(get_db), - user: models_core.CoreUser = Depends(is_user), + user: models_core.CoreUser = Depends(is_user()), ): if not ( user_id == user.id @@ -2636,7 +2644,7 @@ async def delete_membership( ) async def get_status( db: AsyncSession = Depends(get_db), - user: models_core.CoreUser = Depends(is_user), + user: models_core.CoreUser = Depends(is_user()), ): return await get_core_data(schemas_cdr.Status, db) @@ -2686,7 +2694,7 @@ async def update_status( ) async def get_my_tickets( db: AsyncSession = Depends(get_db), - user: models_core.CoreUser = Depends(is_user), + user: models_core.CoreUser = Depends(is_user()), ): return await cruds_cdr.get_tickets_of_user(db=db, user_id=user.id) @@ -2699,7 +2707,7 @@ async def get_my_tickets( async def get_tickets_of_user( user_id: str, db: AsyncSession = Depends(get_db), - user: models_core.CoreUser = Depends(is_user), + user: models_core.CoreUser = Depends(is_user()), ): if not ( is_user_member_of_an_allowed_group(user, [GroupType.admin_cdr]) @@ -2720,7 +2728,7 @@ async def get_tickets_of_user( async def get_ticket_secret( ticket_id: UUID, db: AsyncSession = Depends(get_db), - user: models_core.CoreUser = Depends(is_user), + user: models_core.CoreUser = Depends(is_user()), ): ticket = await cruds_cdr.get_ticket(db=db, ticket_id=ticket_id) if not ticket: diff --git a/app/modules/centralisation/endpoints_centralisation.py b/app/modules/centralisation/endpoints_centralisation.py index 5fdc0129e..6dd32d93f 100644 --- a/app/modules/centralisation/endpoints_centralisation.py +++ b/app/modules/centralisation/endpoints_centralisation.py @@ -1,8 +1,8 @@ -from app.core.groups.groups_type import GroupType +from app.core.groups.groups_type import AccountType from app.types.module import Module module = Module( root="centralisation", tag="Centralisation", - default_allowed_groups_ids=[GroupType.student, GroupType.staff], + default_allowed_account_types=[AccountType.student, AccountType.staff], ) diff --git a/app/modules/cinema/endpoints_cinema.py b/app/modules/cinema/endpoints_cinema.py index 07e09b19a..82b3c9d39 100644 --- a/app/modules/cinema/endpoints_cinema.py +++ b/app/modules/cinema/endpoints_cinema.py @@ -9,7 +9,7 @@ from app.core import models_core, standard_responses from app.core.config import Settings -from app.core.groups.groups_type import GroupType +from app.core.groups.groups_type import AccountType, GroupType from app.core.notification.notification_types import CustomTopic, Topic from app.core.notification.schemas_notification import Message from app.dependencies import ( @@ -34,7 +34,7 @@ module = Module( root="cinema", tag="Cinema", - default_allowed_groups_ids=[GroupType.student, GroupType.staff], + default_allowed_account_types=[AccountType.student, AccountType.staff], ) hyperion_error_logger = logging.getLogger("hyperion.error") diff --git a/app/modules/flappybird/endpoints_flappybird.py b/app/modules/flappybird/endpoints_flappybird.py index 0c7d8d8e0..a457bb8d3 100644 --- a/app/modules/flappybird/endpoints_flappybird.py +++ b/app/modules/flappybird/endpoints_flappybird.py @@ -5,7 +5,7 @@ from sqlalchemy.ext.asyncio import AsyncSession from app.core import models_core -from app.core.groups.groups_type import GroupType +from app.core.groups.groups_type import AccountType, GroupType from app.dependencies import get_db, is_user_a_member, is_user_a_member_of from app.modules.flappybird import ( cruds_flappybird, @@ -17,7 +17,7 @@ module = Module( root="flappybird", tag="Flappy Bird", - default_allowed_groups_ids=[GroupType.student], + default_allowed_account_types=[AccountType.student], ) diff --git a/app/modules/home/endpoints_home.py b/app/modules/home/endpoints_home.py index b3c905877..1e717783f 100644 --- a/app/modules/home/endpoints_home.py +++ b/app/modules/home/endpoints_home.py @@ -1,8 +1,8 @@ -from app.core.groups.groups_type import GroupType +from app.core.groups.groups_type import AccountType from app.types.module import Module module = Module( root="home", tag="Home", - default_allowed_groups_ids=[GroupType.student, GroupType.staff], + default_allowed_account_types=[AccountType.student, AccountType.staff], ) diff --git a/app/modules/loan/endpoints_loan.py b/app/modules/loan/endpoints_loan.py index 2a8940e41..2919ee3f1 100644 --- a/app/modules/loan/endpoints_loan.py +++ b/app/modules/loan/endpoints_loan.py @@ -7,7 +7,7 @@ from sqlalchemy.ext.asyncio import AsyncSession from app.core import models_core -from app.core.groups.groups_type import GroupType +from app.core.groups.groups_type import AccountType, GroupType from app.core.notification.schemas_notification import Message from app.dependencies import ( get_db, @@ -31,7 +31,7 @@ module = Module( root="loan", tag="Loans", - default_allowed_groups_ids=[GroupType.student, GroupType.staff], + default_allowed_account_types=[AccountType.student, AccountType.staff], ) diff --git a/app/modules/ph/endpoints_ph.py b/app/modules/ph/endpoints_ph.py index 218f25b8b..0395c02ba 100644 --- a/app/modules/ph/endpoints_ph.py +++ b/app/modules/ph/endpoints_ph.py @@ -6,7 +6,7 @@ from sqlalchemy.ext.asyncio import AsyncSession from app.core import models_core -from app.core.groups.groups_type import GroupType +from app.core.groups.groups_type import AccountType, GroupType from app.core.notification.notification_types import CustomTopic, Topic from app.core.notification.schemas_notification import Message from app.dependencies import ( @@ -30,7 +30,7 @@ module = Module( root="ph", tag="ph", - default_allowed_groups_ids=[GroupType.student], + default_allowed_account_types=[AccountType.student], ) diff --git a/app/modules/phonebook/endpoints_phonebook.py b/app/modules/phonebook/endpoints_phonebook.py index 82c18f899..4f7e8f0c1 100644 --- a/app/modules/phonebook/endpoints_phonebook.py +++ b/app/modules/phonebook/endpoints_phonebook.py @@ -7,7 +7,7 @@ from app.core import models_core, standard_responses from app.core.groups import cruds_groups -from app.core.groups.groups_type import GroupType +from app.core.groups.groups_type import AccountType, GroupType from app.core.users import cruds_users from app.dependencies import ( get_db, @@ -28,7 +28,7 @@ module = Module( root="phonebook", tag="Phonebook", - default_allowed_groups_ids=[GroupType.student, GroupType.staff], + default_allowed_account_types=[AccountType.student, AccountType.staff], ) hyperion_error_logger = logging.getLogger("hyperion.error") diff --git a/app/modules/raffle/endpoints_raffle.py b/app/modules/raffle/endpoints_raffle.py index d29f4975b..f2a28dc17 100644 --- a/app/modules/raffle/endpoints_raffle.py +++ b/app/modules/raffle/endpoints_raffle.py @@ -8,7 +8,7 @@ from app.core import models_core, standard_responses from app.core.groups import cruds_groups -from app.core.groups.groups_type import GroupType +from app.core.groups.groups_type import AccountType, GroupType from app.core.users import cruds_users from app.core.users.endpoints_users import read_user from app.dependencies import ( @@ -33,7 +33,7 @@ module = Module( root="tombola", tag="Raffle", - default_allowed_groups_ids=[GroupType.student, GroupType.staff], + default_allowed_account_types=[AccountType.student, AccountType.staff], ) hyperion_raffle_logger = logging.getLogger("hyperion.raffle") diff --git a/app/modules/raid/cruds_raid.py b/app/modules/raid/cruds_raid.py index d633452be..95c925958 100644 --- a/app/modules/raid/cruds_raid.py +++ b/app/modules/raid/cruds_raid.py @@ -11,9 +11,9 @@ async def create_participant( - participant: models_raid.Participant, + participant: models_raid.RaidParticipant, db: AsyncSession, -) -> models_raid.Participant: +) -> models_raid.RaidParticipant: db.add(participant) try: await db.commit() @@ -26,10 +26,10 @@ async def create_participant( async def get_all_participants( db: AsyncSession, -) -> Sequence[models_raid.Participant]: +) -> Sequence[models_raid.RaidParticipant]: participants = await db.execute( - select(models_raid.Participant).options( - # Since there is nested classes in the Participant model, we need to load all the related data + select(models_raid.RaidParticipant).options( + # Since there is nested classes in the RaidParticipant model, we need to load all the related data selectinload("*"), ), ) @@ -38,13 +38,13 @@ async def get_all_participants( async def update_participant( participant_id: str, - participant: schemas_raid.ParticipantUpdate, + participant: schemas_raid.RaidParticipantUpdate, is_minor: bool | None, db: AsyncSession, ) -> None: query = ( - update(models_raid.Participant) - .where(models_raid.Participant.id == participant_id) + update(models_raid.RaidParticipant) + .where(models_raid.RaidParticipant.id == participant_id) .values(**participant.model_dump(exclude_none=True)) ) @@ -60,8 +60,8 @@ async def update_participant_minority( db: AsyncSession, ) -> None: await db.execute( - update(models_raid.Participant) - .where(models_raid.Participant.id == participant_id) + update(models_raid.RaidParticipant) + .where(models_raid.RaidParticipant.id == participant_id) .values(is_minor=is_minor), ) await db.commit() @@ -72,7 +72,9 @@ async def is_user_a_participant( db: AsyncSession, ) -> bool: participant = await db.execute( - select(models_raid.Participant).where(models_raid.Participant.id == user_id), + select(models_raid.RaidParticipant).where( + models_raid.RaidParticipant.id == user_id + ), ) return bool(participant.scalars().first()) @@ -80,17 +82,17 @@ async def is_user_a_participant( async def get_team_by_participant_id( participant_id: str, db: AsyncSession, -) -> models_raid.Team | None: +) -> models_raid.RaidTeam | None: team = await db.execute( - select(models_raid.Team) + select(models_raid.RaidTeam) .where( or_( - models_raid.Team.captain_id == participant_id, - models_raid.Team.second_id == participant_id, + models_raid.RaidTeam.captain_id == participant_id, + models_raid.RaidTeam.second_id == participant_id, ), ) .options( - # Since there is nested classes in the Team model, we need to load all the related data + # Since there is nested classes in the RaidTeam model, we need to load all the related data selectinload("*"), ), ) @@ -99,10 +101,10 @@ async def get_team_by_participant_id( async def get_all_teams( db: AsyncSession, -) -> Sequence[models_raid.Team]: +) -> Sequence[models_raid.RaidTeam]: teams = await db.execute( - select(models_raid.Team).options( - # Since there is nested classes in the Team model, we need to load all the related data + select(models_raid.RaidTeam).options( + # Since there is nested classes in the RaidTeam model, we need to load all the related data selectinload("*"), ), ) @@ -111,10 +113,10 @@ async def get_all_teams( async def get_all_validated_teams( db: AsyncSession, -) -> list[models_raid.Team]: +) -> list[models_raid.RaidTeam]: teams = await db.execute( - select(models_raid.Team).options( - # Since there is nested classes in the Team model, we need to load all the related data + select(models_raid.RaidTeam).options( + # Since there is nested classes in the RaidTeam model, we need to load all the related data selectinload("*"), ), ) @@ -127,12 +129,12 @@ async def get_all_validated_teams( async def get_team_by_id( team_id: str, db: AsyncSession, -) -> models_raid.Team | None: +) -> models_raid.RaidTeam | None: team = await db.execute( - select(models_raid.Team) - .where(models_raid.Team.id == team_id) + select(models_raid.RaidTeam) + .where(models_raid.RaidTeam.id == team_id) .options( - # Since there is nested classes in the Team model, we need to load all the related data + # Since there is nested classes in the RaidTeam model, we need to load all the related data selectinload("*"), ), ) @@ -140,7 +142,7 @@ async def get_team_by_id( async def create_team( - team: models_raid.Team, + team: models_raid.RaidTeam, db: AsyncSession, ) -> None: db.add(team) @@ -153,12 +155,12 @@ async def create_team( async def update_team( team_id: str, - team: schemas_raid.TeamUpdate, + team: schemas_raid.RaidTeamUpdate, db: AsyncSession, ) -> None: await db.execute( - update(models_raid.Team) - .where(models_raid.Team.id == team_id) + update(models_raid.RaidTeam) + .where(models_raid.RaidTeam.id == team_id) .values(**team.model_dump(exclude_none=True)), ) await db.commit() @@ -170,8 +172,8 @@ async def update_team_captain_id( db: AsyncSession, ) -> None: await db.execute( - update(models_raid.Team) - .where(models_raid.Team.id == team_id) + update(models_raid.RaidTeam) + .where(models_raid.RaidTeam.id == team_id) .values(captain_id=captain_id), ) await db.commit() @@ -183,8 +185,8 @@ async def update_team_second_id( db: AsyncSession, ) -> None: await db.execute( - update(models_raid.Team) - .where(models_raid.Team.id == team_id) + update(models_raid.RaidTeam) + .where(models_raid.RaidTeam.id == team_id) .values(second_id=second_id), ) await db.commit() @@ -195,8 +197,8 @@ async def delete_participant( db: AsyncSession, ) -> None: await db.execute( - delete(models_raid.Participant).where( - models_raid.Participant.id == participant_id, + delete(models_raid.RaidParticipant).where( + models_raid.RaidParticipant.id == participant_id, ), ) await db.commit() @@ -218,14 +220,16 @@ async def delete_team( team_id: str, db: AsyncSession, ) -> None: - await db.execute(delete(models_raid.Team).where(models_raid.Team.id == team_id)) + await db.execute( + delete(models_raid.RaidTeam).where(models_raid.RaidTeam.id == team_id) + ) await db.commit() async def delete_all_teams( db: AsyncSession, ) -> None: - await db.execute(delete(models_raid.Team)) + await db.execute(delete(models_raid.RaidTeam)) await db.commit() @@ -288,8 +292,8 @@ async def assign_security_file( db: AsyncSession, ) -> None: await db.execute( - update(models_raid.Participant) - .where(models_raid.Participant.id == participant_id) + update(models_raid.RaidParticipant) + .where(models_raid.RaidParticipant.id == participant_id) .values(security_file_id=security_file_id), ) await db.commit() @@ -317,8 +321,8 @@ async def assign_document( db: AsyncSession, ) -> None: await db.execute( - update(models_raid.Participant) - .where(models_raid.Participant.id == participant_id) + update(models_raid.RaidParticipant) + .where(models_raid.RaidParticipant.id == participant_id) .values({document_key: document_id}), ) await db.commit() @@ -350,14 +354,14 @@ async def get_document_by_id( async def get_user_by_document_id( document_id: str, db: AsyncSession, -) -> models_raid.Participant | None: +) -> models_raid.RaidParticipant | None: document = await db.execute( - select(models_raid.Participant).where( + select(models_raid.RaidParticipant).where( or_( - models_raid.Participant.id_card_id == document_id, - models_raid.Participant.medical_certificate_id == document_id, - models_raid.Participant.student_card_id == document_id, - models_raid.Participant.raid_rules_id == document_id, + models_raid.RaidParticipant.id_card_id == document_id, + models_raid.RaidParticipant.medical_certificate_id == document_id, + models_raid.RaidParticipant.student_card_id == document_id, + models_raid.RaidParticipant.raid_rules_id == document_id, ), ), ) @@ -409,8 +413,8 @@ async def confirm_payment( db: AsyncSession, ) -> None: await db.execute( - update(models_raid.Participant) - .where(models_raid.Participant.id == participant_id) + update(models_raid.RaidParticipant) + .where(models_raid.RaidParticipant.id == participant_id) .values(payment=True), ) await db.commit() @@ -421,8 +425,8 @@ async def confirm_t_shirt_payment( db: AsyncSession, ) -> None: await db.execute( - update(models_raid.Participant) - .where(models_raid.Participant.id == participant_id) + update(models_raid.RaidParticipant) + .where(models_raid.RaidParticipant.id == participant_id) .values(t_shirt_payment=True), ) await db.commit() @@ -433,8 +437,8 @@ async def validate_attestation_on_honour( db: AsyncSession, ) -> None: await db.execute( - update(models_raid.Participant) - .where(models_raid.Participant.id == participant_id) + update(models_raid.RaidParticipant) + .where(models_raid.RaidParticipant.id == participant_id) .values(attestation_on_honour=True), ) await db.commit() @@ -443,10 +447,10 @@ async def validate_attestation_on_honour( async def get_participant_by_id( participant_id: str, db: AsyncSession, -) -> models_raid.Participant | None: +) -> models_raid.RaidParticipant | None: participant = await db.execute( - select(models_raid.Participant) - .where(models_raid.Participant.id == participant_id) + select(models_raid.RaidParticipant) + .where(models_raid.RaidParticipant.id == participant_id) .options( selectinload("*"), ), @@ -457,7 +461,7 @@ async def get_participant_by_id( async def get_number_of_teams( db: AsyncSession, ) -> int: - result = await db.execute(select(models_raid.Team)) + result = await db.execute(select(models_raid.RaidTeam)) return len(result.scalars().all()) @@ -538,7 +542,7 @@ async def get_team_if_users_in_the_same_team( participant_id_1: str, participant_id_2: str, db: AsyncSession, -) -> models_raid.Team | None: +) -> models_raid.RaidTeam | None: team_1 = await get_team_by_participant_id(participant_id_1, db) team_2 = await get_team_by_participant_id(participant_id_2, db) if team_1 is None or team_2 is None: @@ -554,8 +558,8 @@ async def update_team_file_id( db: AsyncSession, ) -> None: await db.execute( - update(models_raid.Team) - .where(models_raid.Team.id == team_id) + update(models_raid.RaidTeam) + .where(models_raid.RaidTeam.id == team_id) .values(file_id=file_id), ) await db.commit() @@ -566,15 +570,15 @@ async def get_number_of_team_by_difficulty( db: AsyncSession, ) -> int: result = await db.execute( - select(func.count()).where(models_raid.Team.difficulty == difficulty), + select(func.count()).where(models_raid.RaidTeam.difficulty == difficulty), ) return result.scalar() or 0 async def create_participant_checkout( - checkout: models_raid.ParticipantCheckout, + checkout: models_raid.RaidParticipantCheckout, db: AsyncSession, -) -> models_raid.ParticipantCheckout: +) -> models_raid.RaidParticipantCheckout: db.add(checkout) try: await db.commit() @@ -589,10 +593,10 @@ async def get_participant_checkout_by_checkout_id( # TODO: use UUID checkout_id: str, db: AsyncSession, -) -> models_raid.ParticipantCheckout | None: +) -> models_raid.RaidParticipantCheckout | None: checkout = await db.execute( - select(models_raid.ParticipantCheckout).where( - models_raid.ParticipantCheckout.checkout_id == checkout_id, + select(models_raid.RaidParticipantCheckout).where( + models_raid.RaidParticipantCheckout.checkout_id == checkout_id, ), ) return checkout.scalars().first() diff --git a/app/modules/raid/endpoints_raid.py b/app/modules/raid/endpoints_raid.py index aceacb1dd..711d56392 100644 --- a/app/modules/raid/endpoints_raid.py +++ b/app/modules/raid/endpoints_raid.py @@ -9,7 +9,7 @@ from app.core import models_core, schemas_core from app.core.config import Settings from app.core.google_api.google_api import DriveGoogleAPI -from app.core.groups.groups_type import GroupType +from app.core.groups.groups_type import AccountType, GroupType from app.core.payment.payment_tool import PaymentTool from app.dependencies import ( get_db, @@ -48,19 +48,19 @@ root="raid", tag="Raid", payment_callback=validate_payment, - default_allowed_groups_ids=[GroupType.student, GroupType.staff], + default_allowed_account_types=[AccountType.student, AccountType.staff], ) @module.router.get( "/raid/participants/{participant_id}", - response_model=schemas_raid.Participant, + response_model=schemas_raid.RaidParticipant, status_code=200, ) async def get_participant_by_id( participant_id: str, db: AsyncSession = Depends(get_db), - user: models_core.CoreUser = Depends(is_user), + user: models_core.CoreUser = Depends(is_user()), ): """ Get a participant by id @@ -80,12 +80,12 @@ async def get_participant_by_id( @module.router.post( "/raid/participants", - response_model=schemas_raid.Participant, + response_model=schemas_raid.RaidParticipant, status_code=201, ) async def create_participant( - participant: schemas_raid.ParticipantBase, - user: models_core.CoreUser = Depends(is_user), + participant: schemas_raid.RaidParticipantBase, + user: models_core.CoreUser = Depends(is_user()), db: AsyncSession = Depends(get_db), ): """ @@ -104,7 +104,7 @@ async def create_participant( raid_start_date=raid_information.raid_start_date, ) - db_participant = models_raid.Participant( + db_participant = models_raid.RaidParticipant( **participant.__dict__, id=user.id, is_minor=is_minor, @@ -118,8 +118,8 @@ async def create_participant( ) async def update_participant( participant_id: str, - participant: schemas_raid.ParticipantUpdate, - user: models_core.CoreUser = Depends(is_user), + participant: schemas_raid.RaidParticipantUpdate, + user: models_core.CoreUser = Depends(is_user()), db: AsyncSession = Depends(get_db), drive_file_manager: DriveFileManager = Depends(get_drive_file_manager), settings: Settings = Depends(get_settings), @@ -225,12 +225,12 @@ async def update_participant( @module.router.post( "/raid/teams", - response_model=schemas_raid.Team, + response_model=schemas_raid.RaidTeam, status_code=201, ) async def create_team( - team: schemas_raid.TeamBase, - user: models_core.CoreUser = Depends(is_user), + team: schemas_raid.RaidTeamBase, + user: models_core.CoreUser = Depends(is_user()), db: AsyncSession = Depends(get_db), drive_file_manager: DriveFileManager = Depends(get_drive_file_manager), settings: Settings = Depends(get_settings), @@ -246,7 +246,7 @@ async def create_team( if await cruds_raid.get_team_by_participant_id(user.id, db): raise HTTPException(status_code=403, detail="You already have a team.") - db_team = models_raid.Team( + db_team = models_raid.RaidTeam( id=str(uuid.uuid4()), name=team.name, number=None, @@ -293,13 +293,13 @@ async def generate_teams_pdf( @module.router.get( "/raid/participants/{participant_id}/team", - response_model=schemas_raid.Team, + response_model=schemas_raid.RaidTeam, status_code=200, ) async def get_team_by_participant_id( participant_id: str, db: AsyncSession = Depends(get_db), - user: models_core.CoreUser = Depends(is_user), + user: models_core.CoreUser = Depends(is_user()), ): """ Get a team by participant id @@ -320,7 +320,7 @@ async def get_team_by_participant_id( @module.router.get( "/raid/teams", - response_model=list[schemas_raid.TeamPreview], + response_model=list[schemas_raid.RaidTeamPreview], status_code=200, ) async def get_all_teams( @@ -335,7 +335,7 @@ async def get_all_teams( @module.router.get( "/raid/teams/{team_id}", - response_model=schemas_raid.Team, + response_model=schemas_raid.RaidTeam, status_code=200, ) async def get_team_by_id( @@ -355,9 +355,9 @@ async def get_team_by_id( ) async def update_team( team_id: str, - team: schemas_raid.TeamUpdate, + team: schemas_raid.RaidTeamUpdate, db: AsyncSession = Depends(get_db), - user: models_core.CoreUser = Depends(is_user), + user: models_core.CoreUser = Depends(is_user()), drive_file_manager: DriveFileManager = Depends(get_drive_file_manager), settings: Settings = Depends(get_settings), ): @@ -366,7 +366,7 @@ async def update_team( """ existing_team = await cruds_raid.get_team_by_participant_id(user.id, db) if existing_team is None: - raise HTTPException(status_code=404, detail="Team not found.") + raise HTTPException(status_code=404, detail="RaidTeam not found.") if existing_team.id != team_id: raise HTTPException(status_code=403, detail="You can only edit your own team.") await cruds_raid.update_team(team_id, team, db) @@ -434,7 +434,7 @@ async def upload_document( document_type: DocumentType, file: UploadFile = File(...), request_id: str = Depends(get_request_id), - user: models_core.CoreUser = Depends(is_user), + user: models_core.CoreUser = Depends(is_user()), db: AsyncSession = Depends(get_db), ): """ @@ -490,7 +490,7 @@ async def upload_document( async def read_document( document_id: str, db: AsyncSession = Depends(get_db), - user: models_core.CoreUser = Depends(is_user), + user: models_core.CoreUser = Depends(is_user()), ): """ Read a document @@ -513,7 +513,7 @@ async def read_document( ) raise HTTPException( status_code=404, - detail="Participant owning the document not found.", + detail="RaidParticipant owning the document not found.", ) if not await cruds_raid.are_user_in_the_same_team( @@ -567,7 +567,7 @@ async def set_security_file( security_file: schemas_raid.SecurityFileBase, participant_id: str, db: AsyncSession = Depends(get_db), - user: models_core.CoreUser = Depends(is_user), + user: models_core.CoreUser = Depends(is_user()), drive_file_manager: DriveFileManager = Depends(get_drive_file_manager), settings: Settings = Depends(get_settings), ): @@ -698,7 +698,7 @@ async def confirm_t_shirt_payment( async def validate_attestation_on_honour( participant_id: str, db: AsyncSession = Depends(get_db), - user: models_core.CoreUser = Depends(is_user), + user: models_core.CoreUser = Depends(is_user()), drive_file_manager: DriveFileManager = Depends(get_drive_file_manager), settings: Settings = Depends(get_settings), ): @@ -725,7 +725,7 @@ async def validate_attestation_on_honour( async def create_invite_token( team_id: str, db: AsyncSession = Depends(get_db), - user: models_core.CoreUser = Depends(is_user), + user: models_core.CoreUser = Depends(is_user()), ): """ Create an invite token @@ -733,7 +733,7 @@ async def create_invite_token( team = await cruds_raid.get_team_by_participant_id(user.id, db) if not team: - raise HTTPException(status_code=404, detail="Team not found.") + raise HTTPException(status_code=404, detail="RaidTeam not found.") if team.id != team_id: raise HTTPException(status_code=403, detail="You are not in the team.") @@ -759,7 +759,7 @@ async def create_invite_token( async def join_team( token: str, db: AsyncSession = Depends(get_db), - user: models_core.CoreUser = Depends(is_user), + user: models_core.CoreUser = Depends(is_user()), drive_file_manager: DriveFileManager = Depends(get_drive_file_manager), settings: Settings = Depends(get_settings), ): @@ -787,10 +787,10 @@ async def join_team( team = await cruds_raid.get_team_by_id(invite_token.team_id, db) if not team: - raise HTTPException(status_code=404, detail="Team not found.") + raise HTTPException(status_code=404, detail="RaidTeam not found.") if team.second_id: - raise HTTPException(status_code=403, detail="Team is already full.") + raise HTTPException(status_code=403, detail="RaidTeam is already full.") if team.captain_id == user.id: raise HTTPException( @@ -810,7 +810,7 @@ async def join_team( @module.router.post( "/raid/teams/{team_id}/kick/{participant_id}", - response_model=schemas_raid.Team, + response_model=schemas_raid.RaidTeam, status_code=201, ) async def kick_team_member( @@ -826,7 +826,7 @@ async def kick_team_member( """ team = await cruds_raid.get_team_by_id(team_id, db) if not team: - raise HTTPException(status_code=404, detail="Team not found.") + raise HTTPException(status_code=404, detail="RaidTeam not found.") if team.captain_id == participant_id: if not team.second_id: raise HTTPException( @@ -839,7 +839,7 @@ async def kick_team_member( db, ) elif team.second_id != participant_id: - raise HTTPException(status_code=404, detail="Participant not found.") + raise HTTPException(status_code=404, detail="RaidParticipant not found.") await cruds_raid.update_team_second_id(team_id, None, db) await post_update_actions( team, @@ -852,7 +852,7 @@ async def kick_team_member( @module.router.post( "/raid/teams/merge", - response_model=schemas_raid.Team, + response_model=schemas_raid.RaidTeam, status_code=201, ) async def merge_teams( @@ -869,11 +869,11 @@ async def merge_teams( team1 = await cruds_raid.get_team_by_id(team1_id, db) team2 = await cruds_raid.get_team_by_id(team2_id, db) if not team1 or not team2: - raise HTTPException(status_code=404, detail="Team not found.") + raise HTTPException(status_code=404, detail="RaidTeam not found.") if team1.second_id or team2.second_id: raise HTTPException(status_code=403, detail="One of the team is full.") if team1.captain_id == team2.captain_id: - raise HTTPException(status_code=403, detail="Teams are the same.") + raise HTTPException(status_code=403, detail="RaidTeams are the same.") new_name = f"{team1.name} & {team2.name}" new_difficulty = team1.difficulty if team1.difficulty == team2.difficulty else None new_meeting_place = ( @@ -882,7 +882,7 @@ async def merge_teams( new_number = ( min(team1.number, team2.number) if team1.number and team2.number else None ) - team_update: schemas_raid.TeamUpdate = schemas_raid.TeamUpdate( + team_update: schemas_raid.RaidTeamUpdate = schemas_raid.RaidTeamUpdate( name=new_name, difficulty=new_difficulty, meeting_place=new_meeting_place, @@ -914,7 +914,7 @@ async def merge_teams( ) async def get_raid_information( db: AsyncSession = Depends(get_db), - user: models_core.CoreUser = Depends(is_user), + user: models_core.CoreUser = Depends(is_user()), ): """ Get raid information @@ -1031,7 +1031,7 @@ async def get_drive_folders( ) async def get_raid_price( db: AsyncSession = Depends(get_db), - user: models_core.CoreUser = Depends(is_user), + user: models_core.CoreUser = Depends(is_user()), ): """ Get raid price @@ -1061,7 +1061,7 @@ async def update_raid_price( ) async def get_payment_url( db: AsyncSession = Depends(get_db), - user: models_core.CoreUser = Depends(is_user), + user: models_core.CoreUser = Depends(is_user()), settings: Settings = Depends(get_settings), payment_tool: PaymentTool = Depends(get_payment_tool), ): @@ -1087,18 +1087,20 @@ async def get_payment_url( if not participant.payment: checkout_name += " + " checkout_name += "T Shirt taille" + participant.t_shirt_size.value + user_dict = user.__dict__ + user_dict.pop("school", None) checkout = await payment_tool.init_checkout( module=module.root, helloasso_slug="AEECL", checkout_amount=price, checkout_name=checkout_name, redirection_uri=settings.RAID_PAYMENT_REDIRECTION_URL or "", - payer_user=schemas_core.CoreUser(**user.__dict__), + payer_user=schemas_core.CoreUser(**user_dict), db=db, ) hyperion_error_logger.info(f"RAID: Logging Checkout id {checkout.id}") await cruds_raid.create_participant_checkout( - models_raid.ParticipantCheckout( + models_raid.RaidParticipantCheckout( id=str(uuid.uuid4()), participant_id=user.id, # TODO: use UUID diff --git a/app/modules/raid/models_raid.py b/app/modules/raid/models_raid.py index dc2dcda71..33e4bc249 100644 --- a/app/modules/raid/models_raid.py +++ b/app/modules/raid/models_raid.py @@ -45,7 +45,7 @@ class SecurityFile(Base): surgical_operation: Mapped[str | None] trauma: Mapped[str | None] family: Mapped[str | None] - participant: Mapped["Participant"] = relationship( + participant: Mapped["RaidParticipant"] = relationship( back_populates="security_file", init=False, ) @@ -70,7 +70,7 @@ def validation(self) -> DocumentValidation: return DocumentValidation.temporary -class Participant(Base): +class RaidParticipant(Base): __tablename__ = "raid_participant" id: Mapped[str] = mapped_column( primary_key=True, @@ -245,7 +245,7 @@ def validation_progress(self) -> float: return (number_validated / number_total) * 100 -class Team(Base): +class RaidTeam(Base): __tablename__ = "raid_team" id: Mapped[str] = mapped_column( primary_key=True, @@ -261,13 +261,13 @@ class Team(Base): default=None, ) number: Mapped[int | None] = mapped_column(default=None) - captain: Mapped[Participant] = relationship( - "Participant", + captain: Mapped[RaidParticipant] = relationship( + "RaidParticipant", foreign_keys=[captain_id], init=False, ) - second: Mapped[Participant] = relationship( - "Participant", + second: Mapped[RaidParticipant] = relationship( + "RaidParticipant", foreign_keys=[second_id], init=False, ) @@ -298,7 +298,7 @@ class InviteToken(Base): token: Mapped[str] -class ParticipantCheckout(Base): +class RaidParticipantCheckout(Base): __tablename__ = "raid_participant_checkout" id: Mapped[str] = mapped_column( primary_key=True, diff --git a/app/modules/raid/schemas_raid.py b/app/modules/raid/schemas_raid.py index ded45db23..043b52d1e 100644 --- a/app/modules/raid/schemas_raid.py +++ b/app/modules/raid/schemas_raid.py @@ -53,7 +53,7 @@ class SecurityFile(SecurityFileBase): id: str -class ParticipantBase(BaseModel): +class RaidParticipantBase(BaseModel): name: str firstname: str birthday: date @@ -61,7 +61,7 @@ class ParticipantBase(BaseModel): email: str -class ParticipantPreview(ParticipantBase): +class RaidParticipantPreview(RaidParticipantBase): id: str bike_size: Size | None t_shirt_size: Size | None @@ -73,7 +73,7 @@ class ParticipantPreview(ParticipantBase): number_of_validated_document: int -class Participant(ParticipantPreview): +class RaidParticipant(RaidParticipantPreview): address: str | None other_school: str | None = None company: str | None = None @@ -88,7 +88,7 @@ class Participant(ParticipantPreview): is_minor: bool -class ParticipantUpdate(BaseModel): +class RaidParticipantUpdate(BaseModel): name: str | None = None firstname: str | None = None birthday: date | None = None @@ -110,32 +110,32 @@ class ParticipantUpdate(BaseModel): parent_authorization_id: str | None = None -class TeamBase(BaseModel): +class RaidTeamBase(BaseModel): name: str -class TeamPreview(TeamBase): +class RaidTeamPreview(RaidTeamBase): id: str number: int | None - captain: ParticipantPreview - second: ParticipantPreview | None + captain: RaidParticipantPreview + second: RaidParticipantPreview | None difficulty: Difficulty | None meeting_place: MeetingPlace | None validation_progress: float -class Team(TeamBase): +class RaidTeam(RaidTeamBase): id: str number: int | None - captain: Participant - second: Participant | None + captain: RaidParticipant + second: RaidParticipant | None difficulty: Difficulty | None meeting_place: MeetingPlace | None validation_progress: float file_id: str | None -class TeamUpdate(BaseModel): +class RaidTeamUpdate(BaseModel): name: str | None = None number: int | None = None difficulty: Difficulty | None = None @@ -161,6 +161,6 @@ class PaymentUrl(BaseModel): url: str -class ParticipantCheckout(BaseModel): +class RaidParticipantCheckout(BaseModel): participant_id: str checkout_id: str diff --git a/app/modules/raid/utils/pdf/pdf_writer.py b/app/modules/raid/utils/pdf/pdf_writer.py index d4659e896..c99bb1372 100644 --- a/app/modules/raid/utils/pdf/pdf_writer.py +++ b/app/modules/raid/utils/pdf/pdf_writer.py @@ -14,7 +14,12 @@ from pypdf import PdfReader, PdfWriter from app.modules.raid import coredata_raid -from app.modules.raid.models_raid import Document, Participant, SecurityFile, Team +from app.modules.raid.models_raid import ( + Document, + RaidParticipant, + RaidTeam, + SecurityFile, +) from app.modules.raid.utils.pdf.conversion_utils import ( date_to_string, get_difficulty_label, @@ -92,7 +97,7 @@ def footer(self): new_y=YPos.BMARGIN, ) - def write_team(self, team: Team) -> str: + def write_team(self, team: RaidTeam) -> str: self.pdf_indexes: list[int] = [] self.pdf_paths: list[str] = [] self.pdf_pages = [] @@ -141,7 +146,7 @@ def write_empty_participant(self): new_y=YPos.NEXT, ) - def write_participant_document(self, participant: Participant): + def write_participant_document(self, participant: RaidParticipant): for document in [ participant.id_card, participant.medical_certificate, @@ -165,7 +170,7 @@ def write_participant_document(self, participant: Participant): self.add_page() self.pdf_paths.append(document.id) - def write_document_header(self, document: Document, participant: Participant): + def write_document_header(self, document: Document, participant: RaidParticipant): self.add_page() self.set_y(self.get_y() + 6) self.set_font("times", "B", 12) @@ -204,7 +209,7 @@ def write_document_header(self, document: Document, participant: Participant): def write_security_file( self, security_file: SecurityFile, - participant: Participant, + participant: RaidParticipant, ): self.add_page() self.set_y(self.get_y() + 6) @@ -220,11 +225,11 @@ def write_security_file( self.set_font("times", "", 12) data: list[list[str] | None] = [ ["Allergie", security_file.allergy if security_file.allergy else "Aucune"], - ["Asthme", security_file.asthma and "Oui" or "Non"], + ["Asthme", (security_file.asthma and "Oui") or "Non"], ( [ "Service de réanimation", - security_file.intensive_care_unit and "Oui" or "Non", + (security_file.intensive_care_unit and "Oui") or "Non", ] if security_file.allergy else None @@ -312,7 +317,7 @@ def write_security_file( if data_row: self.write_key_label(data_row[0], data_row[1]) - def write_document(self, document: Document, participant: Participant): + def write_document(self, document: Document, participant: RaidParticipant): self.write_document_header(document, participant) self.set_y(self.get_y() + 6) file = get_file_path_from_data("raid", document.id, "documents") @@ -321,7 +326,7 @@ def write_document(self, document: Document, participant: Participant): x = ((self.epw * 2.85 - image_width) / 2) / 2.85 + 10 self.image(image, x=x) - def write_team_summary(self, team: Team): + def write_team_summary(self, team: RaidTeam): self.set_font("times", "", 12) data = [ ["Parcours", "Lieu de rendez-vous", "Numéro", "Inscription"], @@ -350,7 +355,7 @@ def write_team_summary(self, team: Team): def write_participant_summary( self, - participant: Participant, + participant: RaidParticipant, is_second: bool = False, ): self.set_font("times", "B", 12) @@ -358,7 +363,7 @@ def write_participant_summary( self.cell( w=0, h=12, - text=is_second and "Coéquipier" or "Capitaine", + text=(is_second and "Coéquipier") or "Capitaine", align="C", new_x=XPos.LMARGIN, new_y=YPos.NEXT, @@ -423,7 +428,7 @@ def __init__(self): def write_participant_security_file( self, - participant: Participant, + participant: RaidParticipant, information: coredata_raid.RaidInformation, team_number: int | None, ): diff --git a/app/modules/raid/utils/utils_raid.py b/app/modules/raid/utils/utils_raid.py index 062e1491a..c2d533b0b 100644 --- a/app/modules/raid/utils/utils_raid.py +++ b/app/modules/raid/utils/utils_raid.py @@ -12,8 +12,8 @@ from app.core.payment import schemas_payment from app.modules.raid import coredata_raid, cruds_raid, models_raid, schemas_raid from app.modules.raid.schemas_raid import ( - ParticipantBase, - ParticipantUpdate, + RaidParticipantBase, + RaidParticipantUpdate, ) from app.modules.raid.utils.drive.drive_file_manager import DriveFileManager from app.modules.raid.utils.pdf.pdf_writer import HTMLPDFWriter, PDFWriter @@ -30,7 +30,9 @@ def __init__(self, checkout_id): def will_participant_be_minor_on( - participant: ParticipantUpdate | models_raid.Participant | ParticipantBase, + participant: RaidParticipantUpdate + | models_raid.RaidParticipant + | RaidParticipantBase, raid_start_date: date | None, ) -> bool: """ @@ -92,14 +94,16 @@ async def validate_payment( async def write_teams_csv( - teams: Sequence[models_raid.Team], + teams: Sequence[models_raid.RaidTeam], db: AsyncSession, drive_file_manager: DriveFileManager, settings: Settings, ) -> None: file_name = "Équipes - " + datetime.now(UTC).strftime("%Y-%m-%d_%H_%M_%S") + ".csv" file_path = "data/raid/" + file_name - data: list[list[str]] = [["Team name", "Captain", "Second", "Difficulty", "Number"]] + data: list[list[str]] = [ + ["RaidTeam name", "Captain", "Second", "Difficulty", "Number"], + ] data.extend( [ [ @@ -132,21 +136,21 @@ async def write_teams_csv( Path(file_path).unlink() -async def set_team_number(team: models_raid.Team, db: AsyncSession) -> None: +async def set_team_number(team: models_raid.RaidTeam, db: AsyncSession) -> None: if team.difficulty is None: return new_team_number = await cruds_raid.get_number_of_team_by_difficulty( team.difficulty, db, ) - updated_team: schemas_raid.TeamUpdate = schemas_raid.TeamUpdate( + updated_team: schemas_raid.RaidTeamUpdate = schemas_raid.RaidTeamUpdate( number=new_team_number, ) await cruds_raid.update_team(team.id, updated_team, db) async def save_team_info( - team: models_raid.Team, + team: models_raid.RaidTeam, db: AsyncSession, drive_file_manager: DriveFileManager, settings: Settings, @@ -184,7 +188,7 @@ async def save_team_info( async def post_update_actions( - team: models_raid.Team | None, + team: models_raid.RaidTeam | None, db: AsyncSession, drive_file_manager: DriveFileManager, settings: Settings, @@ -213,7 +217,7 @@ async def post_update_actions( async def save_security_file( - participant: models_raid.Participant, + participant: models_raid.RaidParticipant, information: coredata_raid.RaidInformation, team_number: int | None, db: AsyncSession, @@ -264,8 +268,8 @@ async def save_security_file( async def get_participant( participant_id: str, db: AsyncSession, -) -> models_raid.Participant: +) -> models_raid.RaidParticipant: participant = await cruds_raid.get_participant_by_id(participant_id, db) if not participant: - raise HTTPException(status_code=404, detail="Participant not found.") + raise HTTPException(status_code=404, detail="RaidParticipant not found.") return participant diff --git a/app/modules/recommendation/endpoints_recommendation.py b/app/modules/recommendation/endpoints_recommendation.py index 79eec6137..44be37e7a 100644 --- a/app/modules/recommendation/endpoints_recommendation.py +++ b/app/modules/recommendation/endpoints_recommendation.py @@ -6,7 +6,7 @@ from sqlalchemy.ext.asyncio import AsyncSession from app.core import models_core, standard_responses -from app.core.groups.groups_type import GroupType +from app.core.groups.groups_type import AccountType, GroupType from app.dependencies import ( get_db, get_request_id, @@ -28,7 +28,7 @@ module = Module( root="recommendation", tag="Recommendation", - default_allowed_groups_ids=[GroupType.student, GroupType.staff], + default_allowed_account_types=[AccountType.student, AccountType.staff], ) diff --git a/app/modules/sport_competition/__init__.py b/app/modules/sport_competition/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/app/modules/sport_competition/cruds_sport_competition.py b/app/modules/sport_competition/cruds_sport_competition.py new file mode 100644 index 000000000..1dbda5d24 --- /dev/null +++ b/app/modules/sport_competition/cruds_sport_competition.py @@ -0,0 +1,1182 @@ +import logging + +from sqlalchemy import and_, delete, select, update +from sqlalchemy.exc import IntegrityError +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.orm import selectinload + +from app.core import models_core, schemas_core +from app.modules.sport_competition import models_sport_competition as models_competition +from app.modules.sport_competition import ( + schemas_sport_competition as schemas_competition, +) +from app.modules.sport_competition.types_sport_competition import ( + MultipleEditions, +) + +logger = logging.getLogger("challenger") + + +async def store_group( + group: schemas_competition.Group, + db: AsyncSession, +): + stored_group = await load_group_by_id(group.id, "", db) + if stored_group is None: + db.add(models_competition.CompetitionGroup(**group.model_dump())) + else: + await db.execute( + update(models_competition.CompetitionGroup) + .where(models_competition.CompetitionGroup.id == group.id) + .values(**group.model_dump()), + ) + try: + await db.commit() + except IntegrityError as error: + logger.exception("Could not store group") + await db.rollback() + raise error from None + + +async def delete_group_by_id( + group_id: str, + db: AsyncSession, +): + await db.execute( + delete(models_competition.CompetitionGroup).where( + models_competition.CompetitionGroup.id == group_id, + ), + ) + await db.commit() + + +async def load_group_by_id( + group_id: str, + edition_id: str, + db: AsyncSession, +) -> schemas_competition.GroupComplete | None: + group = ( + ( + await db.execute( + select(models_competition.CompetitionGroup) + .where( + models_competition.CompetitionGroup.id == group_id, + ) + .options( + selectinload(models_competition.CompetitionGroup.members), + ) + .filter( + models_competition.AnnualGroupMembership.edition_id == edition_id, + ), + ) + ) + .scalars() + .first() + ) + return ( + schemas_competition.GroupComplete( + id=group.id, + name=group.name, + description=group.description, + members=[ + schemas_core.CoreUser(**member.__dict__) for member in group.members + ], + ) + if group + else None + ) + + +async def load_group_by_name( + name: str, + db: AsyncSession, +) -> schemas_competition.GroupComplete | None: + group = await db.get(models_competition.CompetitionGroup, name) + if group is None: + return None + + return ( + schemas_competition.GroupComplete( + id=group.id, + name=group.name, + description=group.description, + members=[], + ) + if group + else None + ) + + +async def load_all_groups( + edition_id: str, + db: AsyncSession, +) -> list[schemas_competition.Group]: + groups = await db.execute(select(models_competition.CompetitionGroup)) + return [ + schemas_competition.Group(**group.__dict__) for group in groups.scalars().all() + ] + + +async def load_user_membership_with_group_id( + user_id: str, + group_id: str, + edition_id: str, + db: AsyncSession, +) -> schemas_competition.UserGroupMembership | None: + membership = await db.get( + models_competition.AnnualGroupMembership, + (user_id, group_id, edition_id), + ) + return ( + schemas_competition.UserGroupMembership( + user_id=membership.user_id, + group_id=membership.group_id, + edition_id=membership.edition_id, + ) + if membership + else None + ) + + +async def load_active_user_memberships( + user_id: str, + edition_id: str, + db: AsyncSession, +) -> list[schemas_competition.UserGroupMembership]: + memberships = ( + ( + await db.execute( + select(models_competition.AnnualGroupMembership).where( + models_competition.AnnualGroupMembership.user_id == user_id, + models_competition.AnnualGroupMembership.edition_id == edition_id, + ), + ) + ) + .scalars() + .all() + ) + return [ + schemas_competition.UserGroupMembership( + user_id=membership.user_id, + group_id=membership.group_id, + edition_id=membership.edition_id, + ) + for membership in memberships + ] + + +async def add_user_to_group( + user_id: str, + group_id: str, + edition_id: str, + db: AsyncSession, +) -> None: + db.add( + models_competition.AnnualGroupMembership( + user_id=user_id, + group_id=group_id, + edition_id=edition_id, + ), + ) + try: + await db.commit() + except IntegrityError as error: + await db.rollback() + raise error from None + + +async def remove_user_from_group( + user_id: str, + group_id: str, + edition_id: str, + db: AsyncSession, +) -> None: + await db.delete( + models_competition.AnnualGroupMembership( + user_id=user_id, + group_id=group_id, + edition_id=edition_id, + ), + ) + try: + await db.commit() + except IntegrityError as error: + await db.rollback() + raise error from None + + +async def store_school( + school: schemas_competition.SchoolExtension, + db: AsyncSession, +): + stored_school = await load_school_by_id(school.id, "", db) + if stored_school is None: + db.add(models_competition.SchoolExtension(**school.model_dump())) + else: + await db.execute( + update(models_competition.SchoolExtension) + .where(models_competition.SchoolExtension.school_id == school.id) + .values(**school.model_dump()), + ) + try: + await db.commit() + except IntegrityError as error: + await db.rollback() + raise error from None + + +async def delete_school_by_id( + school_id: str, + db: AsyncSession, +): + await db.execute( + delete(models_competition.SchoolExtension).where( + models_competition.SchoolExtension.school_id == school_id, + ), + ) + await db.commit() + + +async def load_school_by_id( + school_id: str, + edition_id: str, + db: AsyncSession, +) -> schemas_competition.SchoolExtension | None: + school_extension = ( + ( + await db.execute( + select(models_competition.SchoolExtension) + .where( + models_competition.SchoolExtension.school_id == school_id, + ) + .options( + selectinload(models_competition.SchoolExtension.school), + selectinload(models_competition.SchoolExtension.general_quota), + ) + .filter( + models_competition.SchoolGeneralQuota.edition_id == edition_id, + ), + ) + ) + .scalars() + .first() + ) + if school_extension is None: + return None + school_dict = school_extension.__dict__ + school = schemas_core.CoreSchool(**school_dict["school"].__dict__) + general_quota = schemas_competition.SchoolGeneralQuota( + **school_dict["general_quota"].__dict__, + ) + return schemas_competition.SchoolExtension( + id=school_id, + from_lyon=school_dict["from_lyon"], + activated=school_dict["activated"], + school=school, + general_quota=general_quota, + ) + + +async def load_school_by_name( + name: str, + edition_id: str, + db: AsyncSession, +) -> schemas_competition.SchoolExtension | None: + school_extension = ( + ( + await db.execute( + select(models_competition.SchoolExtension) + .where( + models_competition.SchoolExtension.school_id == name, + ) + .options( + selectinload(models_competition.SchoolExtension.school), + selectinload(models_competition.SchoolExtension.general_quota), + ) + .filter( + models_competition.SchoolGeneralQuota.edition_id == edition_id, + ), + ) + ) + .scalars() + .first() + ) + return ( + schemas_competition.SchoolExtension( + id=name, + from_lyon=school_extension.from_lyon, + activated=school_extension.activated, + school=schemas_core.CoreSchool(**school_extension.school.__dict__), + general_quota=schemas_competition.SchoolGeneralQuota( + **school_extension.general_quota.__dict__, + ), + ) + if school_extension + else None + ) + + +async def load_all_schools( + edition_id: str, + db: AsyncSession, +) -> list[schemas_competition.SchoolExtension]: + school_extensions = await db.execute( + select(models_competition.SchoolExtension) + .options( + selectinload(models_competition.SchoolExtension.school), + selectinload(models_competition.SchoolExtension.general_quota), + ) + .filter( + models_competition.SchoolGeneralQuota.edition_id == edition_id, + ), + ) + + return [ + schemas_competition.SchoolExtension( + id=school_extension.school_id, + from_lyon=school_extension.from_lyon, + activated=school_extension.activated, + school=schemas_core.CoreSchool(**school_extension.school.__dict__), + general_quota=schemas_competition.SchoolGeneralQuota( + **school_extension.general_quota.__dict__, + ), + ) + for school_extension in school_extensions.scalars().all() + ] + + +async def store_participant( + participant: schemas_competition.Participant, + db: AsyncSession, +): + stored_participant = await load_participant_by_ids( + participant.user_id, + participant.sport_id, + participant.edition_id, + db, + ) + if stored_participant is None: + db.add(models_competition.Participant(**participant.model_dump())) + else: + await db.execute( + update(models_competition.Participant) + .where( + models_competition.Participant.user_id == participant.user_id, + models_competition.Participant.sport_id == participant.sport_id, + ) + .values(**participant.model_dump()), + ) + try: + await db.commit() + except IntegrityError as error: + await db.rollback() + raise error from None + + +async def delete_participant_by_ids( + user_id: str, + sport_id: str, + edition_id: str, + db: AsyncSession, +): + await db.execute( + delete(models_competition.Participant).where( + models_competition.Participant.user_id == user_id, + models_competition.Participant.sport_id == sport_id, + models_competition.Participant.edition_id == edition_id, + ), + ) + await db.commit() + + +async def load_participant_by_ids( + user_id: str, + sport_id: str, + edition_id: str, + db: AsyncSession, +) -> schemas_competition.ParticipantComplete | None: + participant = ( + ( + await db.execute( + select(models_competition.Participant) + .where( + models_competition.Participant.user_id == user_id, + models_competition.Participant.sport_id == sport_id, + models_competition.Participant.edition_id == edition_id, + ) + .options( + selectinload(models_competition.Participant.user), + selectinload(models_competition.Participant.sport), + selectinload(models_competition.Participant.team), + ), + ) + ) + .scalars() + .first() + ) + return ( + schemas_competition.ParticipantComplete( + user_id=participant.user_id, + sport_id=participant.sport_id, + edition_id=participant.edition_id, + team_id=participant.team_id, + captain=participant.captain, + substitute=participant.substitute, + license=participant.license, + user=schemas_core.CoreUser(**participant.user.__dict__), + sport=schemas_competition.Sport(**participant.sport.__dict__), + team=schemas_competition.Team(**participant.team.__dict__), + ) + if participant + else None + ) + + +async def load_all_participants( + edition_id: str, + db: AsyncSession, +) -> list[schemas_competition.ParticipantComplete]: + participants = await db.execute( + select(models_competition.Participant) + .where( + models_competition.Participant.edition_id == edition_id, + ) + .options( + selectinload(models_competition.Participant.user), + selectinload(models_competition.Participant.sport), + selectinload(models_competition.Participant.team), + ), + ) + return [ + schemas_competition.ParticipantComplete( + user_id=participant.user_id, + sport_id=participant.sport_id, + edition_id=participant.edition_id, + team_id=participant.team_id, + captain=participant.captain, + substitute=participant.substitute, + license=participant.license, + user=schemas_core.CoreUser(**participant.user.__dict__), + sport=schemas_competition.Sport(**participant.sport.__dict__), + team=schemas_competition.Team(**participant.team.__dict__), + ) + for participant in participants.scalars().all() + ] + + +async def load_participants_by_school_id( + school_id: str, + edition_id: str, + db: AsyncSession, +) -> list[schemas_competition.ParticipantComplete]: + participants = ( + ( + await db.execute( + select(models_competition.Participant) + .where( + models_competition.Participant.edition_id == edition_id, + models_competition.Participant.user_id.in_( + select(models_core.CoreUser.id).where( + models_core.CoreUser.school_id == school_id, + ), + ), + ) + .options( + selectinload(models_competition.Participant.user), + selectinload(models_competition.Participant.sport), + selectinload(models_competition.Participant.team), + ), + ) + ) + .scalars() + .all() + ) + return [ + schemas_competition.ParticipantComplete( + user_id=participant.user_id, + sport_id=participant.sport_id, + edition_id=participant.edition_id, + team_id=participant.team_id, + captain=participant.captain, + substitute=participant.substitute, + license=participant.license, + user=schemas_core.CoreUser(**participant.user.__dict__), + sport=schemas_competition.Sport(**participant.sport.__dict__), + team=schemas_competition.Team(**participant.team.__dict__), + ) + for participant in participants + ] + + +async def load_participants_by_sport_id( + sport_id: str, + edition_id: str, + db: AsyncSession, +) -> list[schemas_competition.ParticipantComplete]: + participants = ( + ( + await db.execute( + select(models_competition.Participant) + .where( + models_competition.Participant.sport_id == sport_id, + models_competition.Participant.edition_id == edition_id, + ) + .options( + selectinload(models_competition.Participant.user), + selectinload(models_competition.Participant.sport), + selectinload(models_competition.Participant.team), + ), + ) + ) + .scalars() + .all() + ) + return [ + schemas_competition.ParticipantComplete( + user_id=participant.user_id, + sport_id=participant.sport_id, + edition_id=participant.edition_id, + team_id=participant.team_id, + captain=participant.captain, + substitute=participant.substitute, + license=participant.license, + user=schemas_core.CoreUser(**participant.user.__dict__), + sport=schemas_competition.Sport(**participant.sport.__dict__), + team=schemas_competition.Team(**participant.team.__dict__), + ) + for participant in participants + ] + + +async def load_participants_by_school_and_sport_ids( + school_id: str, + sport_id: str, + edition_id: str, + db: AsyncSession, +) -> list[schemas_competition.ParticipantComplete]: + participants = ( + ( + await db.execute( + select(models_competition.Participant) + .where( + models_competition.Participant.sport_id == sport_id, + models_competition.Participant.edition_id == edition_id, + models_competition.Participant.user_id.in_( + select(models_core.CoreUser.id).where( + models_core.CoreUser.school_id == school_id, + ), + ), + ) + .options( + selectinload(models_competition.Participant.user), + selectinload(models_competition.Participant.sport), + selectinload(models_competition.Participant.team), + ), + ) + ) + .scalars() + .all() + ) + return [ + schemas_competition.ParticipantComplete( + user_id=participant.user_id, + sport_id=participant.sport_id, + edition_id=participant.edition_id, + team_id=participant.team_id, + captain=participant.captain, + substitute=participant.substitute, + license=participant.license, + user=schemas_core.CoreUser(**participant.user.__dict__), + sport=schemas_competition.Sport(**participant.sport.__dict__), + team=schemas_competition.Team(**participant.team.__dict__), + ) + for participant in participants + ] + + +async def store_quota( + quota: schemas_competition.Quota, + db: AsyncSession, +): + stored_quota = await load_quota_by_ids( + quota.school_id, + quota.sport_id, + quota.edition_id, + db, + ) + if stored_quota is None: + db.add(models_competition.SchoolSportQuota(**quota.model_dump())) + else: + await db.execute( + update(models_competition.SchoolSportQuota) + .where( + models_competition.SchoolSportQuota.school_id == quota.school_id, + models_competition.SchoolSportQuota.sport_id == quota.sport_id, + ) + .values(**quota.model_dump()), + ) + try: + await db.commit() + except IntegrityError as error: + await db.rollback() + raise error from None + + +async def delete_quota_by_ids( + school_id: str, + sport_id: str, + edition_id: str, + db: AsyncSession, +): + await db.execute( + delete(models_competition.SchoolSportQuota).where( + models_competition.SchoolSportQuota.school_id == school_id, + models_competition.SchoolSportQuota.sport_id == sport_id, + models_competition.SchoolSportQuota.edition_id == edition_id, + ), + ) + await db.commit() + + +async def load_quota_by_ids( + school_id: str, + sport_id: str, + edition_id: str, + db: AsyncSession, +) -> schemas_competition.Quota | None: + quota = await db.get( + models_competition.SchoolSportQuota, + (school_id, sport_id, edition_id), + ) + return schemas_competition.Quota(**quota.__dict__) if quota else None + + +async def load_all_quotas_by_school_id( + school_id: str, + edition_id: str, + db: AsyncSession, +) -> list[schemas_competition.Quota]: + quotas = await db.execute( + select(models_competition.SchoolSportQuota).where( + models_competition.SchoolSportQuota.school_id == school_id, + models_competition.SchoolSportQuota.edition_id == edition_id, + ), + ) + return [ + schemas_competition.Quota(**quota.__dict__) for quota in quotas.scalars().all() + ] + + +async def load_all_quotas_by_sport_id( + sport_id: str, + edition_id: str, + db: AsyncSession, +) -> list[schemas_competition.Quota]: + quotas = await db.execute( + select(models_competition.SchoolSportQuota).where( + models_competition.SchoolSportQuota.sport_id == sport_id, + models_competition.SchoolSportQuota.edition_id == edition_id, + ), + ) + return [ + schemas_competition.Quota(**quota.__dict__) for quota in quotas.scalars().all() + ] + + +async def store_sport( + sport: schemas_competition.Sport, + db: AsyncSession, +): + stored_sport = await load_sport_by_id(sport.id, db) + if stored_sport is None: + db.add(models_competition.Sport(**sport.model_dump())) + else: + await db.execute( + update(models_competition.Sport) + .where(models_competition.Sport.id == sport.id) + .values(**sport.model_dump()), + ) + try: + await db.commit() + except IntegrityError as error: + await db.rollback() + raise error from None + + +async def delete_sport_by_id( + sport_id: str, + db: AsyncSession, +): + await db.execute( + delete(models_competition.Sport).where(models_competition.Sport.id == sport_id), + ) + await db.commit() + + +async def load_sport_by_id( + sport_id: str, + db: AsyncSession, +) -> schemas_competition.Sport | None: + sport = await db.get(models_competition.Sport, sport_id) + return schemas_competition.Sport(**sport.__dict__) if sport else None + + +async def load_sport_by_name( + name: str, + db: AsyncSession, +) -> schemas_competition.Sport | None: + sport = await db.get(models_competition.Sport, name) + return schemas_competition.Sport(**sport.__dict__) if sport else None + + +async def load_all_sports( + db: AsyncSession, +) -> list[schemas_competition.Sport]: + sports = await db.execute(select(models_competition.Sport)) + return [ + schemas_competition.Sport(**sport.__dict__) for sport in sports.scalars().all() + ] + + +async def store_team( + team: schemas_competition.Team, + db: AsyncSession, +): + stored_team = await load_team_by_id(team.id, db) + if stored_team is None: + db.add(models_competition.Team(**team.model_dump())) + else: + await db.execute( + update(models_competition.Team) + .where(models_competition.Team.id == team.id) + .values(**team.model_dump()), + ) + try: + await db.commit() + except IntegrityError as error: + await db.rollback() + raise error from None + + +async def delete_team_by_id( + team_id: str, + db: AsyncSession, +): + await db.execute( + delete(models_competition.Team).where(models_competition.Team.id == team_id), + ) + await db.commit() + + +async def load_team_by_id( + team_id, + db: AsyncSession, +) -> schemas_competition.TeamComplete | None: + team = ( + ( + await db.execute( + select(models_competition.Team) + .where(models_competition.Team.id == team_id) + .options( + selectinload(models_competition.Team.captain), + selectinload(models_competition.Team.users), + ), + ) + ) + .scalars() + .first() + ) + return ( + schemas_competition.TeamComplete( + name=team.name, + school_id=team.school_id, + sport_id=team.sport_id, + edition_id=team.edition_id, + captain_id=team.captain_id, + id=team.id, + captain=schemas_core.CoreUser(**team.captain.__dict__), + users=[schemas_core.CoreUser(**user.__dict__) for user in team.users], + ) + if team + else None + ) + + +async def load_team_by_name( + name: str, + edition_id: str, + db: AsyncSession, +) -> schemas_competition.TeamComplete | None: + team = ( + ( + await db.execute( + select(models_competition.Team) + .where( + models_competition.Team.name == name, + models_competition.Team.edition_id == edition_id, + ) + .options( + selectinload(models_competition.Team.captain), + selectinload(models_competition.Team.users), + ), + ) + ) + .scalars() + .first() + ) + return ( + schemas_competition.TeamComplete( + name=team.name, + school_id=team.school_id, + sport_id=team.sport_id, + edition_id=team.edition_id, + captain_id=team.captain_id, + id=team.id, + captain=schemas_core.CoreUser(**team.captain.__dict__), + users=[schemas_core.CoreUser(**user.__dict__) for user in team.users], + ) + if team + else None + ) + + +async def load_all_teams_by_sport_id( + sport_id: str, + edition_id: str, + db: AsyncSession, +) -> list[schemas_competition.TeamComplete]: + teams = await db.execute( + select(models_competition.Team) + .where( + models_competition.Team.sport_id == sport_id, + models_competition.Team.edition_id == edition_id, + ) + .options( + selectinload(models_competition.Team.captain), + selectinload(models_competition.Team.users), + ), + ) + return [ + schemas_competition.TeamComplete( + name=team.name, + school_id=team.school_id, + sport_id=team.sport_id, + edition_id=team.edition_id, + captain_id=team.captain_id, + id=team.id, + captain=schemas_core.CoreUser(**team.captain.__dict__), + users=[schemas_core.CoreUser(**user.__dict__) for user in team.users], + ) + for team in teams.scalars().all() + ] + + +async def load_all_teams_by_school_id( + school_id: str, + edition_id: str, + db: AsyncSession, +) -> list[schemas_competition.TeamComplete]: + teams = await db.execute( + select(models_competition.Team) + .where( + models_competition.Team.school_id == school_id, + models_competition.Team.edition_id == edition_id, + ) + .options( + selectinload(models_competition.Team.captain), + selectinload(models_competition.Team.users), + ), + ) + return [ + schemas_competition.TeamComplete( + name=team.name, + school_id=team.school_id, + sport_id=team.sport_id, + edition_id=team.edition_id, + captain_id=team.captain_id, + id=team.id, + captain=schemas_core.CoreUser(**team.captain.__dict__), + users=[schemas_core.CoreUser(**user.__dict__) for user in team.users], + ) + for team in teams.scalars().all() + ] + + +async def load_all_teams_by_school_and_sport_ids( + school_id: str, + team_id: str, + edition_id: str, + db: AsyncSession, +) -> list[schemas_competition.TeamComplete]: + teams = await db.execute( + select(models_competition.Team) + .where( + models_competition.Team.school_id == school_id, + models_competition.Team.sport_id == team_id, + models_competition.Team.edition_id == edition_id, + ) + .options( + selectinload(models_competition.Team.captain), + selectinload(models_competition.Team.users), + ), + ) + return [ + schemas_competition.TeamComplete( + name=team.name, + school_id=team.school_id, + sport_id=team.sport_id, + edition_id=team.edition_id, + captain_id=team.captain_id, + id=team.id, + captain=schemas_core.CoreUser(**team.captain.__dict__), + users=[schemas_core.CoreUser(**user.__dict__) for user in team.users], + ) + for team in teams.scalars().all() + ] + + +async def store_edition( + edition: schemas_competition.CompetitionEdition, + db: AsyncSession, +): + stored_edition = await load_edition_by_id(edition.id, db) + if stored_edition is None: + db.add(models_competition.CompetitionEdition(**edition.model_dump())) + else: + await db.execute( + update(models_competition.CompetitionEdition) + .where(models_competition.CompetitionEdition.id == edition.id) + .values(**edition.model_dump()), + ) + active = await load_active_edition(db) + if active and active.id != edition.id and edition.activated: + raise MultipleEditions + + try: + await db.commit() + except IntegrityError as error: + await db.rollback() + raise error from None + + +async def delete_edition_by_id( + edition_id: str, + db: AsyncSession, +): + await db.execute( + delete(models_competition.CompetitionEdition).where( + models_competition.CompetitionEdition.id == edition_id, + ), + ) + await db.commit() + + +async def load_edition_by_id( + edition_id, + db: AsyncSession, +) -> schemas_competition.CompetitionEdition | None: + edition = await db.get(models_competition.CompetitionEdition, edition_id) + return ( + schemas_competition.CompetitionEdition(**edition.__dict__) if edition else None + ) + + +async def load_active_edition( + db: AsyncSession, +) -> schemas_competition.CompetitionEdition | None: + edition = ( + ( + await db.execute( + select(models_competition.CompetitionEdition).where( + models_competition.CompetitionEdition.activated, + ), + ) + ) + .scalars() + .first() + ) + return ( + schemas_competition.CompetitionEdition(**edition.__dict__) if edition else None + ) + + +async def load_all_editions( + db: AsyncSession, +) -> list[schemas_competition.CompetitionEdition]: + editions = await db.execute(select(models_competition.CompetitionEdition)) + return [ + schemas_competition.CompetitionEdition(**edition.__dict__) + for edition in editions.scalars().all() + ] + + +async def store_match( + match: schemas_competition.Match, + db: AsyncSession, +): + stored_match = await load_match_by_id(match.id, db) + if stored_match is None: + db.add(models_competition.Match(**match.model_dump())) + else: + await db.execute( + update(models_competition.Match) + .where(models_competition.Match.id == match.id) + .values(**match.model_dump()), + ) + try: + await db.commit() + except IntegrityError as error: + await db.rollback() + raise error from None + + +async def delete_match_by_id( + match_id: str, + db: AsyncSession, +): + await db.execute( + delete(models_competition.Match).where(models_competition.Match.id == match_id), + ) + await db.commit() + + +async def load_match_by_id( + match_id, + db: AsyncSession, +) -> schemas_competition.Match | None: + match = await db.get(models_competition.Match, match_id) + return schemas_competition.Match(**match.__dict__) if match else None + + +async def load_all_matches_by_sport_id( + sport_id: str, + edition_id: str, + db: AsyncSession, +) -> list[schemas_competition.Match]: + matches = await db.execute( + select(models_competition.Match) + .where( + models_competition.Match.sport_id == sport_id, + models_competition.Match.edition_id == edition_id, + ) + .options( + selectinload(models_competition.Match.team1), + selectinload(models_competition.Match.team2), + ), + ) + + return [ + schemas_competition.Match( + id=match.id, + name=match.name, + edition_id=match.edition_id, + sport_id=match.sport_id, + team1_id=match.team1_id, + team2_id=match.team2_id, + datetime=match.date, + score_team1=match.score_team1, + score_team2=match.score_team2, + location=match.location, + winner_id=match.winner_id, + team1=schemas_competition.Team(**match.team1.__dict__), + team2=schemas_competition.Team(**match.team2.__dict__), + ) + for match in matches.scalars().all() + ] + + +async def load_all_matches_by_school_id( + school_id: str, + edition_id: str, + db: AsyncSession, +) -> list[schemas_competition.Match]: + matches = await db.execute( + select(models_competition.Match) + .where( + models_competition.Match.edition_id == edition_id, + ( + models_competition.Match.team1_id.in_( + select(models_competition.Team.id).where( + models_competition.Team.school_id == school_id, + ), + ) + | models_competition.Match.team2_id.in_( + select(models_competition.Team.id).where( + models_competition.Team.school_id == school_id, + ), + ) + ), + ) + .options( + selectinload(models_competition.Match.team1), + selectinload(models_competition.Match.team2), + ), + ) + + return [ + schemas_competition.Match( + id=match.id, + name=match.name, + edition_id=match.edition_id, + sport_id=match.sport_id, + team1_id=match.team1_id, + team2_id=match.team2_id, + date=match.date, + score_team1=match.score_team1, + score_team2=match.score_team2, + location=match.location, + winner_id=match.winner_id, + team1=schemas_competition.Team(**match.team1.__dict__), + team2=schemas_competition.Team(**match.team2.__dict__), + ) + for match in matches.scalars().all() + ] + + +async def load_match_by_teams_ids( + team1_id: str, + team2_id: str, + db: AsyncSession, +) -> schemas_competition.Match | None: + match = ( + ( + await db.execute( + select(models_competition.Match) + .where( + and_( + models_competition.Match.team1_id == team1_id, + models_competition.Match.team2_id == team2_id, + ) + | and_( + (models_competition.Match.team1_id == team2_id), + (models_competition.Match.team2_id == team1_id), + ), + ) + .options( + selectinload(models_competition.Match.team1), + selectinload(models_competition.Match.team2), + ), + ) + ) + .scalars() + .first() + ) + return ( + schemas_competition.Match( + id=match.id, + name=match.name, + edition_id=match.edition_id, + sport_id=match.sport_id, + team1_id=match.team1_id, + team2_id=match.team2_id, + date=match.date, + score_team1=match.score_team1, + score_team2=match.score_team2, + location=match.location, + winner_id=match.winner_id, + team1=schemas_competition.Team(**match.team1.__dict__), + team2=schemas_competition.Team(**match.team2.__dict__), + ) + if match + else None + ) diff --git a/app/modules/sport_competition/dependencies_sport_competition.py b/app/modules/sport_competition/dependencies_sport_competition.py new file mode 100644 index 000000000..caad3eaba --- /dev/null +++ b/app/modules/sport_competition/dependencies_sport_competition.py @@ -0,0 +1,78 @@ +import logging +from collections.abc import Callable, Coroutine +from typing import Any + +from fastapi import Depends, HTTPException +from sqlalchemy.ext.asyncio import AsyncSession + +from app.core import models_core +from app.core.groups.groups_type import GroupType +from app.dependencies import get_db, is_user +from app.modules.sport_competition import ( + cruds_sport_competition, + schemas_sport_competition, +) +from app.modules.sport_competition.types_sport_competition import CompetitionGroupType + +hyperion_access_logger = logging.getLogger("hyperion.access") + + +def is_user_a_member_of_extended( + group_id: GroupType | None = None, + comptition_group_id: CompetitionGroupType | None = None, + exclude_external: bool = False, +) -> Callable[[models_core.CoreUser], Coroutine[Any, Any, models_core.CoreUser]]: + """ + Generate a dependency which will: + * check if the request header contains a valid API JWT token (a token that can be used to call endpoints from the API) + * make sure the user making the request exists and is a member of the group with the given id + * make sure the user is not an external user if `exclude_external` is True + * return the corresponding user `models_core.CoreUser` object + """ + + async def is_user_a_member_of( + user: models_core.CoreUser = Depends( + is_user( + included_groups=[group_id] if group_id else None, + exclude_external=exclude_external, + ), + ), + edition: schemas_sport_competition.CompetitionEdition = Depends( + get_current_edition, + ), + db: AsyncSession = Depends(get_db), + ) -> models_core.CoreUser: + """ + A dependency that checks that user is a member of the group with the given id then returns the corresponding user. + """ + if comptition_group_id is None: + return user + membership = cruds_sport_competition.load_user_membership_with_group_id( + user_id=user.id, + group_id=comptition_group_id, + edition_id=edition.id, + db=db, + ) + if not membership: + raise HTTPException( + status_code=403, + detail="User is not a member of the group", + ) + return user + + return is_user_a_member_of + + +async def get_current_edition( + db: AsyncSession = Depends(get_db), +) -> schemas_sport_competition.CompetitionEdition: + """ + Dependency that returns the current edition + """ + current_edition = await cruds_sport_competition.load_active_edition(db) + if not current_edition: + raise HTTPException( + status_code=404, + detail="No current edition", + ) + return current_edition diff --git a/app/modules/sport_competition/endpoints_sport_competition.py b/app/modules/sport_competition/endpoints_sport_competition.py new file mode 100644 index 000000000..e06b0c867 --- /dev/null +++ b/app/modules/sport_competition/endpoints_sport_competition.py @@ -0,0 +1,799 @@ +import logging +import uuid + +from fastapi import APIRouter, Depends, HTTPException +from sqlalchemy.ext.asyncio import AsyncSession + +from app.core import schemas_core +from app.core.users import cruds_users +from app.dependencies import get_db, is_user +from app.modules.sport_competition import cruds_sport_competition as competition_cruds +from app.modules.sport_competition import ( + schemas_sport_competition as competition_schemas, +) +from app.modules.sport_competition.dependencies_sport_competition import ( + get_current_edition, + is_user_a_member_of_extended, +) +from app.modules.sport_competition.types_sport_competition import CompetitionGroupType + +router = APIRouter(tags=["Sport Competition"]) +hyperion_error_logger = logging.getLogger("hyperion.error") + + +@router.get("/competition/sports", response_model=list[competition_schemas.Sport]) +async def get_sports(db: AsyncSession = Depends(get_db)): + return await competition_cruds.load_all_sports(db) + + +@router.post( + "/competition/sports", + status_code=201, + response_model=competition_schemas.Sport, +) +async def create_sport( + sport: competition_schemas.SportBase, + db: AsyncSession = Depends(get_db), + user=Depends( + is_user_a_member_of_extended( + comptition_group_id=CompetitionGroupType.competition_admin, + ), + ), +): + stored = await competition_cruds.load_sport_by_name(sport.name, db) + if stored is not None: + raise HTTPException( + status_code=400, + detail="A sport with this name already exists", + ) from None + sport = competition_schemas.Sport(**sport.model_dump(), id=str(uuid.uuid4())) + await competition_cruds.store_sport(sport, db) + return sport + + +@router.patch("/competition/sports") +async def edit_sport( + sport: competition_schemas.SportEdit, + db: AsyncSession = Depends(get_db), + user=Depends( + is_user_a_member_of_extended( + comptition_group_id=CompetitionGroupType.competition_admin, + ), + ), +): + stored = await competition_cruds.load_sport_by_id(sport.id, db) + if stored is None: + raise HTTPException( + status_code=404, + detail="Sport not found in the database", + ) from None + stored.model_copy(update=sport.model_dump()) + await competition_cruds.store_sport(stored, db) + return stored + + +@router.delete("/competition/sports/{sport_id}") +async def delete_sport( + sport_id: str, + db: AsyncSession = Depends(get_db), + user=Depends( + is_user_a_member_of_extended( + comptition_group_id=CompetitionGroupType.competition_admin, + ), + ), +): + stored = await competition_cruds.load_sport_by_id(sport_id, db) + if stored is None: + raise HTTPException( + status_code=404, + detail="Sport not found in the database", + ) from None + if stored.activated: + raise HTTPException( + status_code=400, + detail="Sport is activated and cannot be deleted", + ) from None + await competition_cruds.delete_sport_by_id(sport_id, db) + + +@router.get("/competition/groups", response_model=list[competition_schemas.Group]) +async def get_groups( + db=Depends(get_db), + edition: competition_schemas.CompetitionEdition = Depends(get_current_edition), +): + return await competition_cruds.load_all_groups(edition.id, db) + + +@router.post( + "/competition/groups", + status_code=201, + response_model=competition_schemas.Group, +) +async def create_group( + group: competition_schemas.GroupBase, + db=Depends(get_db), + user=Depends( + is_user_a_member_of_extended( + comptition_group_id=CompetitionGroupType.competition_admin, + ), + ), +): + stored = await competition_cruds.load_group_by_name(group.name, db) + if stored is not None: + raise HTTPException(status_code=400, detail="Group already exists") from None + group = competition_schemas.Group(**group.model_dump(), id=str(uuid.uuid4())) + await competition_cruds.store_group(group, db) + return group + + +@router.patch("/competition/groups") +async def edit_group( + group: competition_schemas.GroupEdit, + db=Depends(get_db), + user=Depends( + is_user_a_member_of_extended( + comptition_group_id=CompetitionGroupType.competition_admin, + ), + ), + edition: competition_schemas.CompetitionEdition = Depends(get_current_edition), +): + stored = await competition_cruds.load_group_by_id(group.id, edition.id, db) + if stored is None: + raise HTTPException(status_code=404, detail="Group not found") from None + stored.model_copy(update=group.model_dump()) + await competition_cruds.store_group(stored, db) + + +@router.delete("/competition/groups/{group_id}") +async def delete_group( + group_id: str, + db=Depends(get_db), + user=Depends( + is_user_a_member_of_extended( + comptition_group_id=CompetitionGroupType.competition_admin, + ), + ), + edition: competition_schemas.CompetitionEdition = Depends(get_current_edition), +): + stored = await competition_cruds.load_group_by_id(group_id, edition.id, db) + if stored is None: + raise HTTPException(status_code=404, detail="Group not found") from None + + # await competition_cruds.delete_group_membership_by_group_id(group_id, db) + + await competition_cruds.delete_group_by_id(group_id, db) + + +@router.get( + "/competition/sport/{sport_id}/quotas", + response_model=list[competition_schemas.Quota], +) +async def get_quotas_for_sport( + sport_id: str, + db=Depends(get_db), + user=Depends( + is_user_a_member_of_extended( + comptition_group_id=CompetitionGroupType.competition_admin, + ), + ), + edition: competition_schemas.CompetitionEdition = Depends(get_current_edition), +): + sport = await competition_cruds.load_sport_by_id(sport_id, db) + if sport is None: + raise HTTPException( + status_code=404, + detail="Sport not found in the database", + ) from None + return await competition_cruds.load_all_quotas_by_sport_id( + sport_id, + edition.id, + db, + ) + + +@router.get( + "/competition/school/{school_id}/quotas", + response_model=list[competition_schemas.Quota], +) +async def get_quotas_for_school( + school_id: str, + db=Depends(get_db), + edition: competition_schemas.CompetitionEdition = Depends(get_current_edition), +): + school = await competition_cruds.load_school_by_id(school_id, edition.id, db) + if school is None: + raise HTTPException( + status_code=404, + detail="School not found in the database", + ) from None + return await competition_cruds.load_all_quotas_by_school_id( + school_id, + edition.id, + db, + ) + + +@router.post("/competition/school/{school_id}/sport/{sport_id}/quotas", status_code=201) +async def create_quota( + school_id: str, + sport_id: str, + quota_info: competition_schemas.QuotaInfo, + db=Depends(get_db), + user=Depends( + is_user_a_member_of_extended( + comptition_group_id=CompetitionGroupType.competition_admin, + ), + ), + edition: competition_schemas.CompetitionEdition = Depends(get_current_edition), +): + school = await competition_cruds.load_school_by_id(school_id, edition.id, db) + if school is None: + raise HTTPException( + status_code=404, + detail="School not found in the database", + ) from None + sport = await competition_cruds.load_sport_by_id(sport_id, db) + if sport is None: + raise HTTPException( + status_code=404, + detail="Sport not found in the database", + ) from None + stored = await competition_cruds.load_quota_by_ids( + school_id, + sport_id, + edition.id, + db, + ) + if stored is not None: + raise HTTPException(status_code=400, detail="Quota already exists") from None + quota = competition_schemas.Quota( + school_id=school_id, + sport_id=sport_id, + participant_quota=quota_info.participant_quota, + team_quota=quota_info.team_quota, + edition_id=edition.id, + ) + await competition_cruds.store_quota(quota, db) + + +@router.patch("/competition/school/{school_id}/sport/{sport_id}/quotas") +async def edit_quota( + school_id: str, + sport_id: str, + quota_info: competition_schemas.QuotaEdit, + db=Depends(get_db), + user=Depends( + is_user_a_member_of_extended( + comptition_group_id=CompetitionGroupType.competition_admin, + ), + ), + edition: competition_schemas.CompetitionEdition = Depends(get_current_edition), +): + school = await competition_cruds.load_school_by_id(school_id, edition.id, db) + if school is None: + raise HTTPException( + status_code=404, + detail="School not found in the database", + ) from None + sport = await competition_cruds.load_sport_by_id(sport_id, db) + if sport is None: + raise HTTPException( + status_code=404, + detail="Sport not found in the database", + ) from None + stored = await competition_cruds.load_quota_by_ids( + school_id, + sport_id, + edition.id, + db, + ) + if stored is None: + raise HTTPException( + status_code=404, + detail="Quota not found in the database", + ) from None + stored.model_copy(update=quota_info.model_dump()) + await competition_cruds.store_quota(stored, db) + + +@router.delete("/competition/school/{school_id}/sport/{sport_id}/quotas") +async def delete_quota( + school_id: str, + sport_id: str, + db=Depends(get_db), + user=Depends( + is_user_a_member_of_extended( + comptition_group_id=CompetitionGroupType.competition_admin, + ), + ), +): + edition = await competition_cruds.load_active_edition(db) + if edition is None: + raise HTTPException( + status_code=404, + detail="No active edition found in the database", + ) from None + stored = await competition_cruds.load_quota_by_ids( + school_id, + sport_id, + edition.id, + db, + ) + if stored is None: + raise HTTPException( + status_code=404, + detail="Quota not found in the database", + ) from None + await competition_cruds.delete_quota_by_ids(school_id, sport_id, edition.id, db) + + +@router.get( + "/competition/schools", + response_model=list[competition_schemas.SchoolExtension], +) +async def get_schools( + db=Depends(get_db), + edition: competition_schemas.CompetitionEdition = Depends(get_current_edition), +): + return await competition_cruds.load_all_schools(edition.id, db) + + +@router.post( + "/competition/schools", + status_code=201, + response_model=competition_schemas.SchoolExtension, +) +async def create_school( + school: competition_schemas.SchoolExtension, + db=Depends(get_db), + user=Depends( + is_user_a_member_of_extended( + comptition_group_id=CompetitionGroupType.competition_admin, + ), + ), + edition: competition_schemas.CompetitionEdition = Depends(get_current_edition), +): + stored = await competition_cruds.load_school_by_id(school.id, edition.id, db) + if stored is not None: + raise HTTPException( + status_code=400, + detail="School extension already exists", + ) from None + await competition_cruds.store_school(school, db) + return school + + +@router.patch("/competition/schools") +async def edit_school( + school: competition_schemas.SchoolExtensionEdit, + db=Depends(get_db), + user=Depends( + is_user_a_member_of_extended( + comptition_group_id=CompetitionGroupType.competition_admin, + ), + ), + edition: competition_schemas.CompetitionEdition = Depends(get_current_edition), +): + stored = await competition_cruds.load_school_by_id(school.id, edition.id, db) + if stored is None: + raise HTTPException( + status_code=404, + detail="School not found in the database", + ) from None + stored.model_copy(update=school.model_dump()) + await competition_cruds.store_school(stored, db) + + +@router.delete("/competition/schools/{school_id}") +async def delete_school( + school_id: str, + db=Depends(get_db), + user=Depends( + is_user_a_member_of_extended( + comptition_group_id=CompetitionGroupType.competition_admin, + ), + ), + edition: competition_schemas.CompetitionEdition = Depends(get_current_edition), +): + stored = await competition_cruds.load_school_by_id(school_id, edition.id, db) + if stored is None: + raise HTTPException( + status_code=404, + detail="School not found in the database", + ) from None + if stored.activated: + raise HTTPException( + status_code=400, + detail="School is activated and cannot be deleted", + ) from None + await competition_cruds.delete_school_by_id(school_id, db) + + +async def check_team_consistency( + school_id: str, + sport_id: str, + team_id: str, + db: AsyncSession, + edition: competition_schemas.CompetitionEdition = Depends(get_current_edition), +) -> competition_schemas.TeamComplete: + school = await competition_cruds.load_school_by_id(school_id, edition.id, db) + if school is None: + raise HTTPException( + status_code=404, + detail="School not found in the database", + ) from None + sport = await competition_cruds.load_sport_by_id(sport_id, db) + if sport is None: + raise HTTPException( + status_code=404, + detail="Sport not found in the database", + ) from None + team = await competition_cruds.load_team_by_id(team_id, db) + if team is None: + raise HTTPException( + status_code=404, + detail="Team not found in the database", + ) from None + if team.school_id != school_id or team.sport_id != sport_id: + raise HTTPException( + status_code=403, + detail="Team given does not belong to school and sport", + ) from None + return team + + +@router.get( + "/competition/sport/{sport_id}/teams", + response_model=list[competition_schemas.Team], +) +async def get_teams_for_sport( + sport_id: str, + db=Depends(get_db), + user=Depends( + is_user_a_member_of_extended( + comptition_group_id=CompetitionGroupType.competition_admin, + ), + ), + edition: competition_schemas.CompetitionEdition = Depends(get_current_edition), +): + if edition is None: + raise HTTPException( + status_code=404, + detail="No active edition found in the database", + ) from None + sport = await competition_cruds.load_sport_by_id(sport_id, db) + if sport is None: + raise HTTPException( + status_code=404, + detail="Sport not found in the database", + ) from None + return await competition_cruds.load_all_teams_by_sport_id(sport_id, edition.id, db) + + +@router.get( + "/competition/school/{school_id}/sports/{sport_id}/teams", + response_model=list[competition_schemas.Team], +) +async def get_sport_teams_for_school( + school_id: str, + sport_id: str, + db=Depends(get_db), + edition: competition_schemas.CompetitionEdition = Depends(get_current_edition), +): + school = await competition_cruds.load_school_by_id(school_id, edition.id, db) + if school is None: + raise HTTPException( + status_code=404, + detail="School not found in the database", + ) from None + sport = await competition_cruds.load_sport_by_id(sport_id, db) + if sport is None: + raise HTTPException( + status_code=404, + detail="Sport not found in the database", + ) from None + return await competition_cruds.load_all_teams_by_school_and_sport_ids( + school_id, + sport_id, + edition.id, + db, + ) + + +@router.post( + "/competition/school/{school_id}/sport/{sport_id}/teams", + status_code=201, + response_model=competition_schemas.Team, +) +async def create_team( + school_id: str, + sport_id: str, + team_info: competition_schemas.TeamInfo, + db=Depends(get_db), + edition: competition_schemas.CompetitionEdition = Depends(get_current_edition), + user: schemas_core.CoreUser = Depends(is_user), +): + if ( + user.id != team_info.captain_id + and CompetitionGroupType.competition_admin.value + not in [group.id for group in user.groups] + ): + raise HTTPException(status_code=403, detail="Unauthorized action") from None + global_team = await competition_cruds.load_team_by_name( + team_info.name, + edition.id, + db, + ) + if global_team is not None: + raise HTTPException(status_code=400, detail="Team already exists") from None + stored = await competition_cruds.load_all_teams_by_school_and_sport_ids( + school_id, + sport_id, + edition.id, + db, + ) + quotas = await competition_cruds.load_quota_by_ids( + school_id, + sport_id, + edition.id, + db, + ) + if quotas is not None: + if len(stored) >= quotas.team_quota: + raise HTTPException(status_code=400, detail="Team quota reached") from None + team = competition_schemas.Team( + id=uuid.uuid4(), + school_id=school_id, + sport_id=sport_id, + name=team_info.name, + captain_id=team_info.captain_id, + edition_id=edition.id, + ) + await competition_cruds.store_team(team, db) + + +@router.patch("/competition/school/{school_id}/sport/{sport_id}/teams") +async def edit_team( + school_id: str, + sport_id: str, + team_info: competition_schemas.TeamEdit, + db=Depends(get_db), + user: schemas_core.CoreUser = Depends(is_user), +): + stored = await check_team_consistency( + school_id, + sport_id, + team_info.id, + db, + ) + if ( + user.id != stored.captain_id + and CompetitionGroupType.competition_admin.value + not in [group.id for group in user.groups] + ): + raise HTTPException(status_code=403, detail="Unauthorized action") from None + if team_info.captain_id is not None: + captain = await cruds_users.get_user_by_id( + db, + team_info.captain_id, + ) + if captain is None: + raise HTTPException( + status_code=404, + detail="Captain user not found", + ) from None + stored.model_copy(update=team_info.model_dump()) + await competition_cruds.store_team(stored, db) + + +@router.delete("/competition/school/{school_id}/sport/{sport_id}/teams/{team_id}") +async def delete_team( + school_id: str, + sport_id: str, + team_id: str, + db=Depends(get_db), + user: schemas_core.CoreUser = Depends(is_user), +): + stored = await check_team_consistency( + school_id, + sport_id, + team_id, + db, + ) + if ( + user.id != stored.captain_id + and CompetitionGroupType.competition_admin.value + not in [group.id for group in user.groups] + ): + raise HTTPException(status_code=403, detail="Unauthorized action") from None + await competition_cruds.delete_team_by_id(stored.id, db) + + +@router.get( + "/competition/sport/{sport_id}/matches", + response_model=list[competition_schemas.Match], +) +async def get_matches_for_sport_and_edition( + sport_id: str, + edition: competition_schemas.CompetitionEdition = Depends(get_current_edition), + db=Depends(get_db), +): + sport = await competition_cruds.load_sport_by_id(sport_id, db) + if sport is None: + raise HTTPException( + status_code=404, + detail="Sport not found in the database", + ) from None + return await competition_cruds.load_all_matches_by_sport_id( + sport_id, + edition.id, + db, + ) + + +@router.get( + "/competition/school/{school_id}/matches", + response_model=list[competition_schemas.Match], +) +async def get_matches_for_school_sport_and_edition( + school_id: str, + sport_id: str, + edition: competition_schemas.CompetitionEdition = Depends(get_current_edition), + db=Depends(get_db), +): + school = await competition_cruds.load_school_by_id(school_id, edition.id, db) + if school is None: + raise HTTPException( + status_code=404, + detail="School not found in the database", + ) from None + sport = await competition_cruds.load_sport_by_id(sport_id, db) + if sport is None: + raise HTTPException( + status_code=404, + detail="Sport not found in the database", + ) from None + return await competition_cruds.load_all_matches_by_school_id( + school_id, + edition.id, + db, + ) + + +def check_match_consistency( + sport_id: str, + match_info: competition_schemas.MatchBase, + team1: competition_schemas.TeamComplete, + team2: competition_schemas.TeamComplete, + edition: competition_schemas.CompetitionEdition, +) -> None: + if match_info.edition_id != edition.id: + raise HTTPException( + status_code=400, + detail="Match edition does not match the current edition", + ) from None + if team1.sport_id != sport_id or team2.sport_id != sport_id: + raise HTTPException( + status_code=403, + detail="Teams do not belong to the sport", + ) from None + if team1.edition_id != edition.id or team2.edition_id != edition.id: + raise HTTPException( + status_code=403, + detail="Teams do not belong to the current edition", + ) from None + if match_info.team1_id == match_info.team2_id: + raise HTTPException( + status_code=400, + detail="Teams cannot play against themselves", + ) from None + if match_info.date is not None: + if match_info.date < edition.start_date: + raise HTTPException( + status_code=400, + detail="Match date is before the edition start date", + ) from None + if match_info.date > edition.end_date: + raise HTTPException( + status_code=400, + detail="Match date is after the edition end date", + ) from None + + +@router.post( + "/competition/sport/{sport_id}/matches", + status_code=201, + response_model=competition_schemas.Match, +) +async def create_match( + sport_id: str, + match_info: competition_schemas.MatchBase, + db=Depends(get_db), + user: schemas_core.CoreUser = Depends(is_user), + edition: competition_schemas.CompetitionEdition = Depends(get_current_edition), +): + sport = await competition_cruds.load_sport_by_id(sport_id, db) + if sport is None: + raise HTTPException( + status_code=404, + detail="Sport not found in the database", + ) from None + team1 = await competition_cruds.load_team_by_id(match_info.team1_id, db) + if team1 is None: + raise HTTPException( + status_code=404, + detail="Team 1 not found in the database", + ) from None + team2 = await competition_cruds.load_team_by_id(match_info.team2_id, db) + if team2 is None: + raise HTTPException( + status_code=404, + detail="Team 2 not found in the database", + ) from None + + check_match_consistency(sport_id, match_info, team1, team2, edition) + + match = competition_schemas.Match( + id=uuid.uuid4(), + sport_id=sport_id, + edition_id=match_info.edition_id, + datetime=match_info.date, + location=match_info.location, + name=match_info.name, + team1_id=match_info.team1_id, + team2_id=match_info.team2_id, + winner_id=None, + score_team1=None, + score_team2=None, + team1=team1, + team2=team2, + ) + await competition_cruds.store_match(match, db) + + +@router.patch("/competition/matches/{match_id}") +async def edit_match( + match_id: str, + match_info: competition_schemas.MatchBase, + db=Depends(get_db), + user: schemas_core.CoreUser = Depends(is_user), + edition: competition_schemas.CompetitionEdition = Depends(get_current_edition), +): + match = await competition_cruds.load_match_by_id(match_id, db) + if match is None: + raise HTTPException( + status_code=404, + detail="Match not found in the database", + ) from None + team1 = await competition_cruds.load_team_by_id(match_info.team1_id, db) + if team1 is None: + raise HTTPException( + status_code=404, + detail="Team 1 not found in the database", + ) from None + team2 = await competition_cruds.load_team_by_id(match_info.team2_id, db) + if team2 is None: + raise HTTPException( + status_code=404, + detail="Team 2 not found in the database", + ) from None + + check_match_consistency(match.sport_id, match_info, team1, team2, edition) + + match.model_copy(update=match_info.model_dump()) + await competition_cruds.store_match(match, db) + + +@router.delete("/competition/matches/{match_id}") +async def delete_match( + match_id: str, + db=Depends(get_db), + user: schemas_core.CoreUser = Depends(is_user), +): + match = await competition_cruds.load_match_by_id(match_id, db) + if match is None: + raise HTTPException( + status_code=404, + detail="Match not found in the database", + ) from None + await competition_cruds.delete_match_by_id(match_id, db) diff --git a/app/modules/sport_competition/models_sport_competition.py b/app/modules/sport_competition/models_sport_competition.py new file mode 100644 index 000000000..7361c152d --- /dev/null +++ b/app/modules/sport_competition/models_sport_competition.py @@ -0,0 +1,212 @@ +from datetime import datetime + +from sqlalchemy import ForeignKey +from sqlalchemy.orm import Mapped, mapped_column, relationship + +from app.core.models_core import CoreSchool, CoreUser +from app.modules.sport_competition.types_sport_competition import SportCategory +from app.types.sqlalchemy import Base + + +class CompetitionGroup(Base): + __tablename__ = "competition_group" + + id: Mapped[str] = mapped_column(primary_key=True) + name: Mapped[str] = mapped_column(unique=True) + description: Mapped[str | None] + + members: Mapped[list[CoreUser]] = relationship( + "CoreUser", + secondary="competition_annual_group_membership", + lazy="selectin", + viewonly=True, + init=False, + ) + + +class AnnualGroupMembership(Base): + __tablename__ = "competition_annual_group_membership" + + user_id: Mapped[str] = mapped_column(ForeignKey("core_user.id"), primary_key=True) + group_id: Mapped[str] = mapped_column( + ForeignKey("competition_group.id"), + primary_key=True, + ) + edition_id: Mapped[str] = mapped_column( + ForeignKey("competition_edition.id"), + primary_key=True, + ) + + +class SchoolExtension(Base): + __tablename__ = "competition_school_extension" + + school_id: Mapped[str] = mapped_column( + ForeignKey("core_school.id"), + primary_key=True, + ) + from_lyon: Mapped[bool] + activated: Mapped[bool] + + school: Mapped[CoreSchool] = relationship( + "CoreSchool", + lazy="selectin", + init=False, + ) + general_quota: Mapped["SchoolGeneralQuota"] = relationship( + "SchoolGeneralQuota", + lazy="selectin", + init=False, + ) + + +class SchoolGeneralQuota(Base): + __tablename__ = "competition_school_general_quota" + + school_id: Mapped[str] = mapped_column( + ForeignKey("competition_school_extension.school_id"), + primary_key=True, + ) + edition_id: Mapped[str] = mapped_column( + ForeignKey("competition_edition.id"), + primary_key=True, + ) + athlete_quota: Mapped[int | None] + cameraman_quota: Mapped[int | None] + pompom_quota: Mapped[int | None] + fanfare_quota: Mapped[int | None] + non_athlete_quota: Mapped[int | None] + + +class CompetitionEdition(Base): + __tablename__ = "competition_edition" + + id: Mapped[str] = mapped_column(primary_key=True) + year: Mapped[int] + name: Mapped[str] + start_date: Mapped[datetime] + end_date: Mapped[datetime] + activated: Mapped[bool] + + +class Sport(Base): + __tablename__ = "competition_sport" + + id: Mapped[str] = mapped_column(primary_key=True) + activated: Mapped[bool] + name: Mapped[str] + team_size: Mapped[int] + substitute_max: Mapped[int | None] + sport_category: Mapped[SportCategory | None] + + +class Team(Base): + __tablename__ = "competition_team" + + id: Mapped[str] = mapped_column(primary_key=True) + sport_id: Mapped[str] = mapped_column(ForeignKey("competition_sport.id")) + school_id: Mapped[str] = mapped_column( + ForeignKey("competition_school_extension.school_id"), + ) + edition_id: Mapped[str] = mapped_column( + ForeignKey("competition_edition.id"), + ) + name: Mapped[str] + captain_id: Mapped[str] = mapped_column(ForeignKey("core_user.id")) + + users: Mapped[list[CoreUser]] = relationship( + "CoreUser", + secondary="competition_participant", + lazy="selectin", + viewonly=True, + init=False, + ) + captain: Mapped[CoreUser] = relationship( + "CoreUser", + lazy="selectin", + init=False, + ) + + +class Participant(Base): + __tablename__ = "competition_participant" + + user_id: Mapped[str] = mapped_column(ForeignKey("core_user.id"), primary_key=True) + sport_id: Mapped[str] = mapped_column( + ForeignKey("competition_sport.id"), + primary_key=True, + ) + edition_id: Mapped[str] = mapped_column( + ForeignKey("competition_edition.id"), + primary_key=True, + ) + team_id: Mapped[str | None] = mapped_column(ForeignKey("competition_team.id")) + captain: Mapped[bool] + substitute: Mapped[bool] + license: Mapped[str] + + user: Mapped[CoreUser] = relationship( + "CoreUser", + lazy="selectin", + init=False, + ) + sport: Mapped[Sport] = relationship( + "Sport", + lazy="selectin", + init=False, + ) + team: Mapped[Team] = relationship( + "Team", + lazy="selectin", + init=False, + ) + + +class SchoolSportQuota(Base): + __tablename__ = "competition_sport_quota" + + sport_id: Mapped[str] = mapped_column( + ForeignKey("competition_sport.id"), + primary_key=True, + ) + school_id: Mapped[str] = mapped_column( + ForeignKey("competition_school_extension.school_id"), + primary_key=True, + ) + edition_id: Mapped[str] = mapped_column( + ForeignKey("competition_edition.id"), + primary_key=True, + ) + participant_quota: Mapped[int] + team_quota: Mapped[int] + + +class Match(Base): + __tablename__ = "competition_match" + + id: Mapped[str] = mapped_column(primary_key=True) + sport_id: Mapped[str] = mapped_column(ForeignKey("competition_sport.id")) + edition_id: Mapped[str] = mapped_column( + ForeignKey("competition_edition.id"), + ) + name: Mapped[str] + team1_id: Mapped[str] = mapped_column(ForeignKey("competition_team.id")) + team2_id: Mapped[str] = mapped_column(ForeignKey("competition_team.id")) + date: Mapped[datetime] + location: Mapped[str] + score_team1: Mapped[int | None] + score_team2: Mapped[int | None] + winner_id: Mapped[str | None] = mapped_column(ForeignKey("competition_team.id")) + + team1: Mapped[Team] = relationship( + "Team", + foreign_keys=[team1_id], + lazy="selectin", + init=False, + ) + team2: Mapped[Team] = relationship( + "Team", + foreign_keys=[team2_id], + lazy="selectin", + init=False, + ) diff --git a/app/modules/sport_competition/schemas_sport_competition.py b/app/modules/sport_competition/schemas_sport_competition.py new file mode 100644 index 000000000..d27707c68 --- /dev/null +++ b/app/modules/sport_competition/schemas_sport_competition.py @@ -0,0 +1,191 @@ +from datetime import datetime + +from pydantic import BaseModel + +from app.core import schemas_core +from app.core.models_core import CoreUser +from app.modules.sport_competition.types_sport_competition import SportCategory + + +class CompetitionEditionBase(BaseModel): + name: str + year: int + start_date: datetime + end_date: datetime + activated: bool = True + + +class CompetitionEdition(CompetitionEditionBase): + id: str + + +class SchoolExtension(BaseModel): + id: str + from_lyon: bool + activated: bool = True + + school: schemas_core.CoreSchool + general_quota: "SchoolGeneralQuota" + + +class SchoolExtensionEdit(BaseModel): + id: str + from_lyon: bool | None + activated: bool | None + + +class SchoolGeneralQuota(BaseModel): + school_id: str + edition: str + athlete_quota: int | None + cameraman_quota: int | None + pompom_quota: int | None + fanfare_quota: int | None + non_athlete_quota: int | None + + +class GroupMembership(BaseModel): + user: schemas_core.CoreUser + edition_id: str + + +class UserGroupMembership(BaseModel): + user_id: str + group_id: str + edition_id: str + + +class GroupBase(BaseModel): + name: str + description: str | None + + +class Group(GroupBase): + id: str + + +class GroupComplete(Group): + members: list[CoreUser] + + +class GroupEdit(BaseModel): + id: str + name: str | None + description: str | None + + +class SportBase(BaseModel): + name: str + team_size: int + substitute_max: int | None + sport_category: SportCategory | None + activated: bool = True + + +class Sport(SportBase): + id: str + + +class SportEdit(BaseModel): + id: str + name: str | None + team_size: int | None + substitute_max: int | None + sport_category: SportCategory | None + activated: bool | None + + +class QuotaInfo(BaseModel): + participant_quota: int + team_quota: int + + +class Quota(BaseModel): + school_id: str + sport_id: str + edition_id: str + participant_quota: int + team_quota: int + + +class QuotaEdit(BaseModel): + participant_quota: int | None + team_quota: int | None + + +class TeamBase(BaseModel): + name: str + edition_id: str + school_id: str + sport_id: str + captain_id: str + + +class TeamInfo(BaseModel): + name: str + school_id: str | None + sport_id: str | None + captain_id: str | None + + +class Team(TeamBase): + id: str + + +class TeamEdit(BaseModel): + id: str + name: str | None + captain_id: str | None + + +class TeamComplete(Team): + captain: schemas_core.CoreUser + users: list[schemas_core.CoreUser] + + +class ParticipantInfo(BaseModel): + license: str + substitute: bool = False + team_id: str | None + + +class Participant(BaseModel): + user_id: str + sport_id: str + edition_id: str + license: str + substitute: bool = False + team_id: str | None + + +class ParticipantEdit(BaseModel): + license: str | None + team_id: str | None + sport_id: str | None + user_id: str | None + substitute: bool | None + + +class ParticipantComplete(Participant): + user: schemas_core.CoreUser + sport: Sport + team: Team | None + + +class MatchBase(BaseModel): + edition_id: str + name: str + sport_id: str + team1_id: str + team2_id: str + date: datetime | None = None + location: str | None = None + score_team1: int | None = None + score_team2: int | None = None + winner_id: str | None = None + + +class Match(MatchBase): + id: str + team1: Team + team2: Team diff --git a/app/modules/sport_competition/types_sport_competition.py b/app/modules/sport_competition/types_sport_competition.py new file mode 100644 index 000000000..7b26d5d75 --- /dev/null +++ b/app/modules/sport_competition/types_sport_competition.py @@ -0,0 +1,42 @@ +from datetime import UTC, datetime +from enum import Enum + + +class InconsistentData(Exception): + def __init__(self, complement: str | None, *args: object) -> None: + super().__init__(*args) + self.complement = complement + + def __str__(self) -> str: + if self.complement: + return f"Inconsistent data: {self.complement}" + return "Inconsistent data" + + +class MultipleEditions(Exception): + def __str__(self) -> str: + return "Multiple activated editions" + + +class UnauthorizedAction(Exception): + def __str__(self) -> str: + return "Unauthorized action" + + +class SportCategory(str, Enum): + masculine = "masculine" + feminine = "feminine" + + +class CompetitionGroupType(str, Enum): + fanfaron = "c69ed623-1cf5-4769-acc7-ddbc6490fb07" + pompom = "22af0472-0a15-4f05-a670-fa02eda5e33f" + cameraman = "9cb535c7-ca19-4dbc-8b04-8b34017b5cff" + non_athlete = "84a9972c-56b0-4024-86ec-a284058e1cb1" + sport_manager = "a3f48a0b-ada1-4fe0-b987-f4170d8896c4" + schools_bds = "96f8ffb8-c585-4ca5-8360-dc3881f9f1e2" + competition_admin = "41b8735a-1702-4a9e-8486-a453572f6b5c" + + +class DefaultCoreData(str, Enum): + challenge_year = datetime.now(tz=UTC).year diff --git a/app/types/module.py b/app/types/module.py index e3e0c9d59..03dd891a6 100644 --- a/app/types/module.py +++ b/app/types/module.py @@ -3,7 +3,7 @@ from fastapi import APIRouter from sqlalchemy.ext.asyncio import AsyncSession -from app.core.groups.groups_type import GroupType +from app.core.groups.groups_type import AccountType, GroupType from app.core.payment import schemas_payment @@ -12,7 +12,8 @@ def __init__( self, root: str, tag: str, - default_allowed_groups_ids: list[GroupType], + default_allowed_groups_ids: list[GroupType] | None = None, + default_allowed_account_types: list[AccountType] | None = None, router: APIRouter | None = None, payment_callback: Callable[ [schemas_payment.CheckoutPayment, AsyncSession], @@ -29,6 +30,7 @@ def __init__( """ self.root = root self.default_allowed_groups_ids = default_allowed_groups_ids + self.default_allowed_account_types = default_allowed_account_types self.router = router or APIRouter(tags=[tag]) self.payment_callback: ( Callable[[schemas_payment.CheckoutPayment, AsyncSession], Awaitable[None]] diff --git a/app/utils/auth/providers.py b/app/utils/auth/providers.py index 786b49680..8fd00b7c5 100644 --- a/app/utils/auth/providers.py +++ b/app/utils/auth/providers.py @@ -4,7 +4,12 @@ import unidecode from app.core import models_core -from app.core.groups.groups_type import GroupType, get_ecl_groups +from app.core.groups.groups_type import ( + AccountType, + GroupType, + get_account_types_except_externals, + get_ecl_account_types, +) from app.types.floors_type import FloorsType from app.types.scopes_type import ScopeType from app.utils.tools import get_display_name, is_user_member_of_an_allowed_group @@ -31,6 +36,11 @@ class BaseAuthClient: # Restrict the authentication to this client to specific Hyperion groups. # When set to `None`, users from any group can use the auth client allowed_groups: list[GroupType] | None = None + # Restrict the authentication to this client to specific Hyperion account types. + # When set to `None`, users from any account type can use the auth client + allowed_account_types: list[AccountType] | None = ( + get_account_types_except_externals() + ) # redirect_uri should alway match the one provided by the client redirect_uri: list[str] # Sometimes, when the client is wrongly configured, it may return an incorrect return_uri. This may also be useful for debugging clients. @@ -44,9 +54,6 @@ class BaseAuthClient: # By setting this parameter to True, the userinfo will be added to the id_token. return_userinfo_in_id_token: bool = False - # Some clients may allow external users to authenticate - allow_external_users: bool = False - # Some clients may require to enable token introspection to validate access tokens. # We don't want to enable token introspection for all clients as it may be a security risk, allowing attackers to do token fishing. allow_token_introspection: bool = False @@ -107,7 +114,9 @@ class AppAuthClient(BaseAuthClient): # WARNING: to be able to use openid connect, `ScopeType.openid` should always be allowed allowed_scopes: set[ScopeType | str] = {ScopeType.API} - allow_external_users: bool = True + allowed_account_types: list[AccountType] | None = ( + None # No restriction on account types + ) class APIToolAuthClient(BaseAuthClient): @@ -119,7 +128,7 @@ class APIToolAuthClient(BaseAuthClient): class NextcloudAuthClient(BaseAuthClient): - allowed_groups: list[GroupType] | None = get_ecl_groups() + allowed_account_types: list[AccountType] | None = get_ecl_account_types() # For Nextcloud: # Required iss : the issuer value form .well-known (corresponding code : https://github.com/pulsejet/nextcloud-oidc-login/blob/0c072ecaa02579384bb5e10fbb9d219bbd96cfb8/3rdparty/jumbojett/openid-connect-php/src/OpenIDConnectClient.php#L1255) @@ -138,9 +147,8 @@ def get_userinfo(cls, user: models_core.CoreUser): nickname=user.nickname, ), # TODO: should we use group ids instead of names? It would be less human readable but would guarantee uniqueness. Question: are group names unique? - "groups": [ - group.name for group in user.groups - ], # We may want to filter which groups are provided as they won't always all be useful + # We may want to filter which groups are provided as they won't always all be useful + "groups": [group.name for group in user.groups] + [user.account_type.value], "email": user.email, "picture": f"https://hyperion.myecl.fr/users/{user.id}/profile-picture", "is_admin": is_user_member_of_an_allowed_group(user, [GroupType.admin]), @@ -150,7 +158,7 @@ def get_userinfo(cls, user: models_core.CoreUser): class PiwigoAuthClient(BaseAuthClient): # Restrict the authentication to this client to specific Hyperion groups. # When set to `None`, users from any group can use the auth client - allowed_groups: list[GroupType] | None = get_ecl_groups() + allowed_account_types: list[AccountType] | None = get_ecl_account_types() def get_userinfo(self, user: models_core.CoreUser) -> dict[str, Any]: """ @@ -174,7 +182,7 @@ def get_userinfo(self, user: models_core.CoreUser) -> dict[str, Any]: name=user.name, nickname=user.nickname, ), - "groups": [group.name for group in user.groups], + "groups": [group.name for group in user.groups] + [user.account_type.value], "email": user.email, } @@ -184,6 +192,8 @@ class HedgeDocAuthClient(BaseAuthClient): # See app.types.scopes_type.ScopeType for possible values allowed_scopes: set[ScopeType | str] = {ScopeType.profile} + allowed_account_types: list[AccountType] | None = get_ecl_account_types() + @classmethod def get_userinfo(cls, user: models_core.CoreUser): return { @@ -200,6 +210,8 @@ class WikijsAuthClient(BaseAuthClient): # See app.types.scopes_type.ScopeType for possible values allowed_scopes: set[ScopeType | str] = {ScopeType.openid, ScopeType.profile} + allowed_account_types: list[AccountType] | None = get_ecl_account_types() + @classmethod def get_userinfo(cls, user: models_core.CoreUser): return { @@ -210,7 +222,7 @@ def get_userinfo(cls, user: models_core.CoreUser): nickname=user.nickname, ), "email": user.email, - "groups": [group.name for group in user.groups], + "groups": [group.name for group in user.groups] + [user.account_type.value], } @@ -219,6 +231,8 @@ class SynapseAuthClient(BaseAuthClient): # See app.types.scopes_type.ScopeType for possible values allowed_scopes: set[ScopeType | str] = {ScopeType.openid, ScopeType.profile} + allowed_account_types: list[AccountType] | None = get_ecl_account_types() + # https://github.com/matrix-org/matrix-authentication-service/issues/2088 return_userinfo_in_id_token: bool = True @@ -253,6 +267,8 @@ class MinecraftAuthClient(BaseAuthClient): # See app.types.scopes_type.ScopeType for possible values allowed_scopes: set[ScopeType | str] = {ScopeType.profile} + allowed_account_types: list[AccountType] | None = get_ecl_account_types() + @classmethod def get_userinfo(cls, user: models_core.CoreUser): return { @@ -283,8 +299,6 @@ class OpenProjectAuthClient(BaseAuthClient): # See app.types.scopes_type.ScopeType for possible values allowed_scopes: set[ScopeType | str] = {ScopeType.openid, ScopeType.profile} - allow_external_users: bool = True - @classmethod def get_userinfo(cls, user: models_core.CoreUser): return { @@ -360,17 +374,21 @@ class RAIDRegisteringAuthClient(BaseAuthClient): # WARNING: to be able to use openid connect, `ScopeType.openid` should always be allowed allowed_scopes: set[ScopeType | str] = {ScopeType.API} - allow_external_users: bool = True + allowed_account_types: list[AccountType] | None = ( + None # No restriction on account types + ) class SiarnaqAuthClient(BaseAuthClient): allowed_scopes: set[ScopeType | str] = {ScopeType.API} - allow_external_users: bool = True + allowed_account_types: list[AccountType] | None = ( + None # No restriction on account types + ) class OverleafAuthClient(BaseAuthClient): - allowed_groups: list[GroupType] | None = get_ecl_groups() + allowed_account_types: list[AccountType] | None = get_ecl_account_types() @classmethod def get_userinfo(cls, user: models_core.CoreUser): diff --git a/app/utils/initialization.py b/app/utils/initialization.py index 5d9523cc7..e89fcf785 100644 --- a/app/utils/initialization.py +++ b/app/utils/initialization.py @@ -1,4 +1,4 @@ -from sqlalchemy import Connection, MetaData, select +from sqlalchemy import Connection, MetaData, select, update from sqlalchemy.engine import Engine, create_engine from sqlalchemy.exc import IntegrityError from sqlalchemy.orm import Session, selectinload @@ -23,20 +23,47 @@ def get_sync_db_engine(settings: Settings) -> Engine: return engine -def get_all_module_visibility_membership_sync( +def get_all_module_group_visibility_membership_sync( db: Session, ): """ Return the every module with their visibility """ - result = db.execute(select(models_core.ModuleVisibility)) + result = db.execute(select(models_core.ModuleGroupVisibility)) return result.unique().scalars().all() -def create_module_visibility_sync( - module_visibility: models_core.ModuleVisibility, +def get_all_module_account_type_visibility_membership_sync( db: Session, -) -> models_core.ModuleVisibility: +): + """ + Return the every module with their visibility + """ + result = db.execute(select(models_core.ModuleAccountTypeVisibility)) + return result.unique().scalars().all() + + +def create_module_group_visibility_sync( + module_visibility: models_core.ModuleGroupVisibility, + db: Session, +) -> models_core.ModuleGroupVisibility: + """ + Create a new module visibility in database and return it + """ + db.add(module_visibility) + try: + db.commit() + except IntegrityError as error: + db.rollback() + raise ValueError(error) from error + else: + return module_visibility + + +def create_module_account_type_visibility_sync( + module_visibility: models_core.ModuleAccountTypeVisibility, + db: Session, +) -> models_core.ModuleAccountTypeVisibility: """ Create a new module visibility in database and return it """ @@ -81,6 +108,68 @@ def create_group_sync( return group +def get_core_data_sync(schema: str, db: Session) -> models_core.CoreData | None: + """ + Return core data with schema from database + """ + result = db.execute( + select(models_core.CoreData).where(models_core.CoreData.schema == schema), + ) + return result.scalars().first() + + +def set_core_data_sync( + core_data: models_core.CoreData, + db: Session, +) -> models_core.CoreData: + """ + Set core data in database and return it + """ + db_data = get_core_data_sync(core_data.schema, db) + if db_data is not None: + db.execute( + update(models_core.CoreData) + .where(models_core.CoreData.schema == core_data.schema) + .values(data=core_data.data), + ) + else: + db.add(core_data) + try: + db.commit() + except IntegrityError: + db.rollback() + raise + else: + return core_data + + +def get_school_by_id_sync(school_id: str, db: Session) -> models_core.CoreSchool | None: + """ + Return group with id from database + """ + result = db.execute( + select(models_core.CoreSchool).where(models_core.CoreSchool.id == school_id), + ) + return result.scalars().first() + + +def create_school_sync( + school: models_core.CoreSchool, + db: Session, +) -> models_core.CoreSchool: + """ + Create a new group in database and return it + """ + db.add(school) + try: + db.commit() + except IntegrityError: + db.rollback() + raise + else: + return school + + def drop_db_sync(conn: Connection): """ Drop all tables in the database diff --git a/app/utils/tools.py b/app/utils/tools.py index d868d2975..ec5ab5844 100644 --- a/app/utils/tools.py +++ b/app/utils/tools.py @@ -19,7 +19,7 @@ from app.core import cruds_core, models_core, security from app.core.groups import cruds_groups -from app.core.groups.groups_type import GroupType +from app.core.groups.groups_type import AccountType, GroupType from app.core.models_core import CoreUser from app.core.users import cruds_users from app.types import core_data @@ -43,34 +43,13 @@ ) -def is_user_member_of_an_allowed_group( - user: CoreUser, - allowed_groups: list[GroupType] | list[str], -) -> bool: - """ - Test if the provided user is a member of at least one group from `allowed_groups`. - - When required groups can only be determined at runtime a list of strings (group UUIDs) can be provided. - This may be useful for modules that can be used multiple times like the loans module. - NOTE: if the provided string does not match a valid group, the function will return False - """ - # We can not directly test is group_id is in user.groups - # As user.groups is a list of CoreGroup and group_id is an UUID - for allowed_group in allowed_groups: - for user_group in user.groups: - if allowed_group == user_group.id: - # We know the user is a member of at least one allowed group - return True - return False - - def is_user_external( user: CoreUser, ): """ Users that are not members won't be able to use all features """ - return user.external is True + return user.account_type == AccountType.external def sort_user( @@ -111,12 +90,22 @@ def unaccent(s: str) -> str: return [user for user, _ in reversed(scored)] +def is_user_member_of_an_allowed_group( + user: models_core.CoreUser, + allowed_groups: list[str] | list[GroupType], +) -> bool: + """ + Check if the user is a member of the group. + """ + user_groups_id = [group.id for group in user.groups] + return any(group_id in user_groups_id for group_id in allowed_groups) + + async def is_group_id_valid(group_id: str, db: AsyncSession) -> bool: """ Test if the provided group_id is a valid group. The group may be - - an account type - a group type - a group created at runtime and stored in the database """ diff --git a/migrations/versions/26-account_types.py b/migrations/versions/26-account_types.py new file mode 100644 index 000000000..55a64d33a --- /dev/null +++ b/migrations/versions/26-account_types.py @@ -0,0 +1,694 @@ +"""account-types + +Create Date: 2024-11-08 14:00:56.598058 +""" + +import re +from collections.abc import Sequence +from enum import Enum +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from pytest_alembic import MigrationContext + +import sqlalchemy as sa +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = "53c163acf327" +down_revision: str | None = "c73c7b821728" +branch_labels: str | Sequence[str] | None = None +depends_on: str | Sequence[str] | None = None + + +class AccountType(Enum): + student = "student" + former_student = "former_student" + staff = "staff" + association = "association" + external = "external" + demo = "demo" + + +STUDENT_GROUP_ID = "39691052-2ae5-4e12-99d0-7a9f5f2b0136" +FORMER_STUDENT_GROUP_ID = "ab4c7503-41b3-11ee-8177-089798f1a4a5" +STAFF_GROUP_ID = "703056c4-be9d-475c-aa51-b7fc62a96aaa" +ASSOCIATION_GROUP_ID = "29751438-103c-42f2-b09b-33fbb20758a7" +EXTERNAL_GROUP_ID = "b1cd979e-ecc1-4bd0-bc2b-4dad2ba8cded" +DEMO_GROUP_ID = "ae4d1866-e7d9-4d7f-bee7-e0dda24d8dd8" + +DEMO_ID = "9bccbd61-2af3-4bd6-adb7-2a5e48756f66" +ECLAIR_ID = "e68d744f-472f-49e5-896f-662d83be7b9a" + +ECL_STAFF_REGEX = r"^[\w\-.]*@(enise\.)?ec-lyon\.fr$" +ECL_STUDENT_REGEX = r"^[\w\-.]*@((etu(-enise)?)|(ecl\d{2}))\.ec-lyon\.fr$" +ECL_FORMER_STUDENT_REGEX = r"^[\w\-.]*@centraliens-lyon\.net$" + +user_t = sa.Table( + "core_user", + sa.MetaData(), + sa.Column("id", sa.String), + sa.Column("email", sa.String), + sa.Column("account_type", sa.Enum(AccountType, name="accounttype")), +) +group_t = sa.Table( + "core_group", + sa.MetaData(), + sa.Column("id", sa.String), + sa.Column("name", sa.String), +) +membership_t = sa.Table( + "core_membership", + sa.MetaData(), + sa.Column("user_id", sa.String), + sa.Column("group_id", sa.String), +) +module_group_visibility_t = sa.Table( + "module_group_visibility", + sa.MetaData(), + sa.Column("root", sa.String), + sa.Column("allowed_group_id", sa.String), +) +module_account_type_visibility_t = sa.Table( + "module_account_type_visibility", + sa.MetaData(), + sa.Column("root", sa.String), + sa.Column("allowed_account_type", sa.Enum(AccountType, name="accounttype")), +) + +core_data_t = sa.Table( + "core_data", + sa.MetaData(), + sa.Column("schema", sa.String), + sa.Column("data", sa.String), +) + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + + op.rename_table("module_visibility", "module_group_visibility") + op.create_table( + "module_account_type_visibility", + sa.Column("root", sa.String(), nullable=False), + sa.Column( + "allowed_account_type", + sa.Enum(AccountType, name="accounttype"), + nullable=False, + ), + sa.PrimaryKeyConstraint("root", "allowed_account_type"), + ) + op.drop_column("core_user_unconfirmed", "account_type") + op.drop_column("core_user_unconfirmed", "external") + op.drop_column("core_user", "external") + op.add_column( + "core_user", + sa.Column( + "account_type", + sa.Enum( + AccountType, + name="accounttype", + ), + nullable=False, + server_default=AccountType.external.value, + ), + ) + + conn = op.get_bind() + conn.execute( + sa.update( + user_t, + ) + .where( + user_t.c.email.regexp_match(ECL_STUDENT_REGEX), + ) + .values( + account_type=AccountType.student, + ), + ) + conn.execute( + sa.update( + user_t, + ) + .where( + user_t.c.email.regexp_match(ECL_FORMER_STUDENT_REGEX), + ) + .values( + account_type=AccountType.former_student, + ), + ) + conn.execute( + sa.update( + user_t, + ) + .where( + user_t.c.email.regexp_match(ECL_STAFF_REGEX), + ) + .values( + account_type=AccountType.staff, + ), + ) + conn.execute( + sa.update( + user_t, + ) + .where(user_t.c.id == DEMO_ID) + .values( + account_type=AccountType.demo, + ), + ) + conn.execute( + sa.update( + user_t, + ) + .where(user_t.c.id == ECLAIR_ID) + .values( + account_type=AccountType.association, + ), + ) + + conn.execute( + sa.delete( + membership_t, + ).where( + (membership_t.c.group_id == STUDENT_GROUP_ID) + | (membership_t.c.group_id == FORMER_STUDENT_GROUP_ID) + | (membership_t.c.group_id == STAFF_GROUP_ID) + | (membership_t.c.group_id == ASSOCIATION_GROUP_ID) + | (membership_t.c.group_id == DEMO_GROUP_ID) + | (membership_t.c.group_id == EXTERNAL_GROUP_ID), + ), + ) + group_visibilities = conn.execute( + sa.select( + module_group_visibility_t, + ), + ) + group_to_account_type = { + STUDENT_GROUP_ID: AccountType.student, + FORMER_STUDENT_GROUP_ID: AccountType.former_student, + STAFF_GROUP_ID: AccountType.staff, + ASSOCIATION_GROUP_ID: AccountType.association, + DEMO_GROUP_ID: AccountType.demo, + EXTERNAL_GROUP_ID: AccountType.external, + } + + module_awareness = { + group_visibility.root for group_visibility in group_visibilities + } + + conn.execute( + sa.insert( + core_data_t, + ).values( + [ + { + "schema": "module_awareness", + "data": ",".join(module_awareness), + }, + ], + ), + ) + + for group_visibility in group_visibilities: + if group_visibility.allowed_group_id in group_to_account_type: + conn.execute( + sa.insert( + module_account_type_visibility_t, + ).values( + { + "root": group_visibility.root, + "allowed_account_type": group_to_account_type[ + group_visibility.allowed_group_id + ], + }, + ), + ) + + conn.execute( + sa.delete( + module_group_visibility_t, + ).where( + (module_group_visibility_t.c.allowed_group_id == STUDENT_GROUP_ID) + | (module_group_visibility_t.c.allowed_group_id == FORMER_STUDENT_GROUP_ID) + | (module_group_visibility_t.c.allowed_group_id == STAFF_GROUP_ID) + | (module_group_visibility_t.c.allowed_group_id == ASSOCIATION_GROUP_ID) + | (module_group_visibility_t.c.allowed_group_id == DEMO_GROUP_ID) + | (module_group_visibility_t.c.allowed_group_id == EXTERNAL_GROUP_ID), + ), + ) + + conn.execute( + sa.delete( + group_t, + ).where( + (group_t.c.id == STUDENT_GROUP_ID) + | (group_t.c.id == FORMER_STUDENT_GROUP_ID) + | (group_t.c.id == STAFF_GROUP_ID) + | (group_t.c.id == ASSOCIATION_GROUP_ID) + | (group_t.c.id == DEMO_GROUP_ID) + | (group_t.c.id == EXTERNAL_GROUP_ID), + ), + ) + + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.add_column( + "core_user_unconfirmed", + sa.Column( + "account_type", + sa.VARCHAR(), + server_default="external", + nullable=True, + ), + ) + op.add_column( + "core_user_unconfirmed", + sa.Column( + "external", + sa.BOOLEAN(), + server_default=sa.sql.true(), + nullable=False, + ), + ) + op.add_column( + "core_user", + sa.Column( + "external", + sa.Boolean(), + server_default=sa.sql.true(), + nullable=False, + ), + ) + + user_t2 = sa.Table( + "core_user", + sa.MetaData(), + sa.Column("id", sa.String), + sa.Column("external", sa.Boolean), + ) + + account_type_to_group = { + AccountType.student: STUDENT_GROUP_ID, + AccountType.former_student: FORMER_STUDENT_GROUP_ID, + AccountType.staff: STAFF_GROUP_ID, + AccountType.association: ASSOCIATION_GROUP_ID, + AccountType.demo: DEMO_GROUP_ID, + AccountType.external: EXTERNAL_GROUP_ID, + } + + conn = op.get_bind() + + conn.execute( + sa.insert( + group_t, + ).values( + [ + {"id": STUDENT_GROUP_ID, "name": "student"}, + {"id": FORMER_STUDENT_GROUP_ID, "name": "former_student"}, + {"id": STAFF_GROUP_ID, "name": "staff"}, + {"id": ASSOCIATION_GROUP_ID, "name": "association"}, + {"id": EXTERNAL_GROUP_ID, "name": "external"}, + {"id": DEMO_GROUP_ID, "name": "demo"}, + ], + ), + ) + + account_visibilities = conn.execute( + sa.select( + module_account_type_visibility_t, + ), + ) + + for account_visibility in account_visibilities: + conn.execute( + sa.insert( + module_group_visibility_t, + ).values( + { + "root": account_visibility.root, + "allowed_group_id": account_type_to_group[ + account_visibility.allowed_account_type + ], + }, + ), + ) + + def update_user(user_id: str, external: bool) -> None: + conn.execute( + sa.update( + user_t2, + ) + .where( + user_t2.c.id == user_id, + ) + .values( + {"external": external}, + ), + ) + + def insert_membership(user_id: str, group_id: str) -> None: + conn.execute( + sa.insert( + membership_t, + ).values( + {"user_id": user_id, "group_id": group_id}, + ), + ) + + users = conn.execute( + sa.select( + user_t2, + ), + ).fetchall() + + for user in users: + if re.match(ECL_STUDENT_REGEX, user.email): + insert_membership(user.id, STUDENT_GROUP_ID) + update_user(user.id, False) + elif re.match(ECL_FORMER_STUDENT_REGEX, user.email): + insert_membership(user.id, FORMER_STUDENT_GROUP_ID) + update_user(user.id, False) + elif re.match(ECL_STAFF_REGEX, user.email): + insert_membership(user.id, STAFF_GROUP_ID) + update_user(user.id, False) + else: + insert_membership(user.id, EXTERNAL_GROUP_ID) + update_user(user.id, True) + + demo = conn.execute( + sa.select( + user_t2, + ).where( + user_t2.c.id == DEMO_ID, + ), + ).fetchone() + if demo is not None: + insert_membership(DEMO_ID, DEMO_GROUP_ID) + update_user(DEMO_ID, False) + + eclair = conn.execute( + sa.select( + user_t2, + ).where( + user_t2.c.id == ECLAIR_ID, + ), + ).fetchone() + if eclair is not None: + insert_membership(ECLAIR_ID, ASSOCIATION_GROUP_ID) + update_user(ECLAIR_ID, False) + + op.rename_table("module_group_visibility", "module_visibility") + op.drop_table("module_account_type_visibility") + op.drop_column("core_user", "account_type") + sa.Enum(AccountType, name="accounttype").drop( + op.get_bind(), + ) + # ### end Alembic commands ### + + +def pre_test_upgrade( + alembic_runner: "MigrationContext", + alembic_connection: sa.Connection, +) -> None: + alembic_runner.insert_into( + "core_group", + { + "id": DEMO_GROUP_ID, + "name": "demo", + }, + ) + alembic_runner.insert_into( + "core_group", + { + "id": STUDENT_GROUP_ID, + "name": "student", + }, + ) + alembic_runner.insert_into( + "core_group", + { + "id": FORMER_STUDENT_GROUP_ID, + "name": "former_student", + }, + ) + alembic_runner.insert_into( + "core_group", + { + "id": STAFF_GROUP_ID, + "name": "staff", + }, + ) + alembic_runner.insert_into( + "core_group", + { + "id": ASSOCIATION_GROUP_ID, + "name": "association", + }, + ) + alembic_runner.insert_into( + "core_group", + { + "id": EXTERNAL_GROUP_ID, + "name": "external", + }, + ) + + alembic_runner.insert_into( + "core_user", + { + "id": DEMO_ID, + "email": "demo@myecl.fr", + "password_hash": "password_hash", + "name": "name", + "firstname": "firstname", + "nickname": "nickname", + "birthday": None, + "promo": 21, + "phone": "phone", + "floor": "Autre", + "created_on": None, + "external": False, + }, + ) + + alembic_runner.insert_into( + "core_membership", + { + "user_id": DEMO_ID, + "group_id": DEMO_GROUP_ID, + }, + ) + alembic_runner.insert_into( + "core_user", + { + "id": "5c5f9fdd-bedd-449b-91a3-f3437b95e36b", + "email": "test@etu.ec-lyon.fr", + "password_hash": "password_hash", + "name": "name", + "firstname": "firstname", + "nickname": "nickname", + "birthday": None, + "promo": 21, + "phone": "phone", + "floor": "Autre", + "created_on": None, + "external": False, + }, + ) + + alembic_runner.insert_into( + "core_membership", + { + "user_id": "5c5f9fdd-bedd-449b-91a3-f3437b95e36b", + "group_id": STUDENT_GROUP_ID, + }, + ) + alembic_runner.insert_into( + "core_user", + { + "id": "3d565333-aae1-4d70-a645-e6a3bc3ac198", + "email": "test@etu-enise.ec-lyon.fr", + "password_hash": "password_hash", + "name": "name", + "firstname": "firstname", + "nickname": "nickname", + "birthday": None, + "promo": 21, + "phone": "phone", + "floor": "Autre", + "created_on": None, + "external": False, + }, + ) + alembic_runner.insert_into( + "core_membership", + { + "user_id": "3d565333-aae1-4d70-a645-e6a3bc3ac198", + "group_id": EXTERNAL_GROUP_ID, + }, + ) + alembic_runner.insert_into( + "core_user", + { + "id": "1dd04834-700d-4960-ba68-2beab1fa8663", + "email": "test2@gmail.com", + "password_hash": "password_hash", + "name": "name", + "firstname": "firstname", + "nickname": "nickname", + "birthday": None, + "promo": 21, + "phone": "phone", + "floor": "Autre", + "created_on": None, + "external": False, + }, + ) + alembic_runner.insert_into( + "core_membership", + { + "user_id": "1dd04834-700d-4960-ba68-2beab1fa8663", + "group_id": EXTERNAL_GROUP_ID, + }, + ) + alembic_runner.insert_into( + "core_user", + { + "id": ECLAIR_ID, + "email": "eclair@myecl.fr", + "password_hash": "password_hash", + "name": "name", + "firstname": "firstname", + "nickname": "nickname", + "birthday": None, + "promo": 21, + "phone": "phone", + "floor": "Autre", + "created_on": None, + "external": False, + }, + ) + alembic_runner.insert_into( + "core_membership", + { + "user_id": ECLAIR_ID, + "group_id": ASSOCIATION_GROUP_ID, + }, + ) + alembic_runner.insert_into( + "core_user", + { + "id": "64b83ef5-4e13-4827-9f4f-ab4ce1244f4c", + "email": "test2@ec-lyon.fr", + "password_hash": "password_hash", + "name": "name", + "firstname": "firstname", + "nickname": "nickname", + "birthday": None, + "promo": 21, + "phone": "phone", + "floor": "Autre", + "created_on": None, + "external": False, + }, + ) + alembic_runner.insert_into( + "core_membership", + { + "user_id": "64b83ef5-4e13-4827-9f4f-ab4ce1244f4c", + "group_id": STAFF_GROUP_ID, + }, + ) + alembic_runner.insert_into( + "core_user", + { + "id": "f1265ac3-d3cc-4ce5-97f5-82d25b5063f2", + "email": "test2@enise.ec-lyon.fr", + "password_hash": "password_hash", + "name": "name", + "firstname": "firstname", + "nickname": "nickname", + "birthday": None, + "promo": 21, + "phone": "phone", + "floor": "Autre", + "created_on": None, + "external": False, + }, + ) + alembic_runner.insert_into( + "core_membership", + { + "user_id": "f1265ac3-d3cc-4ce5-97f5-82d25b5063f2", + "group_id": STAFF_GROUP_ID, + }, + ) + alembic_runner.insert_into( + "core_user", + { + "id": "b059f348-3678-4d7f-a90e-b1663d60de37", + "email": "test2@centraliens-lyon.net", + "password_hash": "password_hash", + "name": "name", + "firstname": "firstname", + "nickname": "nickname", + "birthday": None, + "promo": 21, + "phone": "phone", + "floor": "Autre", + "created_on": None, + "external": False, + }, + ) + alembic_runner.insert_into( + "core_membership", + { + "user_id": "b059f348-3678-4d7f-a90e-b1663d60de37", + "group_id": FORMER_STUDENT_GROUP_ID, + }, + ) + + +def test_upgrade( + alembic_runner: "MigrationContext", + alembic_connection: sa.Connection, +) -> None: + users = alembic_connection.execute( + sa.text("SELECT id, account_type from core_user"), + ).fetchall() + + duos = { + "1dd04834-700d-4960-ba68-2beab1fa8663": "external", + "3d565333-aae1-4d70-a645-e6a3bc3ac198": "student", + "5c5f9fdd-bedd-449b-91a3-f3437b95e36b": "student", + "64b83ef5-4e13-4827-9f4f-ab4ce1244f4c": "staff", + "f1265ac3-d3cc-4ce5-97f5-82d25b5063f2": "staff", + "b059f348-3678-4d7f-a90e-b1663d60de37": "former_student", + DEMO_ID: "demo", + ECLAIR_ID: "association", + } + + for row in users: + account = duos.get(row[0]) + if account is not None: + assert row[1] == account + + groups = alembic_connection.execute( + sa.text("SELECT id from core_group"), + ).fetchall() + + group_ids = { + STUDENT_GROUP_ID, + FORMER_STUDENT_GROUP_ID, + STAFF_GROUP_ID, + ASSOCIATION_GROUP_ID, + EXTERNAL_GROUP_ID, + DEMO_GROUP_ID, + } + + for group_id in group_ids: + assert (group_id,) not in groups diff --git a/migrations/versions/27-schools.py b/migrations/versions/27-schools.py new file mode 100644 index 000000000..f54f3b77e --- /dev/null +++ b/migrations/versions/27-schools.py @@ -0,0 +1,266 @@ +"""schools + +Create Date: 2024-10-26 19:04:51.089828 +""" + +import enum +from collections.abc import Sequence +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from pytest_alembic import MigrationContext + +import sqlalchemy as sa +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = "a1e6e8b52103" +down_revision: str | None = "53c163acf327" +branch_labels: str | Sequence[str] | None = None +depends_on: str | Sequence[str] | None = None + +centrale_regex = r"^[\w\-.]*@(etu(-enise)?\.)?ec-lyon\.fr$" + + +class AccountType(enum.Enum): + student = "student" + former_student = "former_student" + staff = "staff" + association = "association" + external = "external" + demo = "demo" + + +class AccountType2(enum.Enum): + student = "student" + former_student = "former_student" + staff = "staff" + association = "association" + external = "external" + other_school_student = "other_school_student" + demo = "demo" + + +DEMO_ID = "9bccbd61-2af3-4bd6-adb7-2a5e48756f66" +ECLAIR_ID = "e68d744f-472f-49e5-896f-662d83be7b9a" + +ECL_STAFF_REGEX = r"^[\w\-.]*@(enise\.)?ec-lyon\.fr$" +ECL_STUDENT_REGEX = r"^[\w\-.]*@((etu(-enise)?)|(ecl\d{2}))\.ec-lyon\.fr$" +ECL_FORMER_STUDENT_REGEX = r"^[\w\-.]*@centraliens-lyon\.net$" + +school_table = sa.Table( + "core_school", + sa.MetaData(), + sa.Column("id", sa.String(), nullable=False), + sa.Column("name", sa.String(), nullable=False), + sa.Column("email_regex", sa.String(), nullable=False), +) +user_table = sa.Table( + "core_user", + sa.MetaData(), + sa.Column("id", sa.String(), nullable=False), + sa.Column("email", sa.String(), nullable=False), + sa.Column("account_type", sa.Enum(AccountType, name="accounttype"), nullable=False), + sa.Column("school_id", sa.String(), nullable=False), +) + +visibility_table = sa.Table( + "module_account_type_visibility", + sa.MetaData(), + sa.Column("root", sa.String(), nullable=False), + sa.Column( + "allowed_account_type", + sa.Enum(AccountType, name="accounttype"), + nullable=False, + ), +) + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.create_table( + "core_school", + sa.Column("id", sa.String(), nullable=False, index=True), + sa.Column("name", sa.String(), nullable=False, index=True, unique=True), + sa.Column("email_regex", sa.String(), nullable=False), + sa.PrimaryKeyConstraint("id"), + ) + + conn = op.get_bind() + + conn.execute( + sa.insert(school_table).values( + id="dce19aa2-8863-4c93-861e-fb7be8f610ed", + name="no_school", + email_regex=".*", + ), + ) + + users = conn.execute( + sa.select(user_table.c.id, user_table.c.account_type), + ).fetchall() + + visibilities = conn.execute( + sa.select(visibility_table.c.root, visibility_table.c.allowed_account_type), + ).fetchall() + + with op.batch_alter_table("core_user") as batch_op: + batch_op.add_column( + sa.Column( + "school_id", + sa.String(), + nullable=False, + server_default="dce19aa2-8863-4c93-861e-fb7be8f610ed", + ), + ) + batch_op.create_foreign_key( + "core_user_school_id", + "core_school", + ["school_id"], + ["id"], + ) + batch_op.drop_column("account_type") + + op.drop_table("module_account_type_visibility") + + sa.Enum(AccountType, name="accounttype").drop( + conn, + ) + + op.create_table( + "module_account_type_visibility", + sa.Column("root", sa.String(), nullable=False), + sa.Column( + "allowed_account_type", + sa.Enum( + AccountType2, + name="accounttype", + ), + nullable=False, + ), + sa.PrimaryKeyConstraint("root", "allowed_account_type"), + ) + + with op.batch_alter_table("core_user") as batch_op: + batch_op.add_column( + sa.Column( + "account_type", + sa.Enum(AccountType2, name="accounttype"), + nullable=False, + server_default="external", + ), + ) + + conn.execute( + sa.insert(school_table).values( + id="d9772da7-1142-4002-8b86-b694b431dfed", + name="Centrale Lyon", + email_regex=centrale_regex, + ), + ) + + for user in users: + conn.execute( + sa.update(user_table) + .where(user_table.c.id == user.id) + .values( + account_type=user.account_type, + school_id="d9772da7-1142-4002-8b86-b694b431dfed" + if user.account_type != AccountType.external + else "dce19aa2-8863-4c93-861e-fb7be8f610ed", + ), + ) + + for visibility in visibilities: + conn.execute( + sa.insert(visibility_table).values( + root=visibility.root, + allowed_account_type=visibility.allowed_account_type, + ), + ) + + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + + conn = op.get_bind() + + users = conn.execute( + sa.select(user_table.c.id, user_table.c.account_type), + ).fetchall() + + visibilities = conn.execute( + sa.select(visibility_table.c.root, visibility_table.c.allowed_account_type), + ).fetchall() + + with op.batch_alter_table("core_user") as batch_op: + batch_op.drop_constraint("core_user_school_id", type_="foreignkey") + batch_op.drop_column("school_id") + batch_op.drop_column("account_type") + + op.drop_table("module_account_type_visibility") + + sa.Enum(AccountType2, name="accounttype").drop( + op.get_bind(), + ) + + op.create_table( + "module_account_type_visibility", + sa.Column("root", sa.String(), nullable=False), + sa.Column( + "allowed_account_type", + sa.Enum(AccountType, name="accounttype"), + nullable=False, + ), + sa.PrimaryKeyConstraint("root", "allowed_account_type"), + ) + + with op.batch_alter_table("core_user") as batch_op: + batch_op.add_column( + sa.Column( + "account_type", + sa.Enum(AccountType, name="accounttype"), + nullable=False, + server_default="external", + ), + ) + + op.drop_index(op.f("ix_core_school_name"), table_name="core_school") + op.drop_index(op.f("ix_core_school_id"), table_name="core_school") + op.drop_table("core_school") + + for user in users: + conn.execute( + sa.update(user_table) + .where(user_table.c.id == user.id) + .values( + account_type=user.account_type + if user.account_type != AccountType2.other_school_student + else AccountType.external, + ), + ) + + for visibility in visibilities: + conn.execute( + sa.insert(visibility_table).values( + root=visibility.root, + allowed_account_type=visibility.allowed_account_type, + ), + ) + # ### end Alembic commands ### + + +def pre_test_upgrade( + alembic_runner: "MigrationContext", + alembic_connection: sa.Connection, +) -> None: + pass + + +def test_upgrade( + alembic_runner: "MigrationContext", + alembic_connection: sa.Connection, +) -> None: + pass diff --git a/migrations/versions/28-sport_competition.py b/migrations/versions/28-sport_competition.py new file mode 100644 index 000000000..99fd7da91 --- /dev/null +++ b/migrations/versions/28-sport_competition.py @@ -0,0 +1,171 @@ +"""sport_competition + +Create Date: 2024-11-28 04:47:43.954804 +""" + +from collections.abc import Sequence +from enum import Enum +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from pytest_alembic import MigrationContext + +import sqlalchemy as sa +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = "2bd545ef13b7" +down_revision: str | None = "a1e6e8b52103" +branch_labels: str | Sequence[str] | None = None +depends_on: str | Sequence[str] | None = None + + +class SportCategory(Enum): + masculine = "masculine" + feminine = "feminine" + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.create_table( + "competition_edition", + sa.Column("id", sa.String(), nullable=False), + sa.Column("year", sa.Integer(), nullable=False), + sa.Column("name", sa.String(), nullable=False), + sa.Column("start_date", sa.String(), nullable=False), + sa.Column("end_date", sa.String(), nullable=False), + sa.Column("activated", sa.Boolean(), nullable=False), + sa.PrimaryKeyConstraint("id"), + ) + op.create_table( + "competition_group", + sa.Column("id", sa.String(), nullable=False), + sa.Column("name", sa.String(), nullable=False), + sa.Column("description", sa.String(), nullable=True), + sa.PrimaryKeyConstraint("id"), + ) + op.create_table( + "competition_sport", + sa.Column("id", sa.String(), nullable=False), + sa.Column("activated", sa.Boolean(), nullable=False), + sa.Column("name", sa.String(), nullable=False), + sa.Column("team_size", sa.Integer(), nullable=False), + sa.Column("substitute_max", sa.Integer(), nullable=True), + sa.Column( + "sport_category", + sa.Enum(SportCategory, name="sportcategory"), + nullable=True, + ), + sa.PrimaryKeyConstraint("id"), + ) + op.create_table( + "competition_school_extension", + sa.Column("school_id", sa.String(), nullable=False), + sa.Column("from_lyon", sa.Boolean(), nullable=False), + sa.Column("activated", sa.Boolean(), nullable=False), + sa.ForeignKeyConstraint(["school_id"], ["core_school.id"]), + sa.PrimaryKeyConstraint("school_id"), + ) + op.create_table( + "competition_school_general_quota", + sa.Column("school_id", sa.String(), nullable=False), + sa.Column("edition_id", sa.String(), nullable=False), + sa.Column("athlete_quota", sa.Integer(), nullable=True), + sa.Column("cameraman_quota", sa.Integer(), nullable=True), + sa.Column("pompom_quota", sa.Integer(), nullable=True), + sa.Column("fanfare_quota", sa.Integer(), nullable=True), + sa.Column("non_athlete_quota", sa.Integer(), nullable=True), + sa.ForeignKeyConstraint(["edition_id"], ["competition_edition.id"]), + sa.ForeignKeyConstraint( + ["school_id"], + ["competition_school_extension.school_id"], + ), + sa.PrimaryKeyConstraint("school_id", "edition_id"), + ) + op.create_table( + "competition_sport_quota", + sa.Column("sport_id", sa.String(), nullable=False), + sa.Column("school_id", sa.String(), nullable=False), + sa.Column("edition_id", sa.String(), nullable=False), + sa.Column("participant_quota", sa.Integer(), nullable=False), + sa.Column("team_quota", sa.Integer(), nullable=False), + sa.ForeignKeyConstraint(["edition_id"], ["competition_edition.id"]), + sa.ForeignKeyConstraint( + ["school_id"], + ["competition_school_extension.school_id"], + ), + sa.ForeignKeyConstraint(["sport_id"], ["competition_sport.id"]), + sa.PrimaryKeyConstraint("sport_id", "school_id", "edition_id"), + ) + op.create_table( + "competition_annual_group_membership", + sa.Column("user_id", sa.String(), nullable=False), + sa.Column("group_id", sa.String(), nullable=False), + sa.Column("edition_id", sa.String(), nullable=False), + sa.ForeignKeyConstraint(["edition_id"], ["competition_edition.id"]), + sa.ForeignKeyConstraint(["group_id"], ["competition_group.id"]), + sa.ForeignKeyConstraint(["user_id"], ["core_user.id"]), + sa.PrimaryKeyConstraint("user_id", "group_id", "edition_id"), + ) + op.create_table( + "competition_team", + sa.Column("id", sa.String(), nullable=False), + sa.Column("sport_id", sa.String(), nullable=False), + sa.Column("school_id", sa.String(), nullable=False), + sa.Column("edition_id", sa.String(), nullable=False), + sa.Column("name", sa.String(), nullable=False), + sa.Column("captain_id", sa.String(), nullable=False), + sa.ForeignKeyConstraint(["captain_id"], ["core_user.id"]), + sa.ForeignKeyConstraint(["edition_id"], ["competition_edition.id"]), + sa.ForeignKeyConstraint( + ["school_id"], + ["competition_school_extension.school_id"], + ), + sa.ForeignKeyConstraint(["sport_id"], ["competition_sport.id"]), + sa.PrimaryKeyConstraint("id"), + ) + op.create_table( + "competition_participant", + sa.Column("user_id", sa.String(), nullable=False), + sa.Column("sport_id", sa.String(), nullable=False), + sa.Column("edition_id", sa.String(), nullable=False), + sa.Column("team_id", sa.String(), nullable=True), + sa.Column("captain", sa.Boolean(), nullable=False), + sa.Column("substitute", sa.Boolean(), nullable=False), + sa.Column("license", sa.String(), nullable=False), + sa.ForeignKeyConstraint(["edition_id"], ["competition_edition.id"]), + sa.ForeignKeyConstraint(["sport_id"], ["competition_sport.id"]), + sa.ForeignKeyConstraint(["team_id"], ["competition_team.id"]), + sa.ForeignKeyConstraint(["user_id"], ["core_user.id"]), + sa.PrimaryKeyConstraint("user_id", "sport_id", "edition_id"), + ) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_table("competition_participant") + op.drop_table("competition_team") + op.drop_table("competition_annual_group_membership") + op.drop_table("competition_sport_quota") + op.drop_table("competition_school_general_quota") + op.drop_table("competition_school_extension") + op.drop_table("competition_sport") + op.drop_table("competition_group") + op.drop_table("competition_edition") + sa.Enum(name="sportcategory").drop(op.get_bind()) + # ### end Alembic commands ### + + +def pre_test_upgrade( + alembic_runner: "MigrationContext", + alembic_connection: sa.Connection, +) -> None: + pass + + +def test_upgrade( + alembic_runner: "MigrationContext", + alembic_connection: sa.Connection, +) -> None: + pass diff --git a/tests/commons.py b/tests/commons.py index c271b45a8..123d882b5 100644 --- a/tests/commons.py +++ b/tests/commons.py @@ -14,9 +14,10 @@ from app.core.auth import schemas_auth from app.core.config import Settings from app.core.groups import cruds_groups -from app.core.groups.groups_type import GroupType +from app.core.groups.groups_type import AccountType, GroupType from app.core.payment import cruds_payment, models_payment, schemas_payment from app.core.payment.payment_tool import PaymentTool +from app.core.schools.schools_type import SchoolType from app.core.users import cruds_users from app.dependencies import get_settings from app.types.exceptions import RedisConnectionError @@ -108,13 +109,14 @@ def change_redis_client_status(activated: bool) -> None: async def create_user_with_groups( groups: list[GroupType], + account_type: AccountType = AccountType.external, + school_id: SchoolType = SchoolType.centrale_lyon, user_id: str | None = None, email: str | None = None, password: str | None = None, name: str | None = None, firstname: str | None = None, floor: FloorsType | None = None, - external: bool = False, ) -> models_core.CoreUser: """ Add a dummy user to the database @@ -129,11 +131,12 @@ async def create_user_with_groups( user = models_core.CoreUser( id=user_id, email=email or (get_random_string() + "@etu.ec-lyon.fr"), + school_id=school_id, password_hash=password_hash, name=name or get_random_string(), firstname=firstname or get_random_string(), floor=floor, - external=external, + account_type=account_type, nickname=None, birthday=None, promo=None, diff --git a/tests/test_PH.py b/tests/test_PH.py index 911854f5f..247c485d7 100644 --- a/tests/test_PH.py +++ b/tests/test_PH.py @@ -6,7 +6,7 @@ from fastapi.testclient import TestClient from app.core import models_core -from app.core.groups.groups_type import GroupType +from app.core.groups.groups_type import AccountType, GroupType from app.modules.ph import models_ph from tests.commons import ( add_object_to_db, @@ -33,7 +33,10 @@ async def init_objects() -> None: token_ph = create_api_access_token(ph_user_ph) global ph_user_simple - ph_user_simple = await create_user_with_groups([GroupType.student]) + ph_user_simple = await create_user_with_groups( + [], + AccountType.student, + ) global token_simple token_simple = create_api_access_token(ph_user_simple) diff --git a/tests/test_advert.py b/tests/test_advert.py index 998b5ce40..0f86a709f 100644 --- a/tests/test_advert.py +++ b/tests/test_advert.py @@ -6,7 +6,7 @@ from fastapi.testclient import TestClient from app.core import models_core -from app.core.groups.groups_type import GroupType +from app.core.groups.groups_type import AccountType, GroupType from app.modules.advert import models_advert from tests.commons import ( add_object_to_db, @@ -41,13 +41,19 @@ async def init_objects() -> None: ) await add_object_to_db(advertiser) - user_advertiser = await create_user_with_groups([GroupType.student, GroupType.CAA]) + user_advertiser = await create_user_with_groups( + [GroupType.CAA], + AccountType.student, + ) global token_advertiser token_advertiser = create_api_access_token(user_advertiser) global user_simple - user_simple = await create_user_with_groups([GroupType.student]) + user_simple = await create_user_with_groups( + [], + AccountType.student, + ) global token_simple token_simple = create_api_access_token(user_simple) diff --git a/tests/test_amap.py b/tests/test_amap.py index 15dd14cc2..1eaa93562 100644 --- a/tests/test_amap.py +++ b/tests/test_amap.py @@ -5,7 +5,7 @@ from fastapi.testclient import TestClient from app.core import models_core -from app.core.groups.groups_type import GroupType +from app.core.groups.groups_type import AccountType, GroupType from app.modules.amap import models_amap from app.modules.amap.types_amap import AmapSlotType, DeliveryStatusType from tests.commons import ( @@ -39,8 +39,11 @@ async def init_objects() -> None: order, \ deletable_order_by_admin - amap_user = await create_user_with_groups([GroupType.amap]) - student_user = await create_user_with_groups([GroupType.student]) + amap_user = await create_user_with_groups([GroupType.amap], AccountType.student) + student_user = await create_user_with_groups( + [], + AccountType.student, + ) product = models_amap.Product( id=str(uuid.uuid4()), diff --git a/tests/test_auth.py b/tests/test_auth.py index 0c3f0c938..284b1ee62 100644 --- a/tests/test_auth.py +++ b/tests/test_auth.py @@ -8,7 +8,7 @@ from app.core import models_core from app.core.auth import models_auth -from app.core.groups.groups_type import GroupType +from app.core.groups.groups_type import AccountType, GroupType from tests.commons import ( add_object_to_db, create_api_access_token, @@ -30,13 +30,15 @@ async def init_objects() -> None: global user user = await create_user_with_groups( groups=[], + account_type=AccountType.student, email="email@myecl.fr", password="azerty", ) global ecl_user ecl_user = await create_user_with_groups( - groups=[GroupType.student], + groups=[GroupType.eclair], + account_type=AccountType.student, email="email@etu.ec-lyon.fr", password="azerty", ) @@ -46,7 +48,6 @@ async def init_objects() -> None: groups=[], email="external@myecl.fr", password="azerty", - external=True, ) global access_token @@ -474,7 +475,7 @@ def test_authorization_code_flow_with_auth_client_restricting_allowed_groups_and ) -> None: # For an user that is a member of a required group # data_with_invalid_client_id = { - "client_id": "AcceptingOnlyECLUsersAuthClient", + "client_id": "RestrictingUsersGroupsAuthClient", "client_secret": "secret", "redirect_uri": "http://127.0.0.1:8000/docs", "response_type": "code", @@ -501,13 +502,13 @@ def test_authorization_code_flow_with_auth_client_restricting_allowed_groups_and ) -> None: # For an user that is not a member of a required group # data_with_invalid_client_id = { - "client_id": "AcceptingOnlyECLUsersAuthClient", + "client_id": "RestrictingUsersGroupsAuthClient", "client_secret": "secret", "redirect_uri": "http://127.0.0.1:8000/docs", "response_type": "code", "scope": "API openid", "state": "azerty", - "email": "email@myecl.fr", + "email": "external@myecl.fr", "password": "azerty", } response = client.post( @@ -546,7 +547,7 @@ def test_authorization_code_flow_with_auth_client_restricting_external_users_and assert response.next_request is not None assert str(response.next_request.url).endswith( - "calypsso/error?message=External+users+are+not+allowed", + "calypsso/error?message=User+account+type+is+not+allowed", ) diff --git a/tests/test_booking.py b/tests/test_booking.py index 77ed72646..132bc411d 100644 --- a/tests/test_booking.py +++ b/tests/test_booking.py @@ -5,7 +5,7 @@ from fastapi.testclient import TestClient from app.core import models_core -from app.core.groups.groups_type import GroupType +from app.core.groups.groups_type import AccountType, GroupType from app.modules.booking import models_booking from tests.commons import ( add_object_to_db, @@ -39,13 +39,16 @@ async def init_objects() -> None: token_admin = create_api_access_token(admin_user) global manager_user - manager_user = await create_user_with_groups([GroupType.student, GroupType.BDE]) + manager_user = await create_user_with_groups( + [GroupType.BDE], + AccountType.student, + ) global token_manager token_manager = create_api_access_token(manager_user) global simple_user - simple_user = await create_user_with_groups([GroupType.student]) + simple_user = await create_user_with_groups([], AccountType.student) global token_simple token_simple = create_api_access_token(simple_user) diff --git a/tests/test_calendar.py b/tests/test_calendar.py index 1dcfac35e..d3b2d477f 100644 --- a/tests/test_calendar.py +++ b/tests/test_calendar.py @@ -5,7 +5,7 @@ from fastapi.testclient import TestClient from app.core import models_core -from app.core.groups.groups_type import GroupType +from app.core.groups.groups_type import AccountType, GroupType from app.modules.booking.types_booking import Decision from app.modules.calendar import models_calendar from app.modules.calendar.types_calendar import CalendarEventType @@ -27,14 +27,15 @@ async def init_objects() -> None: global calendar_user_bde calendar_user_bde = await create_user_with_groups( - [GroupType.student, GroupType.BDE], + [GroupType.BDE], + AccountType.student, ) global token_bde token_bde = create_api_access_token(calendar_user_bde) global calendar_user_simple - calendar_user_simple = await create_user_with_groups([GroupType.student]) + calendar_user_simple = await create_user_with_groups([], AccountType.student) global token_simple token_simple = create_api_access_token(calendar_user_simple) diff --git a/tests/test_campaign.py b/tests/test_campaign.py index b4fdaa911..e51ef6a16 100644 --- a/tests/test_campaign.py +++ b/tests/test_campaign.py @@ -5,7 +5,7 @@ from fastapi.testclient import TestClient from app.core import models_core -from app.core.groups.groups_type import GroupType +from app.core.groups.groups_type import AccountType, GroupType from app.modules.campaign import models_campaign from app.modules.campaign.types_campaign import ListType from tests.commons import ( @@ -28,8 +28,11 @@ async def init_objects() -> None: global CAA_user, AE_user - CAA_user = await create_user_with_groups([GroupType.CAA, GroupType.AE]) - AE_user = await create_user_with_groups([GroupType.AE]) + CAA_user = await create_user_with_groups( + [GroupType.CAA, GroupType.AE], + AccountType.student, + ) + AE_user = await create_user_with_groups([GroupType.AE], AccountType.student) global section global campaign_list diff --git a/tests/test_cdr.py b/tests/test_cdr.py index 07c19d88f..221890afb 100644 --- a/tests/test_cdr.py +++ b/tests/test_cdr.py @@ -6,7 +6,7 @@ from pytest_mock import MockerFixture from app.core import models_core -from app.core.groups.groups_type import GroupType +from app.core.groups.groups_type import AccountType, GroupType from app.modules.cdr import models_cdr from app.modules.cdr.types_cdr import ( CdrStatus, @@ -71,7 +71,8 @@ async def init_objects(): global cdr_admin cdr_admin = await create_user_with_groups( - [GroupType.student, GroupType.admin_cdr], + [GroupType.admin_cdr], + AccountType.student, email="cdr_admin@etu.ec-lyon.fr", ) @@ -79,13 +80,19 @@ async def init_objects(): token_admin = create_api_access_token(cdr_admin) global cdr_bde - cdr_bde = await create_user_with_groups([GroupType.student, GroupType.BDE]) + cdr_bde = await create_user_with_groups( + [GroupType.BDE], + AccountType.student, + ) global token_bde token_bde = create_api_access_token(cdr_bde) global cdr_user - cdr_user = await create_user_with_groups([GroupType.student]) + cdr_user = await create_user_with_groups( + [], + AccountType.student, + ) global token_user token_user = create_api_access_token(cdr_user) @@ -238,7 +245,8 @@ async def init_objects(): await add_object_to_db(unused_curriculum) cdr_user_with_curriculum_without_purchase = await create_user_with_groups( - [GroupType.student], + [], + AccountType.student, ) curriculum_membership = models_cdr.CurriculumMembership( user_id=cdr_user_with_curriculum_without_purchase.id, @@ -249,7 +257,8 @@ async def init_objects(): global cdr_user_with_curriculum_with_non_validated_purchase cdr_user_with_curriculum_with_non_validated_purchase = ( await create_user_with_groups( - [GroupType.student], + [], + AccountType.student, ) ) curriculum_membership_for_user_with_curriculum_with_non_validated_purchase = ( diff --git a/tests/test_cdr_result.py b/tests/test_cdr_result.py index 6798fe027..dc54bc018 100644 --- a/tests/test_cdr_result.py +++ b/tests/test_cdr_result.py @@ -4,7 +4,7 @@ import pytest_asyncio from app.core import models_core -from app.core.groups.groups_type import GroupType +from app.core.groups.groups_type import AccountType, GroupType from app.modules.cdr import models_cdr from app.modules.cdr.utils_cdr import construct_dataframe_from_users_purchases from tests.commons import ( @@ -53,7 +53,8 @@ async def init_objects(): global cdr_admin cdr_admin = await create_user_with_groups( - [GroupType.student, GroupType.admin_cdr], + [GroupType.admin_cdr], + AccountType.student, email="cdr_admin@etu.ec-lyon.fr", ) @@ -62,7 +63,8 @@ async def init_objects(): global cdr_user1 cdr_user1 = await create_user_with_groups( - [GroupType.student], + [], + AccountType.student, email="demo@demo.fr", name="Demo", firstname="Oui", @@ -72,10 +74,16 @@ async def init_objects(): token_user = create_api_access_token(cdr_user1) global cdr_user2 - cdr_user2 = await create_user_with_groups([GroupType.student]) + cdr_user2 = await create_user_with_groups( + [], + AccountType.student, + ) global cdr_user3 - cdr_user3 = await create_user_with_groups([GroupType.student]) + cdr_user3 = await create_user_with_groups( + [], + AccountType.student, + ) global seller1 seller1 = models_cdr.Seller( diff --git a/tests/test_cinema.py b/tests/test_cinema.py index 9a2c6f260..ed4b81965 100644 --- a/tests/test_cinema.py +++ b/tests/test_cinema.py @@ -5,7 +5,7 @@ from fastapi.testclient import TestClient from app.core import models_core -from app.core.groups.groups_type import GroupType +from app.core.groups.groups_type import AccountType, GroupType from app.modules.cinema import models_cinema from tests.commons import ( add_object_to_db, @@ -29,7 +29,10 @@ async def init_objects() -> None: token_cinema = create_api_access_token(cinema_user_cinema) global cinema_user_simple - cinema_user_simple = await create_user_with_groups([GroupType.student]) + cinema_user_simple = await create_user_with_groups( + [], + AccountType.student, + ) global token_simple token_simple = create_api_access_token(cinema_user_simple) diff --git a/tests/test_core.py b/tests/test_core.py index 7bacb379f..4ab7bf43f 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -26,7 +26,7 @@ async def init_objects() -> None: user_simple = await create_user_with_groups([GroupType.AE]) global token_simple token_simple = create_api_access_token(user_simple) - module_visibility = models_core.ModuleVisibility( + module_visibility = models_core.ModuleGroupVisibility( root=root, allowed_group_id=group_id, ) @@ -61,9 +61,9 @@ def test_add_module_visibility(client: TestClient) -> None: assert response.status_code == 201 -def test_delete_loaners(client: TestClient) -> None: +def test_delete_visibility(client: TestClient) -> None: response = client.delete( - f"/module-visibility/{root}/{group_id}", + f"/module-visibility/{root}/groups/{group_id}", headers={"Authorization": f"Bearer {token_admin}"}, ) assert response.status_code == 204 diff --git a/tests/test_flappybird.py b/tests/test_flappybird.py index cd5d23fee..e37758c4a 100644 --- a/tests/test_flappybird.py +++ b/tests/test_flappybird.py @@ -5,7 +5,7 @@ from fastapi.testclient import TestClient from app.core import models_core -from app.core.groups.groups_type import GroupType +from app.core.groups.groups_type import AccountType, GroupType from app.modules.flappybird import models_flappybird from tests.commons import ( add_object_to_db, @@ -23,7 +23,10 @@ @pytest_asyncio.fixture(scope="module", autouse=True) async def init_objects() -> None: global user - user = await create_user_with_groups([GroupType.student]) + user = await create_user_with_groups( + [], + AccountType.student, + ) global admin_user admin_user = await create_user_with_groups([GroupType.admin]) diff --git a/tests/test_loan.py b/tests/test_loan.py index c894dad6c..91865220d 100644 --- a/tests/test_loan.py +++ b/tests/test_loan.py @@ -6,7 +6,7 @@ from fastapi.testclient import TestClient from app.core import models_core -from app.core.groups.groups_type import GroupType +from app.core.groups.groups_type import AccountType, GroupType from app.modules.loan import models_loan from tests.commons import ( add_object_to_db, @@ -39,7 +39,10 @@ async def init_objects() -> None: global item_to_delete admin_user = await create_user_with_groups([GroupType.admin]) - loan_user_loaner = await create_user_with_groups([GroupType.CAA]) + loan_user_loaner = await create_user_with_groups( + [GroupType.CAA], + AccountType.student, + ) loaner = models_loan.Loaner( id=str(uuid.uuid4()), name="CAA", @@ -47,7 +50,10 @@ async def init_objects() -> None: ) await add_object_to_db(loaner) - loan_user_simple = await create_user_with_groups([GroupType.amap]) + loan_user_simple = await create_user_with_groups( + [GroupType.amap], + AccountType.student, + ) loaner_to_delete = models_loan.Loaner( id=str(uuid.uuid4()), diff --git a/tests/test_payment.py b/tests/test_payment.py index 3d3996c59..e19954472 100644 --- a/tests/test_payment.py +++ b/tests/test_payment.py @@ -14,7 +14,7 @@ from sqlalchemy.ext.asyncio import AsyncSession from app.core import schemas_core -from app.core.groups.groups_type import GroupType +from app.core.groups.groups_type import AccountType from app.core.payment import cruds_payment, models_payment, schemas_payment from app.core.payment.payment_tool import PaymentTool from app.dependencies import get_payment_tool @@ -74,7 +74,10 @@ async def init_objects() -> None: await add_object_to_db(checkout) global user_schema - user = await create_user_with_groups(groups=[GroupType.student]) + user = await create_user_with_groups( + groups=[], + account_type=AccountType.student, + ) user_schema = schemas_core.CoreUser(**user.__dict__) diff --git a/tests/test_phonebook.py b/tests/test_phonebook.py index 94b6773fc..81a99abb5 100644 --- a/tests/test_phonebook.py +++ b/tests/test_phonebook.py @@ -4,7 +4,7 @@ from fastapi.testclient import TestClient from app.core import models_core -from app.core.groups.groups_type import GroupType +from app.core.groups.groups_type import AccountType, GroupType from app.modules.phonebook import models_phonebook from app.modules.phonebook.types_phonebook import Kinds, RoleTags from tests.commons import ( @@ -69,18 +69,31 @@ async def init_objects(): global association2_group2 phonebook_user_BDE = await create_user_with_groups( - [GroupType.student, GroupType.BDE], + [GroupType.BDE], + AccountType.student, ) token_BDE = create_api_access_token(phonebook_user_BDE) - phonebook_user_president = await create_user_with_groups([GroupType.student]) + phonebook_user_president = await create_user_with_groups( + [], + AccountType.student, + ) token_president = create_api_access_token(phonebook_user_president) - phonebook_user_simple = await create_user_with_groups([GroupType.student]) + phonebook_user_simple = await create_user_with_groups( + [], + AccountType.student, + ) token_simple = create_api_access_token(phonebook_user_simple) - phonebook_user_simple2 = await create_user_with_groups([GroupType.student]) - phonebook_user_simple3 = await create_user_with_groups([GroupType.student]) + phonebook_user_simple2 = await create_user_with_groups( + [], + AccountType.student, + ) + phonebook_user_simple3 = await create_user_with_groups( + [], + AccountType.student, + ) phonebook_user_admin = await create_user_with_groups([GroupType.admin]) token_admin = create_api_access_token(phonebook_user_admin) diff --git a/tests/test_raffle.py b/tests/test_raffle.py index f22b71907..852f0234a 100644 --- a/tests/test_raffle.py +++ b/tests/test_raffle.py @@ -5,7 +5,7 @@ from fastapi.testclient import TestClient from app.core import models_core -from app.core.groups.groups_type import GroupType +from app.core.groups.groups_type import AccountType, GroupType from app.modules.raffle import models_raffle from app.modules.raffle.types_raffle import RaffleStatusType from tests.commons import ( @@ -49,8 +49,11 @@ async def init_objects() -> None: raffle_to_delete, \ packticket_to_delete - BDE_user = await create_user_with_groups([GroupType.BDE]) - student_user = await create_user_with_groups([GroupType.student]) + BDE_user = await create_user_with_groups([GroupType.BDE], AccountType.student) + student_user = await create_user_with_groups( + [], + AccountType.student, + ) admin_user = await create_user_with_groups([GroupType.admin]) raffle_to_delete = models_raffle.Raffle( diff --git a/tests/test_raid.py b/tests/test_raid.py index 17be68f83..f036d458e 100644 --- a/tests/test_raid.py +++ b/tests/test_raid.py @@ -8,9 +8,14 @@ from PIL import Image from app.core import models_core -from app.core.groups.groups_type import GroupType +from app.core.groups.groups_type import AccountType, GroupType from app.modules.raid import coredata_raid, models_raid -from app.modules.raid.models_raid import Document, Participant, SecurityFile, Team +from app.modules.raid.models_raid import ( + Document, + RaidParticipant, + RaidTeam, + SecurityFile, +) from app.modules.raid.raid_type import DocumentType, DocumentValidation, Size from app.modules.raid.utils.pdf.pdf_writer import ( HTMLPDFWriter, @@ -23,9 +28,9 @@ create_user_with_groups, ) -participant: models_raid.Participant +participant: models_raid.RaidParticipant -team: models_raid.Team +team: models_raid.RaidTeam raid_admin_user: models_core.CoreUser simple_user: models_core.CoreUser @@ -45,17 +50,26 @@ async def init_objects() -> None: token_raid_admin = create_api_access_token(raid_admin_user) global simple_user, token_simple - simple_user = await create_user_with_groups([GroupType.student]) + simple_user = await create_user_with_groups( + [], + AccountType.student, + ) token_simple = create_api_access_token(simple_user) global simple_user_without_participant, token_simple_without_participant - simple_user_without_participant = await create_user_with_groups([GroupType.student]) + simple_user_without_participant = await create_user_with_groups( + [], + AccountType.student, + ) token_simple_without_participant = create_api_access_token( simple_user_without_participant, ) global simple_user_without_team, token_simple_without_team - simple_user_without_team = await create_user_with_groups([GroupType.student]) + simple_user_without_team = await create_user_with_groups( + [], + AccountType.student, + ) token_simple_without_team = create_api_access_token(simple_user_without_team) document = models_raid.Document( @@ -69,7 +83,7 @@ async def init_objects() -> None: await add_object_to_db(document) global participant - participant = models_raid.Participant( + participant = models_raid.RaidParticipant( id=simple_user.id, firstname="TestFirstname", name="TestName", @@ -82,18 +96,18 @@ async def init_objects() -> None: await add_object_to_db(participant) global team - team = models_raid.Team( + team = models_raid.RaidTeam( id=str(uuid.uuid4()), - name="TestTeam", + name="TestRaidTeam", captain_id=simple_user.id, difficulty=None, ) await add_object_to_db(team) - no_team_participant = models_raid.Participant( + no_team_participant = models_raid.RaidParticipant( id=simple_user_without_team.id, - firstname="NoTeam", - name="NoTeam", + firstname="NoRaidTeam", + name="NoRaidTeam", birthday=datetime.date(2001, 1, 1), phone="0606060606", email="test@no_team.fr", @@ -113,7 +127,7 @@ def test_get_participant_by_id(client: TestClient): def test_create_participant(client: TestClient): participant_data = { "firstname": "New", - "name": "Participant", + "name": "RaidParticipant", "birthday": "2000-01-01", "phone": "0123456789", "email": "new@participant.com", @@ -226,14 +240,14 @@ def test_update_participant_invalid_security_file_id(client: TestClient): def test_create_team(client: TestClient): - team_data = {"name": "New Team"} + team_data = {"name": "New RaidTeam"} response = client.post( "/raid/teams", json=team_data, headers={"Authorization": f"Bearer {token_simple_without_team}"}, ) assert response.status_code == 201 - assert response.json()["name"] == "New Team" + assert response.json()["name"] == "New RaidTeam" def test_generate_teams_pdf(client: TestClient): @@ -272,7 +286,7 @@ def test_get_team_by_id(client: TestClient): def test_update_team(client: TestClient): - update_data = {"name": "Updated Team"} + update_data = {"name": "Updated RaidTeam"} response = client.patch( f"/raid/teams/{team.id}", json=update_data, @@ -324,7 +338,7 @@ def test_read_document_participant_not_found(client: TestClient): headers={"Authorization": f"Bearer {token_simple}"}, ) assert response.status_code == 404 - assert response.json()["detail"] == "Participant owning the document not found." + assert response.json()["detail"] == "RaidParticipant owning the document not found." # requires a document to be added @@ -461,7 +475,7 @@ def test_merge_teams(client: TestClient): team2_response = client.post( "/raid/teams", - json={"name": "Team 2"}, + json={"name": "RaidTeam 2"}, headers={"Authorization": f"Bearer {token_simple_without_participant}"}, ) assert team2_response.status_code == 201 @@ -547,7 +561,7 @@ def test_delete_team(client: TestClient): assert response.status_code == 204 -@pytest.mark.asyncio() +@pytest.mark.asyncio async def test_get_payment_url_no_participant(client: TestClient, mocker): # Mock the necessary dependencies mocker.patch("app.modules.raid.cruds_raid.get_participant_by_id", return_value=None) @@ -567,7 +581,7 @@ async def test_get_payment_url_no_participant(client: TestClient, mocker): assert response.json()["url"] == "https://some.url.fr/checkout" -@pytest.mark.asyncio() +@pytest.mark.asyncio async def test_get_payment_url_participant_no_payment(client: TestClient, mocker): # Mock the necessary dependencies mocker.patch( @@ -595,7 +609,7 @@ async def test_get_payment_url_participant_no_payment(client: TestClient, mocker assert response.json()["url"] == "https://some.url.fr/checkout" -@pytest.mark.asyncio() +@pytest.mark.asyncio async def test_get_payment_url_participant_with_tshirt(client: TestClient, mocker): # Mock the necessary dependencies mocker.patch( @@ -649,7 +663,7 @@ async def test_get_payment_url_participant_with_tshirt(client: TestClient, mocke # assert response.json()["detail"] == "Prices not set." -@pytest.mark.asyncio() +@pytest.mark.asyncio async def test_get_payment_url_participant_already_paid(client: TestClient, mocker): # Mock the necessary dependencies mocker.patch( @@ -680,14 +694,14 @@ async def test_get_payment_url_participant_already_paid(client: TestClient, mock ## Test for pdf writer -@pytest.fixture() +@pytest.fixture def mock_team(): return Mock( - spec=Team, - name="Test Team", + spec=RaidTeam, + name="Test RaidTeam", number=1, captain=Mock( - spec=Participant, + spec=RaidParticipant, name="Doe", firstname="John", birthday=datetime.datetime(1990, 1, 1, tzinfo=datetime.UTC), @@ -719,15 +733,15 @@ def mock_team(): ) -@pytest.fixture() +@pytest.fixture def mock_security_file(): return Mock(spec=SecurityFile, allergy="None", asthma=False) -@pytest.fixture() +@pytest.fixture def mock_participant(): return Mock( - spec=Participant, + spec=RaidParticipant, name="Doe", firstname="John", birthday=datetime.datetime(1990, 1, 1, tzinfo=datetime.UTC), @@ -757,7 +771,7 @@ def mock_participant(): ) -@pytest.fixture() +@pytest.fixture def mock_document(): return Mock( spec=Document, diff --git a/tests/test_recommendation.py b/tests/test_recommendation.py index 7db90d80c..c1c550eb2 100644 --- a/tests/test_recommendation.py +++ b/tests/test_recommendation.py @@ -5,7 +5,7 @@ import pytest_asyncio from fastapi.testclient import TestClient -from app.core.groups.groups_type import GroupType +from app.core.groups.groups_type import AccountType, GroupType from app.modules.recommendation import models_recommendation from tests.commons import ( add_object_to_db, @@ -20,7 +20,10 @@ @pytest_asyncio.fixture(scope="module", autouse=True) async def init_objects() -> None: - user_simple = await create_user_with_groups([GroupType.student]) + user_simple = await create_user_with_groups( + [], + AccountType.student, + ) global token_simple token_simple = create_api_access_token(user_simple) diff --git a/tests/test_school.py b/tests/test_school.py new file mode 100644 index 000000000..75d105ab3 --- /dev/null +++ b/tests/test_school.py @@ -0,0 +1,147 @@ +import pytest_asyncio +from fastapi.testclient import TestClient +from pytest_mock import MockerFixture + +from app.core import models_core +from app.core.groups.groups_type import GroupType +from tests.commons import ( + add_object_to_db, + create_api_access_token, + create_user_with_groups, +) + +admin_user: models_core.CoreUser + +UNIQUE_TOKEN = "my_unique_token" + +id_test_ens = "4d133de7-24c4-4dbc-be73-4705a2ddd315" + + +@pytest_asyncio.fixture(scope="module", autouse=True) +async def init_objects() -> None: + global admin_user + + ens = models_core.CoreSchool( + id=id_test_ens, + name="ENS", + email_regex="^.*@ens.fr$", + ) + await add_object_to_db(ens) + + admin_user = await create_user_with_groups([GroupType.admin]) + + +def test_read_schools(client: TestClient) -> None: + token = create_api_access_token(admin_user) + + response = client.get( + "/schools/", + headers={"Authorization": f"Bearer {token}"}, + ) + assert response.status_code == 200 + + +def test_read_school(client: TestClient) -> None: + token = create_api_access_token(admin_user) + + response = client.get( + f"/schools/{id_test_ens}", + headers={"Authorization": f"Bearer {token}"}, + ) + assert response.status_code == 200 + data = response.json() + assert data["name"] == "ENS" + + +def test_create_school(client: TestClient) -> None: + token = create_api_access_token(admin_user) + + response = client.post( + "/schools/", + json={ + "name": "school", + "email_regex": "^.*@school.fr$", + }, + headers={"Authorization": f"Bearer {token}"}, + ) + assert response.status_code == 201 + + +def test_update_school(client: TestClient) -> None: + token = create_api_access_token(admin_user) + + response = client.patch( + f"/schools/{id_test_ens}", + json={ + "name": "school ENS", + }, + headers={"Authorization": f"Bearer {token}"}, + ) + assert response.status_code == 204 + + response = client.get( + f"/schools/{id_test_ens}", + headers={"Authorization": f"Bearer {token}"}, + ) + data = response.json() + assert data["name"] == "school ENS" + + +def test_create_user_corresponding_to_school( + mocker: MockerFixture, + client: TestClient, +) -> None: + token = create_api_access_token(admin_user) + + response = client.post( + "/schools/", + json={ + "name": "ENS Lyon", + "email_regex": r"^[\w\-.]*@ens-lyon\.fr$", + }, + headers={"Authorization": f"Bearer {token}"}, + ) + assert response.status_code == 201 + school_id = response.json()["id"] + + mocker.patch( + "app.core.users.endpoints_users.security.generate_token", + return_value=UNIQUE_TOKEN, + ) + + response = client.post( + "/users/create", + json={ + "email": "new_user@ens-lyon.fr", + }, + ) + assert response.status_code == 201 + + response = client.post( + "/users/activate", + json={ + "activation_token": UNIQUE_TOKEN, + "password": "password", + "firstname": "new_user_firstname", + "name": "new_user_name", + }, + ) + + assert response.status_code == 201 + + users = client.get( + "/users/", + headers={"Authorization": f"Bearer {token}"}, + ) + user = next( + user + for user in users.json() + if user["firstname"] == "new_user_firstname" and user["name"] == "new_user_name" + ) + + user_detail = client.get( + f"/users/{user['id']}", + headers={"Authorization": f"Bearer {token}"}, + ) + + assert user_detail.json()["school_id"] == school_id diff --git a/tests/test_user_fusion.py b/tests/test_user_fusion.py index 00a95599f..cda65ab6f 100644 --- a/tests/test_user_fusion.py +++ b/tests/test_user_fusion.py @@ -5,7 +5,7 @@ from fastapi.testclient import TestClient from app.core import models_core -from app.core.groups.groups_type import GroupType +from app.core.groups.groups_type import AccountType, GroupType from app.types.membership import AvailableAssociationMembership from tests.commons import ( add_object_to_db, @@ -35,11 +35,13 @@ async def init_objects() -> None: [GroupType.admin, GroupType.admin_cdr], ) student_user_to_keep = await create_user_with_groups( - [GroupType.student, GroupType.BDE], + [GroupType.BDE], + AccountType.student, email=FABRISTPP_EMAIL_1, ) student_user_to_delete = await create_user_with_groups( - [GroupType.student, GroupType.BDE, GroupType.CAA], + [GroupType.BDE, GroupType.CAA], + AccountType.student, email=FABRISTPP_EMAIL_2, ) diff --git a/tests/test_users.py b/tests/test_users.py index ad910cee0..6e472d5c8 100644 --- a/tests/test_users.py +++ b/tests/test_users.py @@ -6,7 +6,8 @@ from pytest_mock import MockerFixture from app.core import models_core -from app.core.groups.groups_type import GroupType +from app.core.groups.groups_type import AccountType, GroupType +from app.core.schools.schools_type import SchoolType from tests.commons import ( create_api_access_token, create_user_with_groups, @@ -38,17 +39,16 @@ async def init_objects() -> None: email=FABRISTPP_EMAIL_1, ) student_user = await create_user_with_groups( - [GroupType.student], + [], + AccountType.student, email=student_user_email, password=student_user_password, ) - external_user = await create_user_with_groups( - [GroupType.external], - external=True, - ) + external_user = await create_user_with_groups([], AccountType.external) student_user_with_old_email = await create_user_with_groups( - [GroupType.student], + [], + AccountType.student, email=FABRISTPP_EMAIL_2, ) @@ -69,31 +69,41 @@ def test_count_users(client: TestClient) -> None: def test_search_users(client: TestClient) -> None: - group = GroupType.student.value + account_type = AccountType.student.value response = client.get( - f"/groups/{group}", + f"/users/?query=&accountTypes={account_type}", headers={"Authorization": f"Bearer {token_admin_user}"}, ) - group_users = set(member["id"] for member in response.json()["members"]) + account_type_users = set(user["id"] for user in response.json()) response = client.get( - f"/users/search?query=&includedGroups={group}", + f"/users/search?query=&includedAccountTypes={account_type}", headers={"Authorization": f"Bearer {token_student_user}"}, ) assert response.status_code == 200 data_users = set(user["id"] for user in response.json()) assert ( - data_users <= group_users + data_users <= account_type_users ) # This endpoint is limited to 10 members, so we only need an inclusion between the two sets, in case there are more than 10 members in the group. response = client.get( - f"/users/search?query=&excludedGroups={group}", + f"/users/search?query=&excludedAccountTypes={account_type}", headers={"Authorization": f"Bearer {token_student_user}"}, ) assert response.status_code == 200 data = response.json() - assert all(user["id"] not in group_users for user in data) + assert all(user["id"] not in account_type_users for user in data) + + +def test_get_account_types(client: TestClient) -> None: + response = client.get( + "/users/account-types", + headers={"Authorization": f"Bearer {token_admin_user}"}, + ) + assert response.status_code == 200 + data = response.json() + assert data == list(AccountType) def test_read_current_user(client: TestClient) -> None: @@ -128,8 +138,8 @@ def test_read_user(client: TestClient) -> None: ("fab@etu.ec-lyon.fr", 201), ("fab@ec-lyon.fr", 201), ("fab@centraliens-lyon.net", 201), - ("fab@test.fr", 400), - ("fab@ecl22.ec-lyon.fr", 400), + ("fab@test.fr", 201), + ("fab@ecl22.ec-lyon.fr", 201), ], ) def test_create_user_by_user_with_email( @@ -167,8 +177,8 @@ def test_create_and_activate_user(mocker: MockerFixture, client: TestClient) -> json={ "activation_token": UNIQUE_TOKEN, "password": "password", - "firstname": "firstname", - "name": "name", + "firstname": "new_user_firstname", + "name": "new_user_name", "nickname": "nickname", "floor": "X1", }, @@ -176,6 +186,23 @@ def test_create_and_activate_user(mocker: MockerFixture, client: TestClient) -> assert response.status_code == 201 + users = client.get( + "/users/", + headers={"Authorization": f"Bearer {token_admin_user}"}, + ) + user = next( + user + for user in users.json() + if user["firstname"] == "new_user_firstname" and user["name"] == "new_user_name" + ) + + user_detail = client.get( + f"/users/{user['id']}", + headers={"Authorization": f"Bearer {token_admin_user}"}, + ) + + assert user_detail.json()["school_id"] == SchoolType.centrale_lyon.value + @pytest.mark.parametrize( ("email", "expected_error"), @@ -208,13 +235,12 @@ def activate_user_with_invalid_phone_number( def test_update_batch_create_users(client: TestClient) -> None: - student = "39691052-2ae5-4e12-99d0-7a9f5f2b0136" response = client.post( "/users/batch-creation", json=[ - {"email": "1@1.fr", "account_type": student}, - {"email": "2@1.fr", "account_type": student}, - {"email": "3@b.fr", "account_type": student}, + {"email": "1@1.fr"}, + {"email": "2@1.fr"}, + {"email": "3@b.fr"}, ], headers={"Authorization": f"Bearer {token_admin_user}"}, ) @@ -404,15 +430,15 @@ def test_batch_internal_user(client: TestClient) -> None: assert response.status_code == 204 response = client.get( - f"/groups/{GroupType.external.value}", + f"/users/search?query=&includedAccountTypes={AccountType.external.value}", headers={"Authorization": f"Bearer {token}"}, ) - assert response.json()["members"] == [] + assert response.json() == [] response = client.get( - f"/groups/{GroupType.student.value}", + f"/users/search?query=&includedAccountTypes={AccountType.student.value}", headers={"Authorization": f"Bearer {token}"}, ) - members = response.json()["members"] + members = response.json() members_ids = [member["id"] for member in members] assert external_user.id in members_ids