From 835584c27fb91b93b708f427f9f2202e990cbe62 Mon Sep 17 00:00:00 2001 From: Lawrence Adu-Gyamfi Date: Fri, 25 Oct 2024 00:24:18 +0200 Subject: [PATCH 1/2] [FEAT]: basic cleanup of package --- .flake8 | 26 +++++ Makefile | 46 ++++++++ khaya/__init__.py | 2 + khaya/asr_api.py | 2 +- khaya/base_api.py | 3 +- khaya/khaya_interface.py | 8 +- khaya/translation_api.py | 2 +- khaya/tts_api.py | 3 +- poetry.lock | 166 +++++++++++++++++++++++++++- pyproject.toml | 3 + tests/khaya/conftest.py | 4 +- tests/khaya/test_khaya_interface.py | 2 - 12 files changed, 255 insertions(+), 12 deletions(-) create mode 100644 .flake8 create mode 100644 Makefile diff --git a/.flake8 b/.flake8 new file mode 100644 index 0000000..ec1b4d0 --- /dev/null +++ b/.flake8 @@ -0,0 +1,26 @@ +[flake8] +max-line-length = 120 +max-complexity = 10 +select = C,E,F,W,B,B950 +count = True +statistics = True +ignore = + E203, + E302, + E501, + W503, + D100, + D101, + D102, + D103, + D104, + D105, + D106, + D107 +exclude = + .git, + __pycache__, + *.egg-info, + .pytest_cache, + .mypy_cache, + notebooks/* \ No newline at end of file diff --git a/Makefile b/Makefile new file mode 100644 index 0000000..5e269a4 --- /dev/null +++ b/Makefile @@ -0,0 +1,46 @@ +.DEFAULT_GOAL := help +.PHONY: help lint test + +SRD_DIR = khaya +TEST_DIR = tests + +ci: lint typecheck test ## Run all CI checks +.PHONY: ci + +lint: ## Run linting + poetry run isort $(SRC_DIR) $(TEST_DIR) + poetry run flake8 --max-line-length 120 $(SRC_DIR) $(TEST_DIR) +.PHONY: lint + + +typecheck: ## Run type checking + poetry run mypy $(SRD_DIR) +.PHONY: typecheck + +test: ## Run tests + poetry run coverage run --source=$(SRD_DIR) -m pytest -v $(TEST_DIR) && poetry run coverage report -m +.PHONY: test + +clean-py: ## Remove python cache files + find . -name '__pycache__' -type d -exec rm -r {} + + find . -name '*.pyc' -type f -exec rm {} + + rm -rf ./*.egg-info +.PHONY: clean-py + +.PHONY: help +help: ## Show help message + @IFS=$$'\n' ; \ + help_lines=(`fgrep -h "##" $(MAKEFILE_LIST) | fgrep -v fgrep | sed -e 's/\\$$//' | sed -e 's/##/:/'`); \ + printf "%s\n\n" "Usage: make [task]"; \ + printf "%-20s %s\n" "task" "help" ; \ + printf "%-20s %s\n" "------" "----" ; \ + for help_line in $${help_lines[@]}; do \ + IFS=$$':' ; \ + help_split=($$help_line) ; \ + help_command=`echo $${help_split[0]} | sed -e 's/^ *//' -e 's/ *$$//'` ; \ + help_info=`echo $${help_split[2]} | sed -e 's/^ *//' -e 's/ *$$//'` ; \ + printf '\033[36m'; \ + printf "%-20s %s" $$help_command ; \ + printf '\033[0m'; \ + printf "%s\n" $$help_info; \ + done diff --git a/khaya/__init__.py b/khaya/__init__.py index b3dc2ea..b491637 100644 --- a/khaya/__init__.py +++ b/khaya/__init__.py @@ -1 +1,3 @@ from .khaya_interface import KhayaInterface as khayaAPI + +__all__ = ["khayaAPI"] \ No newline at end of file diff --git a/khaya/asr_api.py b/khaya/asr_api.py index 3aad0b8..fb83103 100644 --- a/khaya/asr_api.py +++ b/khaya/asr_api.py @@ -4,7 +4,7 @@ class AsrApi(BaseApi): - def transcribe(self, audio_file_path: str, language="tw") -> Response: + def transcribe(self, audio_file_path: str, language="tw") -> Response | dict[str, str]: """ Convert speech to text from audio binary data in an African language using the GhanaNLP STT API. diff --git a/khaya/base_api.py b/khaya/base_api.py index ce051e2..75084c1 100644 --- a/khaya/base_api.py +++ b/khaya/base_api.py @@ -1,5 +1,6 @@ from abc import ABC from typing import Optional + import requests @@ -13,7 +14,7 @@ def __init__(self, api_key: str, base_url: Optional[str] = None): "Cache-Control": "no-cache", } - def _make_request(self, method: str, url: str, **kwargs) -> requests.Response: + def _make_request(self, method: str, url: str, **kwargs) -> requests.Response | dict[str, str]: """ Make an HTTP request. diff --git a/khaya/khaya_interface.py b/khaya/khaya_interface.py index 807f4ce..c0a0831 100644 --- a/khaya/khaya_interface.py +++ b/khaya/khaya_interface.py @@ -6,6 +6,8 @@ from khaya.translation_api import TranslationApi from khaya.tts_api import TtsApi +# custom type hint for Response or dict[str, str] +ResponseOrDict = Response | dict[str, str] class KhayaInterface: """ @@ -58,7 +60,7 @@ def __init__(self, api_key: str, base_url: Optional[str] = "https://translation- self.asr_api = AsrApi(api_key, base_url) self.tts_api = TtsApi(api_key, base_url) - def translate(self, text: str, language_pair: str = "en-tw") -> Response: + def translate(self, text: str, language_pair: str = "en-tw") -> ResponseOrDict: """ Translate text from one language to another. @@ -72,7 +74,7 @@ def translate(self, text: str, language_pair: str = "en-tw") -> Response: return self.translation_api.translate(text, language_pair) - def asr(self, audio_file_path: str, language: str = "tw") -> Response: + def asr(self, audio_file_path: str, language: str = "tw") -> ResponseOrDict: """ Get the transcription of an audio file from a given language. @@ -85,7 +87,7 @@ def asr(self, audio_file_path: str, language: str = "tw") -> Response: """ return self.asr_api.transcribe(audio_file_path, language) - def tts(self, text: str, lang: str) -> Response: + def tts(self, text: str, lang: str) -> ResponseOrDict: """ Synthesize speech from text. diff --git a/khaya/translation_api.py b/khaya/translation_api.py index b425e06..1105791 100644 --- a/khaya/translation_api.py +++ b/khaya/translation_api.py @@ -4,7 +4,7 @@ class TranslationApi(BaseApi): - def translate(self, text: str, language_pair: str = "en-tw") -> Response: + def translate(self, text: str, language_pair: str = "en-tw") -> Response | dict[str, str]: """ Translate text from one language to another using the GhanaNLP translation API. diff --git a/khaya/tts_api.py b/khaya/tts_api.py index 0338fa1..e8a88a3 100644 --- a/khaya/tts_api.py +++ b/khaya/tts_api.py @@ -1,11 +1,12 @@ import json + from requests.models import Response from khaya.base_api import BaseApi class TtsApi(BaseApi): - def synthesize(self, text: str, lang: str) -> Response: + def synthesize(self, text: str, lang: str) -> Response | dict[str, str]: """ Convert text to speech in a specified African language using the GhanaNLP TTS API. diff --git a/poetry.lock b/poetry.lock index c7b3725..281ba03 100644 --- a/poetry.lock +++ b/poetry.lock @@ -201,6 +201,17 @@ files = [ [package.dependencies] pycparser = "*" +[[package]] +name = "cfgv" +version = "3.4.0" +description = "Validate configuration and produce human readable error messages." +optional = false +python-versions = ">=3.8" +files = [ + {file = "cfgv-3.4.0-py2.py3-none-any.whl", hash = "sha256:b7265b1f29fd3316bfcd2b330d63d024f2bfd8bcb8b0272f8e19a504856c48f9"}, + {file = "cfgv-3.4.0.tar.gz", hash = "sha256:e52591d4c5f5dead8e0f673fb16db7949d2cfb3f7da4582893288f0ded8fe560"}, +] + [[package]] name = "charset-normalizer" version = "3.4.0" @@ -488,6 +499,17 @@ files = [ {file = "defusedxml-0.7.1.tar.gz", hash = "sha256:1bb3032db185915b62d7c6209c5a8792be6a32ab2fedacc84e01b52c51aa3e69"}, ] +[[package]] +name = "distlib" +version = "0.3.9" +description = "Distribution utilities" +optional = false +python-versions = "*" +files = [ + {file = "distlib-0.3.9-py2.py3-none-any.whl", hash = "sha256:47f8c22fd27c27e25a65601af709b38e4f0a45ea4fc2e710f65755fa8caaaf87"}, + {file = "distlib-0.3.9.tar.gz", hash = "sha256:a60f20dea646b8a33f3e7772f74dc0b2d0772d2837ee1342a00645c81edf9403"}, +] + [[package]] name = "executing" version = "2.1.0" @@ -516,6 +538,38 @@ files = [ [package.extras] devel = ["colorama", "json-spec", "jsonschema", "pylint", "pytest", "pytest-benchmark", "pytest-cache", "validictory"] +[[package]] +name = "filelock" +version = "3.16.1" +description = "A platform independent file lock." +optional = false +python-versions = ">=3.8" +files = [ + {file = "filelock-3.16.1-py3-none-any.whl", hash = "sha256:2082e5703d51fbf98ea75855d9d5527e33d8ff23099bec374a134febee6946b0"}, + {file = "filelock-3.16.1.tar.gz", hash = "sha256:c249fbfcd5db47e5e2d6d62198e565475ee65e4831e2561c8e313fa7eb961435"}, +] + +[package.extras] +docs = ["furo (>=2024.8.6)", "sphinx (>=8.0.2)", "sphinx-autodoc-typehints (>=2.4.1)"] +testing = ["covdefaults (>=2.3)", "coverage (>=7.6.1)", "diff-cover (>=9.2)", "pytest (>=8.3.3)", "pytest-asyncio (>=0.24)", "pytest-cov (>=5)", "pytest-mock (>=3.14)", "pytest-timeout (>=2.3.1)", "virtualenv (>=20.26.4)"] +typing = ["typing-extensions (>=4.12.2)"] + +[[package]] +name = "flake8" +version = "7.1.1" +description = "the modular source code checker: pep8 pyflakes and co" +optional = false +python-versions = ">=3.8.1" +files = [ + {file = "flake8-7.1.1-py2.py3-none-any.whl", hash = "sha256:597477df7860daa5aa0fdd84bf5208a043ab96b8e96ab708770ae0364dd03213"}, + {file = "flake8-7.1.1.tar.gz", hash = "sha256:049d058491e228e03e67b390f311bbf88fce2dbaa8fa673e7aea87b7198b8d38"}, +] + +[package.dependencies] +mccabe = ">=0.7.0,<0.8.0" +pycodestyle = ">=2.12.0,<2.13.0" +pyflakes = ">=3.2.0,<3.3.0" + [[package]] name = "ghp-import" version = "2.1.0" @@ -547,6 +601,20 @@ files = [ [package.dependencies] colorama = ">=0.4" +[[package]] +name = "identify" +version = "2.6.1" +description = "File identification library for Python" +optional = false +python-versions = ">=3.8" +files = [ + {file = "identify-2.6.1-py2.py3-none-any.whl", hash = "sha256:53863bcac7caf8d2ed85bd20312ea5dcfc22226800f6d6881f232d861db5a8f0"}, + {file = "identify-2.6.1.tar.gz", hash = "sha256:91478c5fb7c3aac5ff7bf9b4344f803843dc586832d5f110d672b19aa1984c98"}, +] + +[package.extras] +license = ["ukkonen"] + [[package]] name = "idna" version = "3.10" @@ -642,6 +710,20 @@ qtconsole = ["qtconsole"] test = ["packaging", "pickleshare", "pytest", "pytest-asyncio (<0.22)", "testpath"] test-extra = ["curio", "ipython[test]", "matplotlib (!=3.2.0)", "nbformat", "numpy (>=1.23)", "pandas", "trio"] +[[package]] +name = "isort" +version = "5.13.2" +description = "A Python utility / library to sort Python imports." +optional = false +python-versions = ">=3.8.0" +files = [ + {file = "isort-5.13.2-py3-none-any.whl", hash = "sha256:8ca5e72a8d85860d5a3fa69b8745237f2939afe12dbf656afbcb47fe72d947a6"}, + {file = "isort-5.13.2.tar.gz", hash = "sha256:48fdfcb9face5d58a4f6dde2e72a1fb8dcaf8ab26f95ab49fab84c2ddefb0109"}, +] + +[package.extras] +colors = ["colorama (>=0.4.6)"] + [[package]] name = "jedi" version = "0.19.1" @@ -934,6 +1016,17 @@ files = [ [package.dependencies] traitlets = "*" +[[package]] +name = "mccabe" +version = "0.7.0" +description = "McCabe checker, plugin for flake8" +optional = false +python-versions = ">=3.6" +files = [ + {file = "mccabe-0.7.0-py2.py3-none-any.whl", hash = "sha256:6c2d30ab6be0e4a46919781807b4f0d834ebdd6c6e3dca0bda5a15f863427b6e"}, + {file = "mccabe-0.7.0.tar.gz", hash = "sha256:348e0240c33b60bbdf4e523192ef919f28cb2c3d7d5c7794f74009290f236325"}, +] + [[package]] name = "mdit-py-plugins" version = "0.4.2" @@ -1318,6 +1411,17 @@ files = [ {file = "nest_asyncio-1.6.0.tar.gz", hash = "sha256:6f172d5449aca15afd6c646851f4e31e02c598d553a667e38cafa997cfec55fe"}, ] +[[package]] +name = "nodeenv" +version = "1.9.1" +description = "Node.js virtual environment builder" +optional = false +python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,!=3.6.*,>=2.7" +files = [ + {file = "nodeenv-1.9.1-py2.py3-none-any.whl", hash = "sha256:ba11c9782d29c27c70ffbdda2d7415098754709be8a7056d79a737cd901155c9"}, + {file = "nodeenv-1.9.1.tar.gz", hash = "sha256:6ec12890a2dab7946721edbfbcd91f3319c6ccc9aec47be7c7e6b7011ee6645f"}, +] + [[package]] name = "packaging" version = "24.1" @@ -1426,6 +1530,24 @@ files = [ dev = ["pre-commit", "tox"] testing = ["pytest", "pytest-benchmark"] +[[package]] +name = "pre-commit" +version = "4.0.1" +description = "A framework for managing and maintaining multi-language pre-commit hooks." +optional = false +python-versions = ">=3.9" +files = [ + {file = "pre_commit-4.0.1-py2.py3-none-any.whl", hash = "sha256:efde913840816312445dc98787724647c65473daefe420785f885e8ed9a06878"}, + {file = "pre_commit-4.0.1.tar.gz", hash = "sha256:80905ac375958c0444c65e9cebebd948b3cdb518f335a091a670a89d652139d2"}, +] + +[package.dependencies] +cfgv = ">=2.0.0" +identify = ">=1.0.0" +nodeenv = ">=0.11.1" +pyyaml = ">=5.1" +virtualenv = ">=20.10.0" + [[package]] name = "prompt-toolkit" version = "3.0.48" @@ -1506,6 +1628,17 @@ files = [ {file = "py-1.11.0.tar.gz", hash = "sha256:51c75c4126074b472f746a24399ad32f6053d1b34b68d2fa41e558e6f4a98719"}, ] +[[package]] +name = "pycodestyle" +version = "2.12.1" +description = "Python style guide checker" +optional = false +python-versions = ">=3.8" +files = [ + {file = "pycodestyle-2.12.1-py2.py3-none-any.whl", hash = "sha256:46f0fb92069a7c28ab7bb558f05bfc0110dac69a0cd23c61ea0040283a9d78b3"}, + {file = "pycodestyle-2.12.1.tar.gz", hash = "sha256:6838eae08bbce4f6accd5d5572075c63626a15ee3e6f842df996bf62f6d73521"}, +] + [[package]] name = "pycparser" version = "2.22" @@ -1517,6 +1650,17 @@ files = [ {file = "pycparser-2.22.tar.gz", hash = "sha256:491c8be9c040f5390f5bf44a5b07752bd07f56edf992381b05c701439eec10f6"}, ] +[[package]] +name = "pyflakes" +version = "3.2.0" +description = "passive checker of Python programs" +optional = false +python-versions = ">=3.8" +files = [ + {file = "pyflakes-3.2.0-py2.py3-none-any.whl", hash = "sha256:84b5be138a2dfbb40689ca07e2152deb896a65c3a3e24c251c5c62489568074a"}, + {file = "pyflakes-3.2.0.tar.gz", hash = "sha256:1c61603ff154621fb2a9172037d84dca3500def8c8b630657d1701f026f8af3f"}, +] + [[package]] name = "pygments" version = "2.18.0" @@ -2255,6 +2399,26 @@ h2 = ["h2 (>=4,<5)"] socks = ["pysocks (>=1.5.6,!=1.5.7,<2.0)"] zstd = ["zstandard (>=0.18.0)"] +[[package]] +name = "virtualenv" +version = "20.27.0" +description = "Virtual Python Environment builder" +optional = false +python-versions = ">=3.8" +files = [ + {file = "virtualenv-20.27.0-py3-none-any.whl", hash = "sha256:44a72c29cceb0ee08f300b314848c86e57bf8d1f13107a5e671fb9274138d655"}, + {file = "virtualenv-20.27.0.tar.gz", hash = "sha256:2ca56a68ed615b8fe4326d11a0dca5dfbe8fd68510fb6c6349163bed3c15f2b2"}, +] + +[package.dependencies] +distlib = ">=0.3.7,<1" +filelock = ">=3.12.2,<4" +platformdirs = ">=3.9.1,<5" + +[package.extras] +docs = ["furo (>=2023.7.26)", "proselint (>=0.13)", "sphinx (>=7.1.2,!=7.3)", "sphinx-argparse (>=0.4)", "sphinxcontrib-towncrier (>=0.2.1a0)", "towncrier (>=23.6)"] +test = ["covdefaults (>=2.3)", "coverage (>=7.2.7)", "coverage-enable-subprocess (>=1)", "flaky (>=3.7)", "packaging (>=23.1)", "pytest (>=7.4)", "pytest-env (>=0.8.2)", "pytest-freezer (>=0.4.8)", "pytest-mock (>=3.11.1)", "pytest-randomly (>=3.12)", "pytest-timeout (>=2.1)", "setuptools (>=68)", "time-machine (>=2.10)"] + [[package]] name = "watchdog" version = "5.0.3" @@ -2322,4 +2486,4 @@ files = [ [metadata] lock-version = "2.0" python-versions = "^3.11" -content-hash = "2222a9aea2a07aadb0c8f655697b4e4e78ffd2903b0c2eb7f65c54afb8b54147" +content-hash = "676c54d0661f0e4b93ed9394501ec32e25b84522a6f6d25ec463cb83b3311eea" diff --git a/pyproject.toml b/pyproject.toml index 8055d88..0352cf2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -21,6 +21,9 @@ python-dotenv = "^1.0.1" [tool.poetry.group.dev.dependencies] mypy = "^1.12.0" +pre-commit = "^4.0.1" +isort = "^5.13.2" +flake8 = "^7.1.1" [tool.poetry.group.docs.dependencies] diff --git a/tests/khaya/conftest.py b/tests/khaya/conftest.py index 7028eb3..37cfd57 100644 --- a/tests/khaya/conftest.py +++ b/tests/khaya/conftest.py @@ -3,7 +3,7 @@ import pytest from dotenv import load_dotenv -from khaya.khaya_interface import KhayaInterface +from khaya import khayaAPI # os.environ.pop("khaya_api_key", None) # Remove the key if it exists @@ -14,4 +14,4 @@ @pytest.fixture def khaya_interface(): api_key = khaya_api_key - return KhayaInterface(api_key) + return khayaAPI(api_key) diff --git a/tests/khaya/test_khaya_interface.py b/tests/khaya/test_khaya_interface.py index a0759fb..b0715ee 100644 --- a/tests/khaya/test_khaya_interface.py +++ b/tests/khaya/test_khaya_interface.py @@ -1,5 +1,3 @@ -import pytest - class TestTranslate: From dfb2ca981daeb09574f8614186e7941aa8329e22 Mon Sep 17 00:00:00 2001 From: Lawrence Adu-Gyamfi Date: Fri, 25 Oct 2024 00:53:52 +0200 Subject: [PATCH 2/2] [FEAT] add some more tests --- tests/khaya/test_khaya_interface.py | 45 ++++++++++++++++++++++++++++- 1 file changed, 44 insertions(+), 1 deletion(-) diff --git a/tests/khaya/test_khaya_interface.py b/tests/khaya/test_khaya_interface.py index b0715ee..e7b47d7 100644 --- a/tests/khaya/test_khaya_interface.py +++ b/tests/khaya/test_khaya_interface.py @@ -1,3 +1,24 @@ +import pytest + +from khaya import khayaAPI + + +@pytest.mark.parametrize( + "task, input, lang,", [ + ("translate", "Hello", "en-tw"), + ("asr", "tests/khaya/me_ho_ye.wav", "tw"), + ("tts", "Hello", "tw"), + ] +) +def test_invalid_api_key(task, input, lang): + invalid_api_key = "invalid_api_key" + khaya_interface = khayaAPI(invalid_api_key) + + # execute the task + result = getattr(khaya_interface, task)(input, lang) + + assert "401 Client Error" in result["message"] + class TestTranslate: @@ -19,6 +40,14 @@ def test_translate_error(self, khaya_interface): assert "error" in result.text.lower() + def test_translate_empty_text(self, khaya_interface): + text = "" + translation_pair = "en-tw" + + result = khaya_interface.translate(text, translation_pair) + + assert "error" in result.text.lower() + class TestASR: @@ -31,7 +60,7 @@ def test_asr_valid(self, khaya_interface): assert "error" not in result.text.lower() assert result.json() == "me ho yÉ›" - def test_asr_error(self, khaya_interface): + def test_asr_error_invalid_language(self, khaya_interface): audio_file_path = "tests/khaya/me_ho_ye.wav" wrong_lang = "fw" @@ -39,6 +68,12 @@ def test_asr_error(self, khaya_interface): assert "error" in result['message'].lower() + def test_asr_error_nonexistent_file(self, khaya_interface): + audio_file_path = "tests/khaya/nonexistent.wav" + + with pytest.raises(FileNotFoundError): + khaya_interface.asr(audio_file_path, "tw") + class TestTTS: @@ -59,3 +94,11 @@ def test_tts_error(self, khaya_interface): result = khaya_interface.tts(text, wrong_lang) assert "error" in result.text.lower() + + def test_tts_empty_text(self, khaya_interface): + text = "" + lang = "tw" + + result = khaya_interface.tts(text, lang) + + assert "error" in result.text.lower()