Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(task-processor): validate arguments passed to task processor functions #2747

Merged
merged 5 commits into from
Sep 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 27 additions & 4 deletions api/task_processor/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from django.conf import settings
from django.utils import timezone

from task_processor.exceptions import InvalidArgumentsError
from task_processor.models import RecurringTask, Task
from task_processor.task_registry import register_task
from task_processor.task_run_method import TaskRunMethod
Expand Down Expand Up @@ -41,6 +42,7 @@ def delay(
return

if settings.TASK_RUN_METHOD == TaskRunMethod.SYNCHRONOUSLY:
_validate_inputs(*args, **kwargs)
f(*args, **kwargs)
elif settings.TASK_RUN_METHOD == TaskRunMethod.SEPARATE_THREAD:
logger.debug("Running task '%s' in separate thread", task_identifier)
Expand All @@ -58,13 +60,26 @@ def delay(

def run_in_thread(*, args: typing.Tuple = (), kwargs: typing.Dict = None):
logger.info("Running function %s in unmanaged thread.", f.__name__)
_validate_inputs(*args, **kwargs)
Thread(target=f, args=args, kwargs=kwargs, daemon=True).start()

f.delay = delay
f.run_in_thread = run_in_thread
f.task_identifier = task_identifier
def _wrapper(*args, **kwargs):
"""
Execute the function after validating the arguments. Ensures that, in unit testing,
the arguments are validated to prevent issues with serialization in an environment
that utilises the task processor.
"""
_validate_inputs(*args, **kwargs)
return f(*args, **kwargs)

return f
_wrapper.delay = delay
_wrapper.run_in_thread = run_in_thread
_wrapper.task_identifier = task_identifier

# patch the original unwrapped function onto the wrapped version for testing
_wrapper.unwrapped = f

return _wrapper

return decorator

Expand Down Expand Up @@ -101,3 +116,11 @@ def decorator(f: typing.Callable):
return task

return decorator


def _validate_inputs(*args, **kwargs):
try:
Task.serialize_data(args or tuple())
Task.serialize_data(kwargs or dict())
except TypeError as e:
raise InvalidArgumentsError("Inputs are not serializable.") from e
4 changes: 4 additions & 0 deletions api/task_processor/exceptions.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,6 @@
class TaskProcessingError(Exception):
pass


class InvalidArgumentsError(TaskProcessingError):
pass
2 changes: 1 addition & 1 deletion api/tests/unit/audit/test_unit_audit_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ def test_create_segment_priorities_changed_audit_log(
create_segment_priorities_changed_audit_log(
previous_id_priority_pairs=[
(feature_segment.id, 0),
(another_feature_segment, 1),
(another_feature_segment.id, 1),
],
feature_segment_ids=[feature_segment.id, another_feature_segment.id],
user_id=admin_user.id,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,16 @@
from datetime import timedelta

import pytest
from pytest_django.fixtures import SettingsWrapper

from task_processor.decorators import (
register_recurring_task,
register_task_handler,
)
from task_processor.exceptions import InvalidArgumentsError
from task_processor.models import RecurringTask
from task_processor.task_registry import get_task
from task_processor.task_run_method import TaskRunMethod


def test_register_task_handler_run_in_thread(mocker, caplog):
Expand Down Expand Up @@ -41,7 +44,7 @@ def my_function(*args, **kwargs):

# Then
mock_thread_class.assert_called_once_with(
target=my_function, args=args, kwargs=kwargs, daemon=True
target=my_function.unwrapped, args=args, kwargs=kwargs, daemon=True
)
mock_thread.start.assert_called_once()

Expand Down Expand Up @@ -93,3 +96,38 @@ def some_function(first_arg, second_arg):
assert not RecurringTask.objects.filter(task_identifier=task_identifier).exists()
with pytest.raises(KeyError):
assert get_task(task_identifier)


def test_register_task_handler_validates_inputs() -> None:
# Given
@register_task_handler()
def my_function(*args, **kwargs):
pass

class NonSerializableObj:
pass

# When
with pytest.raises(InvalidArgumentsError):
my_function(NonSerializableObj())


@pytest.mark.parametrize(
"task_run_method", (TaskRunMethod.SEPARATE_THREAD, TaskRunMethod.SYNCHRONOUSLY)
)
def test_inputs_are_validated_when_run_without_task_processor(
settings: SettingsWrapper, task_run_method: TaskRunMethod
) -> None:
# Given
settings.TASK_RUN_METHOD = task_run_method

@register_task_handler()
def my_function(*args, **kwargs):
pass

class NonSerializableObj:
pass

# When
with pytest.raises(InvalidArgumentsError):
my_function.delay(args=(NonSerializableObj(),))