diff --git a/api/task_processor/management/commands/runprocessor.py b/api/task_processor/management/commands/runprocessor.py index 7830ffa52cd8..f8e1a579f9aa 100644 --- a/api/task_processor/management/commands/runprocessor.py +++ b/api/task_processor/management/commands/runprocessor.py @@ -10,7 +10,8 @@ from django.core.management import BaseCommand from django.utils import timezone -from task_processor import tasks +from sse import tasks as sse_tasks +from task_processor import tasks as processor_tasks from task_processor.task_registry import registered_tasks from task_processor.thread_monitoring import ( clear_unhealthy_threads, @@ -20,6 +21,8 @@ logger = logging.getLogger(__name__) +TASKS_MODULES_TO_RELOAD = [processor_tasks, sse_tasks] + class Command(BaseCommand): def __init__(self, *args, **kwargs): @@ -33,10 +36,14 @@ def __init__(self, *args, **kwargs): # environment variable. os.environ["RUN_BY_PROCESSOR"] = "True" - # Since the tasks module is loaded by the ready method in TaskProcessorConfig - # which is run before the command is initialised, we need to reload the internal - # tasks module here to make sure recurring tasks are registered correctly. - reload(tasks) + # Since all the apps are loaded before the command is initialised, + # we need to reload some of those modules(that contains recurring tasks) + # to ensure the tasks are registered correctly + # e.g the tasks module is loaded by the ready method in TaskProcessorConfig + # which is run before the command is initialised + + for module in TASKS_MODULES_TO_RELOAD: + reload(module) signal.signal(signal.SIGINT, self._exit_gracefully) signal.signal(signal.SIGTERM, self._exit_gracefully) diff --git a/api/tests/unit/sse/test_tasks.py b/api/tests/unit/sse/test_tasks.py index b0e1d76a541b..7510e35aa9c4 100644 --- a/api/tests/unit/sse/test_tasks.py +++ b/api/tests/unit/sse/test_tasks.py @@ -1,5 +1,7 @@ +import os from datetime import datetime from typing import Callable +from unittest import mock from unittest.mock import call import pytest @@ -15,6 +17,10 @@ send_environment_update_message_for_project, update_sse_usage, ) +from task_processor.management.commands.runprocessor import ( + Command as RunProcessorCommand, +) +from task_processor.models import RecurringTask def test_send_environment_update_message_for_project_make_correct_request( @@ -138,3 +144,24 @@ def test_track_sse_usage( bucket=influxdb_bucket, record=mocked_influx_point().field().tag().tag().tag().time(), ) + + +@mock.patch.dict(os.environ, {}) +@pytest.mark.django_db +def test_track_sse_usage_is_installed_correctly( + settings: SettingsWrapper, +): + # Given + settings.AWS_SSE_LOGS_BUCKET_NAME = "test_bucket" + + # When + # Initialising the command should save the task to the database + RunProcessorCommand() + + # Then + assert ( + RecurringTask.objects.filter( + task_identifier=f"tasks.{update_sse_usage.__name__}" + ).exists() + is True + )