diff --git a/.github/workflows/test-build-release.yml b/.github/workflows/test-build-release.yml index 98f7c92..59971f5 100644 --- a/.github/workflows/test-build-release.yml +++ b/.github/workflows/test-build-release.yml @@ -17,6 +17,7 @@ jobs: - run: pytest env: LD_LIBRARY_PATH: /usr/local/lib + NPBC_DATABASE_DIR: data # build executable for linux build-linux: diff --git a/npbc_cli.py b/npbc_cli.py index 1ebdbed..61d09d4 100644 --- a/npbc_cli.py +++ b/npbc_cli.py @@ -8,12 +8,12 @@ """ -import sqlite3 +from sqlite3 import DatabaseError, connect, Connection from argparse import ArgumentParser from argparse import Namespace as ArgNamespace +from collections.abc import Generator from datetime import datetime from sys import argv -from typing import Generator from colorama import Fore, Style @@ -176,7 +176,7 @@ def status_print(status: bool, message: str) -> None: print(f"{Style.BRIGHT}{message}{Style.RESET_ALL}\n") -def calculate(parsed_arguments: ArgNamespace) -> None: +def calculate(parsed_arguments: ArgNamespace, connection: Connection) -> None: """calculate the cost for a given month and year - default to the previous month if no month and no year is given - default to the current month if no month is given and year is given @@ -210,12 +210,12 @@ def calculate(parsed_arguments: ArgNamespace) -> None: # prepare a dictionary for undelivered strings undelivered_strings = { int(paper_id): [] - for paper_id, _, _, _, _ in npbc_core.get_papers() + for paper_id, _, _, _, _ in npbc_core.get_papers(connection) } # get the undelivered strings from the database try: - raw_undelivered_strings = npbc_core.get_undelivered_strings(month=month, year=year) + raw_undelivered_strings = npbc_core.get_undelivered_strings(connection, month=month, year=year) # add them to the dictionary for _, paper_id, _, _, string in raw_undelivered_strings: @@ -226,33 +226,34 @@ def calculate(parsed_arguments: ArgNamespace) -> None: pass # if there is a database error, print an error message - except sqlite3.DatabaseError as e: + except DatabaseError as e: status_print(False, f"Database error: {e}\nPlease report this to the developer.") return try: # calculate the cost for each paper costs, total, undelivered_dates = npbc_core.calculate_cost_of_all_papers( + connection, undelivered_strings, month, year ) # if there is a database error, print an error message - except sqlite3.DatabaseError as e: + except DatabaseError as e: status_print(False, f"Database error: {e}\nPlease report this to the developer.") return # format the results - formatted = '\n'.join(npbc_core.format_output(costs, total, month, year)) + formatted = '\n'.join(npbc_core.format_output(connection, costs, total, month, year)) # unless the user specifies so, log the results to the database if not parsed_arguments.nolog: try: - npbc_core.save_results(costs, undelivered_dates, month, year) + npbc_core.save_results(connection, costs, undelivered_dates, month, year) # if there is a database error, print an error message - except sqlite3.DatabaseError as e: + except DatabaseError as e: status_print(False, f"Database error: {e}\nPlease report this to the developer.") return @@ -263,7 +264,7 @@ def calculate(parsed_arguments: ArgNamespace) -> None: print(f"SUMMARY:\n\n{formatted}") -def addudl(parsed_arguments: ArgNamespace) -> None: +def addudl(parsed_arguments: ArgNamespace, connection: Connection) -> None: """add undelivered strings to the database - default to the current month if no month and/or no year is given""" @@ -286,7 +287,7 @@ def addudl(parsed_arguments: ArgNamespace) -> None: # attempt to add the strings to the database try: - npbc_core.add_undelivered_string(month, year, parsed_arguments.paperid, *parsed_arguments.strings) + npbc_core.add_undelivered_string(connection, month, year, parsed_arguments.paperid, *parsed_arguments.strings) # if the paper doesn't exist, print an error message except npbc_exceptions.PaperNotExists: @@ -299,7 +300,7 @@ def addudl(parsed_arguments: ArgNamespace) -> None: return # if there is a database error, print an error message - except sqlite3.DatabaseError as e: + except DatabaseError as e: status_print(False, f"Database error: {e}\nPlease report this to the developer.") return @@ -311,7 +312,7 @@ def addudl(parsed_arguments: ArgNamespace) -> None: status_print(True, "Success!") -def deludl(parsed_arguments: ArgNamespace) -> None: +def deludl(parsed_arguments: ArgNamespace, connection: Connection) -> None: """delete undelivered strings from the database""" # validate the month and year @@ -326,6 +327,7 @@ def deludl(parsed_arguments: ArgNamespace) -> None: # attempt to delete the strings from the database try: npbc_core.delete_undelivered_string( + connection, month=parsed_arguments.month, year=parsed_arguments.year, paper_id=parsed_arguments.paperid, @@ -344,14 +346,14 @@ def deludl(parsed_arguments: ArgNamespace) -> None: return # if there is a database error, print an error message - except sqlite3.DatabaseError as e: + except DatabaseError as e: status_print(False, f"Database error: {e}\nPlease report this to the developer.") return status_print(True, "Success!") -def getudl(parsed_arguments: ArgNamespace) -> None: +def getudl(parsed_arguments: ArgNamespace, connection: Connection) -> None: """get undelivered strings from the database filter by whichever parameter the user provides. they as many as they want. available parameters: month, year, paper_id, string_id, string""" @@ -368,6 +370,7 @@ def getudl(parsed_arguments: ArgNamespace) -> None: # attempt to get the strings from the database try: undelivered_strings = npbc_core.get_undelivered_strings( + connection, month=parsed_arguments.month, year=parsed_arguments.year, paper_id=parsed_arguments.paperid, @@ -401,7 +404,7 @@ def extract_delivery_from_user_input(input_delivery: str) -> list[bool]: return list(map(lambda x: x == 'Y', input_delivery)) -def extract_costs_from_user_input(paper_id: int | None, delivery_data: list[bool] | None, *input_costs: float) -> Generator[float, None, None]: +def extract_costs_from_user_input(connection: Connection, paper_id: int | None, delivery_data: list[bool] | None, *input_costs: float) -> Generator[float, None, None]: """convert the user input to a float list""" # filter the data to remove zeros @@ -430,10 +433,10 @@ def extract_costs_from_user_input(paper_id: int | None, delivery_data: list[bool # get the delivery data from the database, and filter for the paper ID try: - raw_data = [paper for paper in npbc_core.get_papers() if paper[0] == int(paper_id)] + raw_data = [paper for paper in npbc_core.get_papers(connection) if paper[0] == int(paper_id)] # if there is a database error, print an error message - except sqlite3.DatabaseError as e: + except DatabaseError as e: status_print(False, f"Database error: {e}\nPlease report this to the developer.") return @@ -457,7 +460,7 @@ def extract_costs_from_user_input(paper_id: int | None, delivery_data: list[bool raise npbc_exceptions.InvalidInput("Neither delivery data nor paper ID given.") -def editpaper(parsed_arguments: ArgNamespace) -> None: +def editpaper(parsed_arguments: ArgNamespace, connection: Connection) -> None: """edit a paper's information""" @@ -468,10 +471,11 @@ def editpaper(parsed_arguments: ArgNamespace) -> None: # attempt to edit the paper. if costs are given, use them, else use None npbc_core.edit_existing_paper( + connection, paper_id=parsed_arguments.paperid, name=parsed_arguments.name, days_delivered=delivery_data, - days_cost=list(extract_costs_from_user_input(parsed_arguments.paperid, delivery_data, *parsed_arguments.costs)) if parsed_arguments.costs else None + days_cost=list(extract_costs_from_user_input(connection, parsed_arguments.paperid, delivery_data, *parsed_arguments.costs)) if parsed_arguments.costs else None ) # if the paper doesn't exist, print an error message @@ -485,14 +489,14 @@ def editpaper(parsed_arguments: ArgNamespace) -> None: return # if there is a database error, print an error message - except sqlite3.DatabaseError as e: + except DatabaseError as e: status_print(False, f"Database error: {e}\nPlease report this to the developer.") return status_print(True, "Success!") -def addpaper(parsed_arguments: ArgNamespace) -> None: +def addpaper(parsed_arguments: ArgNamespace, connection: Connection) -> None: """add a new paper to the database""" try: @@ -501,9 +505,10 @@ def addpaper(parsed_arguments: ArgNamespace) -> None: # attempt to add the paper. npbc_core.add_new_paper( + connection, name=parsed_arguments.name, days_delivered=delivery_data, - days_cost=list(extract_costs_from_user_input(None, delivery_data, *parsed_arguments.costs)) + days_cost=list(extract_costs_from_user_input(connection, None, delivery_data, *parsed_arguments.costs)) ) # if the paper already exists, print an error message @@ -517,19 +522,19 @@ def addpaper(parsed_arguments: ArgNamespace) -> None: return # if there is a database error, print an error message - except sqlite3.DatabaseError as e: + except DatabaseError as e: status_print(False, f"Database error: {e}\nPlease report this to the developer.") return status_print(True, "Success!") -def delpaper(parsed_arguments: ArgNamespace) -> None: +def delpaper(parsed_arguments: ArgNamespace, connection: Connection) -> None: """delete a paper from the database""" # attempt to delete the paper try: - npbc_core.delete_existing_paper(parsed_arguments.paperid) + npbc_core.delete_existing_paper(connection, parsed_arguments.paperid) # if the paper doesn't exist, print an error message except npbc_exceptions.PaperNotExists: @@ -537,14 +542,14 @@ def delpaper(parsed_arguments: ArgNamespace) -> None: return # if there is a database error, print an error message - except sqlite3.DatabaseError as e: + except DatabaseError as e: status_print(False, f"Database error: {e}\nPlease report this to the developer.") return status_print(True, "Success!") -def getpapers(parsed_arguments: ArgNamespace) -> None: +def getpapers(parsed_arguments: ArgNamespace, connection: Connection) -> None: """get a list of all papers in the database - filter by whichever parameter the user provides. they may use as many as they want (but keys are always printed) - available parameters: name, days, costs @@ -552,10 +557,10 @@ def getpapers(parsed_arguments: ArgNamespace) -> None: # get the papers from the database try: - raw_data = npbc_core.get_papers() + raw_data = npbc_core.get_papers(connection) # if there is a database error, print an error message - except sqlite3.DatabaseError as e: + except DatabaseError as e: status_print(False, f"Database error: {e}\nPlease report this to the developer.") return @@ -611,7 +616,7 @@ def getpapers(parsed_arguments: ArgNamespace) -> None: delivery = [ ''.join([ 'Y' if days[paper_id][day_id]['delivery'] else 'N' - for day_id, _ in enumerate(npbc_core.WEEKDAY_NAMES) + for day_id in range(len(npbc_core.WEEKDAY_NAMES)) ]) for paper_id in ids ] @@ -624,7 +629,7 @@ def getpapers(parsed_arguments: ArgNamespace) -> None: costs = [ ';'.join([ str(days[paper_id][day_id]['cost']) - for day_id, _ in enumerate(npbc_core.WEEKDAY_NAMES) + for day_id in range(len(npbc_core.WEEKDAY_NAMES)) if days[paper_id][day_id]['cost'] != 0 ]) for paper_id in ids @@ -652,7 +657,7 @@ def getpapers(parsed_arguments: ArgNamespace) -> None: print() -def getlogs(parsed_arguments: ArgNamespace) -> None: +def getlogs(parsed_arguments: ArgNamespace, connection: Connection) -> None: """get a list of all logs in the database - filter by whichever parameter the user provides. they may use as many as they want (but log IDs are always printed) - available parameters: log_id, paper_id, month, year, timestamp @@ -661,6 +666,7 @@ def getlogs(parsed_arguments: ArgNamespace) -> None: # attempt to get the logs from the database try: data = npbc_core.get_logged_data( + connection, query_log_id = parsed_arguments.logid, query_paper_id=parsed_arguments.paperid, query_month=parsed_arguments.month, @@ -669,7 +675,7 @@ def getlogs(parsed_arguments: ArgNamespace) -> None: ) # if there is a database error, print an error message - except sqlite3.DatabaseError as e: + except DatabaseError as e: status_print(False, f"Database error. Please report this to the developer.\n{e}") return @@ -705,18 +711,25 @@ def main(arguments: list[str]) -> None: # attempt to initialize the database try: - npbc_core.setup_and_connect_DB() + database_path = npbc_core.create_and_setup_DB() # if there is a database error, print an error message - except sqlite3.DatabaseError as e: + except DatabaseError as e: status_print(False, f"Database error: {e}\nPlease report this to the developer.") return - + # parse the command line arguments parsed_namespace = define_and_read_args(arguments) - # execute the appropriate function - parsed_namespace.func(parsed_namespace) + try: + + with connect(database_path) as connection: + # execute the appropriate function + parsed_namespace.func(parsed_namespace, connection) + + # close the database connection + finally: + connection.close() # type: ignore if __name__ == "__main__": diff --git a/npbc_core.py b/npbc_core.py index a6268a3..37eb7ed 100644 --- a/npbc_core.py +++ b/npbc_core.py @@ -8,12 +8,12 @@ from calendar import day_name as weekday_names_iterable from calendar import monthcalendar, monthrange -from datetime import date as date_type -from datetime import datetime, timedelta +from collections.abc import Generator +from datetime import date, datetime, timedelta from os import environ from pathlib import Path from sqlite3 import Connection, connect -from typing import Generator + import numpy import numpy.typing @@ -22,23 +22,19 @@ ## paths for the folder containing schema and database files # during normal use, the DB will be in ~/.npbc (where ~ is the user's home directory) and the schema will be bundled with the executable -# during development, the DB and schema will both be in "data" - -# default to PRODUCTION -DATABASE_DIR = Path.home() / '.npbc' -SCHEMA_PATH = Path(__file__).parent / 'schema.sql' +# during development, the DB and schema will both be in the folder provided by the environment (likely "data") -# if in a development environment, set the paths to the data folder -if environ.get('NPBC_DEVELOPMENT') or environ.get('CI'): - DATABASE_DIR = Path('data') - SCHEMA_PATH = Path('data') / 'schema.sql' +DATABASE_VARIABLE = environ.get("NPBC_DATABASE_DIR") +DATABASE_DIR = Path(DATABASE_VARIABLE) if DATABASE_VARIABLE is not None else Path.home() / ".npbc" +DATABASE_PATH = DATABASE_DIR / "npbc.sqlite" -DATABASE_PATH = DATABASE_DIR / 'npbc.db' +SCHEMA_DIR = Path(DATABASE_VARIABLE) if DATABASE_VARIABLE is not None else Path(__file__).parent +SCHEMA_PATH = SCHEMA_DIR / "schema.sql" -## list constant for names of weekdays -WEEKDAY_NAMES = list(weekday_names_iterable) +## constant for names of weekdays +WEEKDAY_NAMES = tuple(weekday_names_iterable) -def setup_and_connect_DB() -> None: +def create_and_setup_DB() -> Path: """ensure DB exists and it's set up with the schema""" DATABASE_DIR.mkdir(parents=True, exist_ok=True) @@ -49,6 +45,8 @@ def setup_and_connect_DB() -> None: connection.close() + return DATABASE_PATH + def get_number_of_each_weekday(month: int, year: int) -> Generator[int, None, None]: """generate a list of number of times each weekday occurs in a given month (return a generator) @@ -61,7 +59,7 @@ def get_number_of_each_weekday(month: int, year: int) -> Generator[int, None, No number_of_weeks = len(main_calendar) # iterate over each possible weekday - for i, _ in enumerate(WEEKDAY_NAMES): + for i in range(len(WEEKDAY_NAMES)): # assume that the weekday occurs once per week in the month number_of_weekday: int = number_of_weeks @@ -84,50 +82,52 @@ def validate_undelivered_string(*strings: str) -> None: # check that the string matches one of the acceptable patterns for string in strings: - if string and not ( - npbc_regex.NUMBER_MATCH_REGEX.match(string) or - npbc_regex.RANGE_MATCH_REGEX.match(string) or - npbc_regex.DAYS_MATCH_REGEX.match(string) or - npbc_regex.N_DAY_MATCH_REGEX.match(string) or - npbc_regex.ALL_MATCH_REGEX.match(string) + if string and not any ( + pattern.match(string) for pattern in ( + npbc_regex.NUMBER_MATCH_REGEX, + npbc_regex.RANGE_MATCH_REGEX, + npbc_regex.DAYS_MATCH_REGEX, + npbc_regex.N_DAY_MATCH_REGEX, + npbc_regex.ALL_MATCH_REGEX + ) ): raise npbc_exceptions.InvalidUndeliveredString(f'{string} is not a valid undelivered string.') # if we get here, all strings passed the regex check -def extract_number(string: str, month: int, year: int) -> date_type | None: +def extract_number(string: str, month: int, year: int) -> date | None: """if the date is simply a number, it's a single day. so we just identify that date""" - date = int(string) + day = int(string) # if the date is valid for the given month - if date > 0 and date <= monthrange(year, month)[1]: - return date_type(year, month, date) + if 0 < day <= monthrange(year, month)[1]: + return date(year, month, day) -def extract_range(string: str, month: int, year: int) -> Generator[date_type, None, None]: +def extract_range(string: str, month: int, year: int) -> Generator[date, None, None]: """if the date is a range of numbers, it's a range of days. we identify all the dates in that range, bounds inclusive""" start, end = map(int, npbc_regex.HYPHEN_SPLIT_REGEX.split(string)) # if the range is valid for the given month if 0 < start <= end <= monthrange(year, month)[1]: - for date in range(start, end + 1): - yield date_type(year, month, date) + for day in range(start, end + 1): + yield date(year, month, day) -def extract_weekday(string: str, month: int, year: int) -> Generator[date_type, None, None]: +def extract_weekday(string: str, month: int, year: int) -> Generator[date, None, None]: """if the date is the plural of a weekday name, we identify all dates in that month which are the given weekday""" weekday = WEEKDAY_NAMES.index(string.capitalize().rstrip('s')) for day in range(1, monthrange(year, month)[1] + 1): - if date_type(year, month, day).weekday() == weekday: - yield date_type(year, month, day) + if date(year, month, day).weekday() == weekday: + yield date(year, month, day) -def extract_nth_weekday(string: str, month: int, year: int) -> date_type | None: +def extract_nth_weekday(string: str, month: int, year: int) -> date | None: """if the date is a number and a weekday name (singular), we identify the date that is the nth occurrence of the given weekday in the month""" n, weekday_name = npbc_regex.HYPHEN_SPLIT_REGEX.split(string) @@ -135,30 +135,30 @@ def extract_nth_weekday(string: str, month: int, year: int) -> date_type | None: n = int(n) # if the day is valid for the given month - if n > 0 and n <= list(get_number_of_each_weekday(month, year))[WEEKDAY_NAMES.index(weekday_name.capitalize())]: + if 0 < n <= list(get_number_of_each_weekday(month, year))[WEEKDAY_NAMES.index(weekday_name.capitalize())]: # record the "day_id" corresponding to the given weekday name weekday = WEEKDAY_NAMES.index(weekday_name.capitalize()) # store all dates when the given weekday occurs in the given month valid_dates = [ - date_type(year, month, day) + date(year, month, day) for day in range(1, monthrange(year, month)[1] + 1) - if date_type(year, month, day).weekday() == weekday + if date(year, month, day).weekday() == weekday ] # return the date that is the nth occurrence of the given weekday in the month return valid_dates[n - 1] -def extract_all(month: int, year: int) -> Generator[date_type, None, None]: +def extract_all(month: int, year: int) -> Generator[date, None, None]: """if the text is "all", we identify all the dates in the month""" for day in range(1, monthrange(year, month)[1] + 1): - yield date_type(year, month, day) + yield date(year, month, day) -def parse_undelivered_string(month: int, year: int, string: str) -> set[date_type]: +def parse_undelivered_string(month: int, year: int, string: str) -> set[date]: """parse a section of the strings - each section is a string that specifies a set of dates - this function will return a set of dates that uniquely identifies each date mentioned across the string""" @@ -194,7 +194,7 @@ def parse_undelivered_string(month: int, year: int, string: str) -> set[date_typ return dates -def parse_undelivered_strings(month: int, year: int, *strings: str) -> set[date_type]: +def parse_undelivered_strings(month: int, year: int, *strings: str) -> set[date]: """parse a string that specifies when a given paper was not delivered - each section states some set of dates - this function will return a set of dates that uniquely identifies each date mentioned across all the strings""" @@ -219,7 +219,7 @@ def parse_undelivered_strings(month: int, year: int, *strings: str) -> set[date_ return dates -def get_cost_and_delivery_data(paper_id: int, connection: Connection) -> tuple[numpy.typing.NDArray[numpy.floating], numpy.typing.NDArray[numpy.integer]]: +def get_cost_and_delivery_data(paper_id: int, connection: Connection) -> tuple[numpy.typing.NDArray[numpy.floating], numpy.typing.NDArray[numpy.int8]]: """get the cost and delivery data for a given paper from the DB""" delivered_query = """ @@ -242,29 +242,29 @@ def get_cost_and_delivery_data(paper_id: int, connection: Connection) -> tuple[n def calculate_cost_of_one_paper( number_of_each_weekday: list[int], - undelivered_dates: set[date_type], + undelivered_dates: set[date], cost_data: numpy.typing.NDArray[numpy.floating], - delivery_data: numpy.typing.NDArray[numpy.integer] + delivery_data: numpy.typing.NDArray[numpy.int8] ) -> float: """calculate the cost of one paper for the full month - any dates when it was not delivered will be removed""" # initialize counters corresponding to each weekday when the paper was not delivered - number_of_days_per_weekday_not_received = numpy.zeros(len(number_of_each_weekday), dtype=numpy.integer) + number_of_days_per_weekday_not_received = numpy.zeros(len(number_of_each_weekday), dtype=numpy.int8) # for each date that the paper was not delivered, we increment the counter for the corresponding weekday - for date in undelivered_dates: - number_of_days_per_weekday_not_received[date.weekday()] += 1 + for day in undelivered_dates: + number_of_days_per_weekday_not_received[day.weekday()] += 1 return numpy.sum( delivery_data * cost_data * (number_of_each_weekday - number_of_days_per_weekday_not_received) ) -def calculate_cost_of_all_papers(undelivered_strings: dict[int, list[str]], month: int, year: int) -> tuple[ +def calculate_cost_of_all_papers(connection: Connection, undelivered_strings: dict[int, list[str]], month: int, year: int) -> tuple[ dict[int, float], float, - dict[int, set[date_type]] + dict[int, set[date]] ]: """calculate the cost of all papers for the full month - return data about the cost of each paper, the total cost, and dates when each paper was not delivered""" @@ -273,19 +273,17 @@ def calculate_cost_of_all_papers(undelivered_strings: dict[int, list[str]], mont cost_and_delivery_data = {} # get the IDs of papers that exist - with connect(DATABASE_PATH) as connection: - papers = connection.execute("SELECT paper_id FROM papers;").fetchall() + papers = connection.execute("SELECT paper_id FROM papers;").fetchall() - # get the data about cost and delivery for each paper - cost_and_delivery_data = [ - get_cost_and_delivery_data(paper_id, connection) - for paper_id, in papers # type: ignore - ] + # get the data about cost and delivery for each paper + cost_and_delivery_data = [ + get_cost_and_delivery_data(paper_id, connection) + for paper_id, in papers # type: ignore + ] - connection.close() # initialize a "blank" dictionary that will eventually contain any dates when a paper was not delivered - undelivered_dates: dict[int, set[date_type]] = { + undelivered_dates: dict[int, set[date]] = { int(paper_id): set() for paper_id, in papers # type: ignore } @@ -313,8 +311,9 @@ def calculate_cost_of_all_papers(undelivered_strings: dict[int, list[str]], mont def save_results( + connection: Connection, costs: dict[int, float], - undelivered_dates: dict[int, set[date_type]], + undelivered_dates: dict[int, set[date]], month: int, year: int, custom_timestamp: datetime | None = None @@ -323,96 +322,86 @@ def save_results( - save the dates any paper was not delivered - save the final cost of each paper""" - timestamp = (custom_timestamp if custom_timestamp else datetime.now()).strftime(r'%d/%m/%Y %I:%M:%S %p') - - with connect(DATABASE_PATH) as connection: + timestamp = (custom_timestamp or datetime.now()).strftime(r'%d/%m/%Y %I:%M:%S %p') + + # create log entries for each paper + log_ids = { + paper_id: connection.execute( + """ + INSERT INTO logs (paper_id, month, year, timestamp) + VALUES (?, ?, ?, ?) + RETURNING logs.log_id; + """, + (paper_id, month, year, timestamp) + ).fetchone()[0] + for paper_id in costs.keys() + } - # create log entries for each paper - log_ids = { - paper_id: connection.execute( - """ - INSERT INTO logs (paper_id, month, year, timestamp) - VALUES (?, ?, ?, ?) - RETURNING logs.log_id; - """, - (paper_id, month, year, timestamp) - ).fetchone()[0] - for paper_id in costs.keys() - } + # create cost entries for each paper + for paper_id, log_id in log_ids.items(): + connection.execute( + """ + INSERT INTO cost_logs (log_id, cost) + VALUES (?, ?); + """, + (log_id, costs[paper_id]) + ) - # create cost entries for each paper - for paper_id, log_id in log_ids.items(): + # create undelivered date entries for each paper + for paper_id, dates in undelivered_dates.items(): + for day in dates: connection.execute( """ - INSERT INTO cost_logs (log_id, cost) + INSERT INTO undelivered_dates_logs (log_id, date_not_delivered) VALUES (?, ?); """, - (log_id, costs[paper_id]) + (log_ids[paper_id], day.strftime("%Y-%m-%d")) ) - # create undelivered date entries for each paper - for paper_id, dates in undelivered_dates.items(): - for date in dates: - connection.execute( - """ - INSERT INTO undelivered_dates_logs (log_id, date_not_delivered) - VALUES (?, ?); - """, - (log_ids[paper_id], date.strftime("%Y-%m-%d")) - ) - - connection.close() - -def format_output(costs: dict[int, float], total: float, month: int, year: int) -> Generator[str, None, None]: +def format_output(connection: Connection, costs: dict[int, float], total: float, month: int, year: int) -> Generator[str, None, None]: """format the output of calculating the cost of all papers""" # output the name of the month for which the total cost was calculated - yield f"For {date_type(year=year, month=month, day=1).strftime(r'%B %Y')},\n" + yield f"For {date(year=year, month=month, day=1).strftime(r'%B %Y')},\n" # output the total cost of all papers yield f"*TOTAL*: {total:.2f}" # output the cost of each paper with its name - with connect(DATABASE_PATH) as connection: - papers = dict(connection.execute("SELECT paper_id, name FROM papers;").fetchall()) + papers = dict(connection.execute("SELECT paper_id, name FROM papers;").fetchall()) - for paper_id, cost in costs.items(): - yield f"{papers[paper_id]}: {cost:.2f}" + for paper_id, cost in costs.items(): + yield f"{papers[paper_id]}: {cost:.2f}" - connection.close() - -def add_new_paper(name: str, days_delivered: list[bool], days_cost: list[float]) -> None: +def add_new_paper(connection: Connection, name: str, days_delivered: list[bool], days_cost: list[float]) -> None: """add a new paper - do not allow if the paper already exists""" - with connect(DATABASE_PATH) as connection: - - # check if the paper already exists - if connection.execute( - "SELECT EXISTS (SELECT 1 FROM papers WHERE name = ?);", - (name,)).fetchone()[0]: - raise npbc_exceptions.PaperAlreadyExists(f"Paper \"{name}\" already exists." - ) - - # insert the paper - paper_id = connection.execute( - "INSERT INTO papers (name) VALUES (?) RETURNING papers.paper_id;", - (name,) - ).fetchone()[0] + # check if the paper already exists + if connection.execute( + "SELECT EXISTS (SELECT 1 FROM papers WHERE name = ?);", + (name,)).fetchone()[0]: + raise npbc_exceptions.PaperAlreadyExists(f"Paper \"{name}\" already exists." + ) - # create cost and delivered entries for each day - for day_id, (delivered, cost) in enumerate(zip(days_delivered, days_cost)): - connection.execute( - "INSERT INTO cost_and_delivery_data (paper_id, day_id, delivered, cost) VALUES (?, ?, ?, ?);", - (paper_id, day_id, delivered, cost) - ) + # insert the paper + paper_id = connection.execute( + "INSERT INTO papers (name) VALUES (?) RETURNING papers.paper_id;", + (name,) + ).fetchone()[0] - connection.close() + # create cost and delivered entries for each day + for day_id, (delivered, cost) in enumerate(zip(days_delivered, days_cost)): + connection.execute( + "INSERT INTO cost_and_delivery_data (paper_id, day_id, delivered, cost) VALUES (?, ?, ?, ?);", + (paper_id, day_id, delivered, cost) + ) def edit_existing_paper( + connection: Connection, paper_id: int, name: str | None = None, days_delivered: list[bool] | None = None, @@ -421,70 +410,62 @@ def edit_existing_paper( """edit an existing paper do not allow if the paper does not exist""" - with connect(DATABASE_PATH) as connection: - - # check if the paper exists - if not connection.execute( - "SELECT EXISTS (SELECT 1 FROM papers WHERE paper_id = ?);", - (paper_id,)).fetchone()[0]: - raise npbc_exceptions.PaperNotExists(f"Paper with ID {paper_id} does not exist." + # check if the paper exists + if not connection.execute( + "SELECT EXISTS (SELECT 1 FROM papers WHERE paper_id = ?);", + (paper_id,)).fetchone()[0]: + raise npbc_exceptions.PaperNotExists(f"Paper with ID {paper_id} does not exist." + ) + + # update the paper name + if name is not None: + connection.execute( + "UPDATE papers SET name = ? WHERE paper_id = ?;", + (name, paper_id) ) - # update the paper name - if name is not None: + # update the costs of each day + if days_cost is not None: + for day_id, cost in enumerate(days_cost): connection.execute( - "UPDATE papers SET name = ? WHERE paper_id = ?;", - (name, paper_id) + "UPDATE cost_and_delivery_data SET cost = ? WHERE paper_id = ? AND day_id = ?;", + (cost, paper_id, day_id) ) - # update the costs of each day - if days_cost is not None: - for day_id, cost in enumerate(days_cost): - connection.execute( - "UPDATE cost_and_delivery_data SET cost = ? WHERE paper_id = ? AND day_id = ?;", - (cost, paper_id, day_id) - ) - - # update the delivered status of each day - if days_delivered is not None: - for day_id, delivered in enumerate(days_delivered): - connection.execute( - "UPDATE cost_and_delivery_data SET delivered = ? WHERE paper_id = ? AND day_id = ?;", - (delivered, paper_id, day_id) - ) - - connection.close() + # update the delivered status of each day + if days_delivered is not None: + for day_id, delivered in enumerate(days_delivered): + connection.execute( + "UPDATE cost_and_delivery_data SET delivered = ? WHERE paper_id = ? AND day_id = ?;", + (delivered, paper_id, day_id) + ) -def delete_existing_paper(paper_id: int) -> None: +def delete_existing_paper(connection: Connection, paper_id: int) -> None: """delete an existing paper - do not allow if the paper does not exist""" - with connect(DATABASE_PATH) as connection: - - # check if the paper exists - if not connection.execute( - "SELECT EXISTS (SELECT 1 FROM papers WHERE paper_id = ?);", - (paper_id,)).fetchone()[0]: - raise npbc_exceptions.PaperNotExists(f"Paper with ID {paper_id} does not exist." - ) - - # delete the paper - connection.execute( - "DELETE FROM papers WHERE paper_id = ?;", - (paper_id,) - ) + # check if the paper exists + if not connection.execute( + "SELECT EXISTS (SELECT 1 FROM papers WHERE paper_id = ?);", + (paper_id,)).fetchone()[0]: + raise npbc_exceptions.PaperNotExists(f"Paper with ID {paper_id} does not exist." + ) - # delete the costs and delivery data for the paper - connection.execute( - "DELETE FROM cost_and_delivery_data WHERE paper_id = ?;", - (paper_id,) - ) + # delete the paper + connection.execute( + "DELETE FROM papers WHERE paper_id = ?;", + (paper_id,) + ) - connection.close() + # delete the costs and delivery data for the paper + connection.execute( + "DELETE FROM cost_and_delivery_data WHERE paper_id = ?;", + (paper_id,) + ) -def add_undelivered_string(month: int, year: int, paper_id: int | None = None, *undelivered_strings: str) -> None: +def add_undelivered_string(connection: Connection, month: int, year: int, paper_id: int | None = None, *undelivered_strings: str) -> None: """record strings for date(s) paper(s) were not delivered - if no paper ID is specified, all papers are assumed""" @@ -495,48 +476,42 @@ def add_undelivered_string(month: int, year: int, paper_id: int | None = None, * if paper_id: # check that specified paper exists in the database - with connect(DATABASE_PATH) as connection: - if not connection.execute( - "SELECT EXISTS (SELECT 1 FROM papers WHERE paper_id = ?);", - (paper_id,)).fetchone()[0]: - raise npbc_exceptions.PaperNotExists(f"Paper with ID {paper_id} does not exist." - ) - - # add the string(s) - params = [ - (month, year, paper_id, string) - for string in undelivered_strings - ] - - connection.executemany("INSERT INTO undelivered_strings (month, year, paper_id, string) VALUES (?, ?, ?, ?);", params) + if not connection.execute( + "SELECT EXISTS (SELECT 1 FROM papers WHERE paper_id = ?);", + (paper_id,)).fetchone()[0]: + raise npbc_exceptions.PaperNotExists(f"Paper with ID {paper_id} does not exist." + ) + + # add the string(s) + params = [ + (month, year, paper_id, string) + for string in undelivered_strings + ] - connection.close() + connection.executemany("INSERT INTO undelivered_strings (month, year, paper_id, string) VALUES (?, ?, ?, ?);", params) - # if no paper ID is given else: # get the IDs of all papers - with connect(DATABASE_PATH) as connection: - paper_ids = [ - row[0] - for row in connection.execute( - "SELECT paper_id FROM papers;" - ) - ] - - # add the string(s) - params = [ - (month, year, paper_id, string) - for paper_id in paper_ids - for string in undelivered_strings - ] + paper_ids = [ + row[0] + for row in connection.execute( + "SELECT paper_id FROM papers;" + ) + ] - connection.executemany("INSERT INTO undelivered_strings (month, year, paper_id, string) VALUES (?, ?, ?, ?);", params) + # add the string(s) + params = [ + (month, year, paper_id, string) + for paper_id in paper_ids + for string in undelivered_strings + ] - connection.close() + connection.executemany("INSERT INTO undelivered_strings (month, year, paper_id, string) VALUES (?, ?, ?, ?);", params) def delete_undelivered_string( + connection: Connection, string_id: int | None = None, string: str | None = None, paper_id: int | None = None, @@ -575,28 +550,24 @@ def delete_undelivered_string( if not parameters: raise npbc_exceptions.NoParameters("No parameters given.") - with connect(DATABASE_PATH) as connection: - - # check if the string exists - check_query = "SELECT EXISTS (SELECT 1 FROM undelivered_strings" + # check if the string exists + check_query = "SELECT EXISTS (SELECT 1 FROM undelivered_strings" - conditions = ' AND '.join( - f"{parameter} = ?" - for parameter in parameters - ) + conditions = ' AND '.join( + f"{parameter} = ?" + for parameter in parameters + ) - if (1,) not in connection.execute(f"{check_query} WHERE {conditions});", values).fetchall(): - raise npbc_exceptions.StringNotExists("String with given parameters does not exist.") + if (1,) not in connection.execute(f"{check_query} WHERE {conditions});", values).fetchall(): + raise npbc_exceptions.StringNotExists("String with given parameters does not exist.") - # if the string did exist, delete it - delete_query = "DELETE FROM undelivered_strings" + # if the string did exist, delete it + delete_query = "DELETE FROM undelivered_strings" - connection.execute(f"{delete_query} WHERE {conditions};", values) + connection.execute(f"{delete_query} WHERE {conditions};", values) - connection.close() - -def get_papers() -> list[tuple[int, str, int, int, float]]: +def get_papers(connection: Connection) -> list[tuple[int, str, int, int, float]]: """get all papers - returns a list of tuples containing the following fields: paper_id, paper_name, day_id, paper_delivered, paper_cost""" @@ -610,15 +581,13 @@ def get_papers() -> list[tuple[int, str, int, int, float]]: ORDER BY papers.paper_id, cost_and_delivery_data.day_id; """ - with connect(DATABASE_PATH) as connection: - raw_data = connection.execute(query).fetchall() - - connection.close() + raw_data = connection.execute(query).fetchall() return raw_data def get_undelivered_strings( + connection: Connection, string_id: int | None = None, month: int | None = None, year: int | None = None, @@ -658,24 +627,21 @@ def get_undelivered_strings( values.append(string) - with connect(DATABASE_PATH) as connection: - - # generate the SQL query - main_query = "SELECT string_id, paper_id, year, month, string FROM undelivered_strings" - - if not parameters: - query = f"{main_query};" + # generate the SQL query + main_query = "SELECT string_id, paper_id, year, month, string FROM undelivered_strings" + + if not parameters: + query = f"{main_query};" - else: - conditions = ' AND '.join( - f"{parameter} = ?" - for parameter in parameters - ) + else: + conditions = ' AND '.join( + f"{parameter} = ?" + for parameter in parameters + ) - query = f"{main_query} WHERE {conditions};" + query = f"{main_query} WHERE {conditions};" - data = connection.execute(query, values).fetchall() - connection.close() + data = connection.execute(query, values).fetchall() # if no data was found, raise an error if not data: @@ -685,11 +651,12 @@ def get_undelivered_strings( def get_logged_data( + connection: Connection, query_paper_id: int | None = None, query_log_id: int | None = None, query_month: int | None = None, query_year: int | None = None, - query_timestamp: date_type | None = None + query_timestamp: date | None = None ) -> Generator[tuple[int, int, int, int, str, str | float], None, None]: """get logged data - the user may specify as parameters many as they want @@ -743,26 +710,22 @@ def get_logged_data( dates_query = "SELECT log_id, date_not_delivered FROM undelivered_dates_logs;" costs_query = "SELECT log_id, cost FROM cost_logs;" - with connect(DATABASE_PATH) as connection: - logs = { - log_id: [paper_id, month, year, timestamp] - for log_id, paper_id, timestamp, month, year in connection.execute(logs_query, values).fetchall() - } - - dates = connection.execute(dates_query).fetchall() - costs = connection.execute(costs_query).fetchall() + logs = { + log_id: [paper_id, month, year, timestamp] + for log_id, paper_id, timestamp, month, year in connection.execute(logs_query, values).fetchall() + } - for log_id, date in dates: - yield tuple(logs[log_id] + [date]) + dates = connection.execute(dates_query).fetchall() + costs = connection.execute(costs_query).fetchall() - for log_id, cost in costs: - yield tuple(logs[log_id] + [float(cost)]) - - connection.close() + for log_id, date_undelivered in dates: + yield tuple(logs[log_id] + [date_undelivered]) + for log_id, cost in costs: + yield tuple(logs[log_id] + [float(cost)]) -def get_previous_month() -> date_type: +def get_previous_month() -> date: """get the previous month, by looking at 1 day before the first day of the current month (duh)""" return (datetime.today().replace(day=1) - timedelta(days=1)).replace(day=1) diff --git a/npbc_regex.py b/npbc_regex.py index 3b083e5..a2fd34d 100644 --- a/npbc_regex.py +++ b/npbc_regex.py @@ -5,9 +5,9 @@ """ from calendar import day_name as WEEKDAY_NAMES_ITERABLE +from re import IGNORECASE from re import compile as compile_regex - ## regex used to match against strings # match for a list of comma separated values. each value must be/contain digits, or letters, or hyphens. spaces are allowed between values and commas. any number of values are allowed, but at least one must be present. @@ -26,7 +26,7 @@ N_DAY_MATCH_REGEX = compile_regex(f"^\\d *- *({'|'.join(map(lambda x: x.lower(), WEEKDAY_NAMES_ITERABLE))})$") # match for the text "all" in any case. -ALL_MATCH_REGEX = compile_regex(r'^[aA][lL]{2}$') +ALL_MATCH_REGEX = compile_regex(r'^all$', IGNORECASE) # match for seven values, each of which must be a 'Y' or an 'N'. there are no delimiters. DELIVERY_MATCH_REGEX = compile_regex(r'^[YN]{7}$') @@ -35,4 +35,4 @@ ## regex used to split strings # split on hyphens. spaces are allowed between hyphens and values. -HYPHEN_SPLIT_REGEX = compile_regex(r' *- *') \ No newline at end of file +HYPHEN_SPLIT_REGEX = compile_regex(r' *- *') diff --git a/test_core.py b/test_core.py index cb1ce9a..f323e5d 100644 --- a/test_core.py +++ b/test_core.py @@ -3,7 +3,7 @@ - none of these depend on data in the database """ -from datetime import date as date_type +from datetime import date from numpy import array from pytest import raises @@ -73,114 +73,114 @@ def test_undelivered_string_parsing(): assert test_function(MONTH, YEAR, '') == set([]) assert test_function(MONTH, YEAR, '1') == set([ - date_type(year=YEAR, month=MONTH, day=1) + date(year=YEAR, month=MONTH, day=1) ]) assert test_function(MONTH, YEAR, '1-2') == set([ - date_type(year=YEAR, month=MONTH, day=1), - date_type(year=YEAR, month=MONTH, day=2) + date(year=YEAR, month=MONTH, day=1), + date(year=YEAR, month=MONTH, day=2) ]) assert test_function(MONTH, YEAR, '5-17') == set([ - date_type(year=YEAR, month=MONTH, day=5), - date_type(year=YEAR, month=MONTH, day=6), - date_type(year=YEAR, month=MONTH, day=7), - date_type(year=YEAR, month=MONTH, day=8), - date_type(year=YEAR, month=MONTH, day=9), - date_type(year=YEAR, month=MONTH, day=10), - date_type(year=YEAR, month=MONTH, day=11), - date_type(year=YEAR, month=MONTH, day=12), - date_type(year=YEAR, month=MONTH, day=13), - date_type(year=YEAR, month=MONTH, day=14), - date_type(year=YEAR, month=MONTH, day=15), - date_type(year=YEAR, month=MONTH, day=16), - date_type(year=YEAR, month=MONTH, day=17) + date(year=YEAR, month=MONTH, day=5), + date(year=YEAR, month=MONTH, day=6), + date(year=YEAR, month=MONTH, day=7), + date(year=YEAR, month=MONTH, day=8), + date(year=YEAR, month=MONTH, day=9), + date(year=YEAR, month=MONTH, day=10), + date(year=YEAR, month=MONTH, day=11), + date(year=YEAR, month=MONTH, day=12), + date(year=YEAR, month=MONTH, day=13), + date(year=YEAR, month=MONTH, day=14), + date(year=YEAR, month=MONTH, day=15), + date(year=YEAR, month=MONTH, day=16), + date(year=YEAR, month=MONTH, day=17) ]) assert test_function(MONTH, YEAR, '5-17', '19') == set([ - date_type(year=YEAR, month=MONTH, day=5), - date_type(year=YEAR, month=MONTH, day=6), - date_type(year=YEAR, month=MONTH, day=7), - date_type(year=YEAR, month=MONTH, day=8), - date_type(year=YEAR, month=MONTH, day=9), - date_type(year=YEAR, month=MONTH, day=10), - date_type(year=YEAR, month=MONTH, day=11), - date_type(year=YEAR, month=MONTH, day=12), - date_type(year=YEAR, month=MONTH, day=13), - date_type(year=YEAR, month=MONTH, day=14), - date_type(year=YEAR, month=MONTH, day=15), - date_type(year=YEAR, month=MONTH, day=16), - date_type(year=YEAR, month=MONTH, day=17), - date_type(year=YEAR, month=MONTH, day=19) + date(year=YEAR, month=MONTH, day=5), + date(year=YEAR, month=MONTH, day=6), + date(year=YEAR, month=MONTH, day=7), + date(year=YEAR, month=MONTH, day=8), + date(year=YEAR, month=MONTH, day=9), + date(year=YEAR, month=MONTH, day=10), + date(year=YEAR, month=MONTH, day=11), + date(year=YEAR, month=MONTH, day=12), + date(year=YEAR, month=MONTH, day=13), + date(year=YEAR, month=MONTH, day=14), + date(year=YEAR, month=MONTH, day=15), + date(year=YEAR, month=MONTH, day=16), + date(year=YEAR, month=MONTH, day=17), + date(year=YEAR, month=MONTH, day=19) ]) assert test_function(MONTH, YEAR, '5-17', '19-21') == set([ - date_type(year=YEAR, month=MONTH, day=5), - date_type(year=YEAR, month=MONTH, day=6), - date_type(year=YEAR, month=MONTH, day=7), - date_type(year=YEAR, month=MONTH, day=8), - date_type(year=YEAR, month=MONTH, day=9), - date_type(year=YEAR, month=MONTH, day=10), - date_type(year=YEAR, month=MONTH, day=11), - date_type(year=YEAR, month=MONTH, day=12), - date_type(year=YEAR, month=MONTH, day=13), - date_type(year=YEAR, month=MONTH, day=14), - date_type(year=YEAR, month=MONTH, day=15), - date_type(year=YEAR, month=MONTH, day=16), - date_type(year=YEAR, month=MONTH, day=17), - date_type(year=YEAR, month=MONTH, day=19), - date_type(year=YEAR, month=MONTH, day=20), - date_type(year=YEAR, month=MONTH, day=21) + date(year=YEAR, month=MONTH, day=5), + date(year=YEAR, month=MONTH, day=6), + date(year=YEAR, month=MONTH, day=7), + date(year=YEAR, month=MONTH, day=8), + date(year=YEAR, month=MONTH, day=9), + date(year=YEAR, month=MONTH, day=10), + date(year=YEAR, month=MONTH, day=11), + date(year=YEAR, month=MONTH, day=12), + date(year=YEAR, month=MONTH, day=13), + date(year=YEAR, month=MONTH, day=14), + date(year=YEAR, month=MONTH, day=15), + date(year=YEAR, month=MONTH, day=16), + date(year=YEAR, month=MONTH, day=17), + date(year=YEAR, month=MONTH, day=19), + date(year=YEAR, month=MONTH, day=20), + date(year=YEAR, month=MONTH, day=21) ]) assert test_function(MONTH, YEAR, '5-17', '19-21', '23') == set([ - date_type(year=YEAR, month=MONTH, day=5), - date_type(year=YEAR, month=MONTH, day=6), - date_type(year=YEAR, month=MONTH, day=7), - date_type(year=YEAR, month=MONTH, day=8), - date_type(year=YEAR, month=MONTH, day=9), - date_type(year=YEAR, month=MONTH, day=10), - date_type(year=YEAR, month=MONTH, day=11), - date_type(year=YEAR, month=MONTH, day=12), - date_type(year=YEAR, month=MONTH, day=13), - date_type(year=YEAR, month=MONTH, day=14), - date_type(year=YEAR, month=MONTH, day=15), - date_type(year=YEAR, month=MONTH, day=16), - date_type(year=YEAR, month=MONTH, day=17), - date_type(year=YEAR, month=MONTH, day=19), - date_type(year=YEAR, month=MONTH, day=20), - date_type(year=YEAR, month=MONTH, day=21), - date_type(year=YEAR, month=MONTH, day=23) + date(year=YEAR, month=MONTH, day=5), + date(year=YEAR, month=MONTH, day=6), + date(year=YEAR, month=MONTH, day=7), + date(year=YEAR, month=MONTH, day=8), + date(year=YEAR, month=MONTH, day=9), + date(year=YEAR, month=MONTH, day=10), + date(year=YEAR, month=MONTH, day=11), + date(year=YEAR, month=MONTH, day=12), + date(year=YEAR, month=MONTH, day=13), + date(year=YEAR, month=MONTH, day=14), + date(year=YEAR, month=MONTH, day=15), + date(year=YEAR, month=MONTH, day=16), + date(year=YEAR, month=MONTH, day=17), + date(year=YEAR, month=MONTH, day=19), + date(year=YEAR, month=MONTH, day=20), + date(year=YEAR, month=MONTH, day=21), + date(year=YEAR, month=MONTH, day=23) ]) assert test_function(MONTH, YEAR, 'mondays') == set([ - date_type(year=YEAR, month=MONTH, day=1), - date_type(year=YEAR, month=MONTH, day=8), - date_type(year=YEAR, month=MONTH, day=15), - date_type(year=YEAR, month=MONTH, day=22), - date_type(year=YEAR, month=MONTH, day=29) + date(year=YEAR, month=MONTH, day=1), + date(year=YEAR, month=MONTH, day=8), + date(year=YEAR, month=MONTH, day=15), + date(year=YEAR, month=MONTH, day=22), + date(year=YEAR, month=MONTH, day=29) ]) assert test_function(MONTH, YEAR, 'mondays', 'wednesdays') == set([ - date_type(year=YEAR, month=MONTH, day=1), - date_type(year=YEAR, month=MONTH, day=8), - date_type(year=YEAR, month=MONTH, day=15), - date_type(year=YEAR, month=MONTH, day=22), - date_type(year=YEAR, month=MONTH, day=29), - date_type(year=YEAR, month=MONTH, day=3), - date_type(year=YEAR, month=MONTH, day=10), - date_type(year=YEAR, month=MONTH, day=17), - date_type(year=YEAR, month=MONTH, day=24), - date_type(year=YEAR, month=MONTH, day=31) + date(year=YEAR, month=MONTH, day=1), + date(year=YEAR, month=MONTH, day=8), + date(year=YEAR, month=MONTH, day=15), + date(year=YEAR, month=MONTH, day=22), + date(year=YEAR, month=MONTH, day=29), + date(year=YEAR, month=MONTH, day=3), + date(year=YEAR, month=MONTH, day=10), + date(year=YEAR, month=MONTH, day=17), + date(year=YEAR, month=MONTH, day=24), + date(year=YEAR, month=MONTH, day=31) ]) assert test_function(MONTH, YEAR, '2-monday') == set([ - date_type(year=YEAR, month=MONTH, day=8) + date(year=YEAR, month=MONTH, day=8) ]) assert test_function(MONTH, YEAR, '2-monday', '3-wednesday') == set([ - date_type(year=YEAR, month=MONTH, day=8), - date_type(year=YEAR, month=MONTH, day=17) + date(year=YEAR, month=MONTH, day=8), + date(year=YEAR, month=MONTH, day=17) ]) @@ -210,7 +210,7 @@ def test_calculating_cost_of_one_paper(): assert test_function( DAYS_PER_WEEK, set([ - date_type(year=2022, month=1, day=8) + date(year=2022, month=1, day=8) ]), *COST_AND_DELIVERY_DATA ) == 41 @@ -218,8 +218,8 @@ def test_calculating_cost_of_one_paper(): assert test_function( DAYS_PER_WEEK, set([ - date_type(year=2022, month=1, day=8), - date_type(year=2022, month=1, day=8) + date(year=2022, month=1, day=8), + date(year=2022, month=1, day=8) ]), *COST_AND_DELIVERY_DATA ) == 41 @@ -227,8 +227,8 @@ def test_calculating_cost_of_one_paper(): assert test_function( DAYS_PER_WEEK, set([ - date_type(year=2022, month=1, day=8), - date_type(year=2022, month=1, day=17) + date(year=2022, month=1, day=8), + date(year=2022, month=1, day=17) ]), *COST_AND_DELIVERY_DATA ) == 41 @@ -236,7 +236,7 @@ def test_calculating_cost_of_one_paper(): assert test_function( DAYS_PER_WEEK, set([ - date_type(year=2022, month=1, day=2) + date(year=2022, month=1, day=2) ]), *COST_AND_DELIVERY_DATA ) == 40 @@ -244,8 +244,8 @@ def test_calculating_cost_of_one_paper(): assert test_function( DAYS_PER_WEEK, set([ - date_type(year=2022, month=1, day=2), - date_type(year=2022, month=1, day=2) + date(year=2022, month=1, day=2), + date(year=2022, month=1, day=2) ]), *COST_AND_DELIVERY_DATA ) == 40 @@ -253,8 +253,8 @@ def test_calculating_cost_of_one_paper(): assert test_function( DAYS_PER_WEEK, set([ - date_type(year=2022, month=1, day=6), - date_type(year=2022, month=1, day=7) + date(year=2022, month=1, day=6), + date(year=2022, month=1, day=7) ]), *COST_AND_DELIVERY_DATA ) == 34 @@ -262,9 +262,9 @@ def test_calculating_cost_of_one_paper(): assert test_function( DAYS_PER_WEEK, set([ - date_type(year=2022, month=1, day=6), - date_type(year=2022, month=1, day=7), - date_type(year=2022, month=1, day=8) + date(year=2022, month=1, day=6), + date(year=2022, month=1, day=7), + date(year=2022, month=1, day=8) ]), *COST_AND_DELIVERY_DATA ) == 34 @@ -272,13 +272,13 @@ def test_calculating_cost_of_one_paper(): assert test_function( DAYS_PER_WEEK, set([ - date_type(year=2022, month=1, day=6), - date_type(year=2022, month=1, day=7), - date_type(year=2022, month=1, day=7), - date_type(year=2022, month=1, day=7), - date_type(year=2022, month=1, day=8), - date_type(year=2022, month=1, day=8), - date_type(year=2022, month=1, day=8) + date(year=2022, month=1, day=6), + date(year=2022, month=1, day=7), + date(year=2022, month=1, day=7), + date(year=2022, month=1, day=7), + date(year=2022, month=1, day=8), + date(year=2022, month=1, day=8), + date(year=2022, month=1, day=8) ]), *COST_AND_DELIVERY_DATA ) == 34 diff --git a/test_db.py b/test_db.py index 7dc32c1..9b356d1 100644 --- a/test_db.py +++ b/test_db.py @@ -8,6 +8,7 @@ from datetime import date, datetime +from multiprocessing.connection import Connection from pathlib import Path from sqlite3 import connect from typing import Counter @@ -19,7 +20,7 @@ import npbc_exceptions ACTIVE_DIRECTORY = Path("data") -DATABASE_PATH = ACTIVE_DIRECTORY / "npbc.db" +DATABASE_PATH = ACTIVE_DIRECTORY / "npbc.sqlite" SCHEMA_PATH = ACTIVE_DIRECTORY / "schema.sql" TEST_SQL = ACTIVE_DIRECTORY / "test.sql" @@ -27,13 +28,13 @@ def setup_db(): DATABASE_PATH.unlink(missing_ok=True) - with connect(DATABASE_PATH) as connection: - connection.executescript(SCHEMA_PATH.read_text()) - connection.commit() - connection.executescript(TEST_SQL.read_text()) - - connection.close() + connection = connect(DATABASE_PATH) + connection.executescript(SCHEMA_PATH.read_text()) + connection.commit() + connection.executescript(TEST_SQL.read_text()) + connection.commit() + return connection def test_db_creation(): DATABASE_PATH.unlink(missing_ok=True) @@ -49,7 +50,7 @@ def test_db_creation(): def test_get_papers(): - setup_db() + connection = setup_db() known_data = [ (1, 'paper1', 0, 0, 0), @@ -75,11 +76,12 @@ def test_get_papers(): (3, 'paper3', 6, 1, 6) ] - assert Counter(npbc_core.get_papers()) == Counter(known_data) + assert Counter(npbc_core.get_papers(connection)) == Counter(known_data) + connection.close() def test_get_undelivered_strings(): - setup_db() + connection = setup_db() known_data = [ (1, 1, 2020, 11, '5'), @@ -89,20 +91,22 @@ def test_get_undelivered_strings(): (5, 3, 2020, 10, 'all') ] - assert Counter(npbc_core.get_undelivered_strings()) == Counter(known_data) - assert Counter(npbc_core.get_undelivered_strings(string_id=3)) == Counter([known_data[2]]) - assert Counter(npbc_core.get_undelivered_strings(month=11)) == Counter(known_data[:4]) - assert Counter(npbc_core.get_undelivered_strings(paper_id=1)) == Counter(known_data[:2]) - assert Counter(npbc_core.get_undelivered_strings(paper_id=1, string='6-12')) == Counter([known_data[1]]) + assert Counter(npbc_core.get_undelivered_strings(connection)) == Counter(known_data) + assert Counter(npbc_core.get_undelivered_strings(connection, string_id=3)) == Counter([known_data[2]]) + assert Counter(npbc_core.get_undelivered_strings(connection, month=11)) == Counter(known_data[:4]) + assert Counter(npbc_core.get_undelivered_strings(connection, paper_id=1)) == Counter(known_data[:2]) + assert Counter(npbc_core.get_undelivered_strings(connection, paper_id=1, string='6-12')) == Counter([known_data[1]]) with raises(npbc_exceptions.StringNotExists): - npbc_core.get_undelivered_strings(year=1986) + npbc_core.get_undelivered_strings(connection, year=1986) + + connection.close() def test_delete_paper(): - setup_db() + connection = setup_db() - npbc_core.delete_existing_paper(2) + npbc_core.delete_existing_paper(connection, 2) known_data = [ (1, 'paper1', 0, 0, 0), @@ -121,15 +125,17 @@ def test_delete_paper(): (3, 'paper3', 6, 1, 6) ] - assert Counter(npbc_core.get_papers()) == Counter(known_data) + assert Counter(npbc_core.get_papers(connection)) == Counter(known_data) with raises(npbc_exceptions.PaperNotExists): - npbc_core.delete_existing_paper(7) - npbc_core.delete_existing_paper(2) + npbc_core.delete_existing_paper(connection, 7) + npbc_core.delete_existing_paper(connection, 2) + + connection.close() def test_add_paper(): - setup_db() + connection = setup_db() known_data = [ (1, 'paper1', 0, 0, 0), @@ -163,23 +169,27 @@ def test_add_paper(): ] npbc_core.add_new_paper( + connection, 'paper4', [True, False, True, False, False, True, True], [4, 0, 2.6, 0, 0, 1, 7] ) - assert Counter(npbc_core.get_papers()) == Counter(known_data) + assert Counter(npbc_core.get_papers(connection)) == Counter(known_data) with raises(npbc_exceptions.PaperAlreadyExists): npbc_core.add_new_paper( + connection, 'paper4', [True, False, True, False, False, True, True], [4, 0, 2.6, 0, 0, 1, 7] ) + connection.close() + def test_edit_paper(): - setup_db() + connection = setup_db() known_data = [ (1, 'paper1', 0, 0, 0), @@ -206,6 +216,7 @@ def test_edit_paper(): ] npbc_core.edit_existing_paper( + connection, 1, days_delivered=[True, False, True, False, False, True, True], days_cost=[6.4, 0, 0, 0, 0, 7.9, 4] @@ -219,9 +230,10 @@ def test_edit_paper(): known_data[5] = (1, 'paper1', 5, 1, 7.9) known_data[6] = (1, 'paper1', 6, 1, 4) - assert Counter(npbc_core.get_papers()) == Counter(known_data) + assert Counter(npbc_core.get_papers(connection)) == Counter(known_data) npbc_core.edit_existing_paper( + connection, 3, name="New paper" ) @@ -234,10 +246,12 @@ def test_edit_paper(): known_data[19] = (3, 'New paper', 5, 1, 4.6) known_data[20] = (3, 'New paper', 6, 1, 6) - assert Counter(npbc_core.get_papers()) == Counter(known_data) + assert Counter(npbc_core.get_papers(connection)) == Counter(known_data) with raises(npbc_exceptions.PaperNotExists): - npbc_core.edit_existing_paper(7, name="New paper") + npbc_core.edit_existing_paper(connection, 7, name="New paper") + + connection.close() def test_delete_string(): @@ -249,29 +263,33 @@ def test_delete_string(): (5, 3, 2020, 10, 'all') ] - setup_db() - npbc_core.delete_undelivered_string(string='all') - assert Counter(npbc_core.get_undelivered_strings()) == Counter(known_data[:4]) - - setup_db() - npbc_core.delete_undelivered_string(month=11) - assert Counter(npbc_core.get_undelivered_strings()) == Counter([known_data[4]]) + connection = setup_db() + npbc_core.delete_undelivered_string(connection, string='all') + assert Counter(npbc_core.get_undelivered_strings(connection)) == Counter(known_data[:4]) + connection.close() - setup_db() - npbc_core.delete_undelivered_string(paper_id=1) - assert Counter(npbc_core.get_undelivered_strings()) == Counter(known_data[2:]) + connection = setup_db() + npbc_core.delete_undelivered_string(connection, month=11) + assert Counter(npbc_core.get_undelivered_strings(connection)) == Counter([known_data[4]]) + connection.close() - setup_db() + connection = setup_db() + npbc_core.delete_undelivered_string(connection, paper_id=1) + assert Counter(npbc_core.get_undelivered_strings(connection)) == Counter(known_data[2:]) + connection.close() + connection = setup_db() with raises(npbc_exceptions.StringNotExists): - npbc_core.delete_undelivered_string(string='not exists') + npbc_core.delete_undelivered_string(connection, string='not exists') with raises(npbc_exceptions.NoParameters): - npbc_core.delete_undelivered_string() + npbc_core.delete_undelivered_string(connection) + + connection.close() def test_add_string(): - setup_db() + connection = setup_db() known_data = [ (1, 1, 2020, 11, '5'), @@ -281,19 +299,21 @@ def test_add_string(): (5, 3, 2020, 10, 'all') ] - npbc_core.add_undelivered_string(4, 2017, 3, 'sundays') + npbc_core.add_undelivered_string(connection, 4, 2017, 3, 'sundays') known_data.append((6, 3, 2017, 4, 'sundays')) - assert Counter(npbc_core.get_undelivered_strings()) == Counter(known_data) + assert Counter(npbc_core.get_undelivered_strings(connection)) == Counter(known_data) - npbc_core.add_undelivered_string(9, 2017, None, '11') + npbc_core.add_undelivered_string(connection, 9, 2017, None, '11') known_data.append((7, 1, 2017, 9, '11')) known_data.append((8, 2, 2017, 9, '11')) known_data.append((9, 3, 2017, 9, '11')) - assert Counter(npbc_core.get_undelivered_strings()) == Counter(known_data) + assert Counter(npbc_core.get_undelivered_strings(connection)) == Counter(known_data) + + connection.close() def test_save_results(): - setup_db() + connection = setup_db() known_data = [ (1, 1, 2020, '04/01/2022 01:05:42 AM', '2020-01-01'), @@ -307,6 +327,7 @@ def test_save_results(): ] npbc_core.save_results( + connection, {1: 105, 2: 51, 3: 647}, { 1: set([date(month=1, day=1, year=2020), date(month=1, day=2, year=2020)]), @@ -318,4 +339,6 @@ def test_save_results(): datetime(year=2022, month=1, day=4, hour=1, minute=5, second=42) ) - assert Counter(npbc_core.get_logged_data()) == Counter(known_data) + assert Counter(npbc_core.get_logged_data(connection)) == Counter(known_data) + + connection.close()