Skip to content

Commit

Permalink
feat(task-processor): validate arguments passed to task processor fun…
Browse files Browse the repository at this point in the history
…ctions (#2747)
  • Loading branch information
matthewelwell authored Sep 6, 2023
1 parent c1a62ce commit d947474
Show file tree
Hide file tree
Showing 4 changed files with 71 additions and 6 deletions.
31 changes: 27 additions & 4 deletions api/task_processor/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from django.conf import settings
from django.utils import timezone

from task_processor.exceptions import InvalidArgumentsError
from task_processor.models import RecurringTask, Task
from task_processor.task_registry import register_task
from task_processor.task_run_method import TaskRunMethod
Expand Down Expand Up @@ -41,6 +42,7 @@ def delay(
return

if settings.TASK_RUN_METHOD == TaskRunMethod.SYNCHRONOUSLY:
_validate_inputs(*args, **kwargs)
f(*args, **kwargs)
elif settings.TASK_RUN_METHOD == TaskRunMethod.SEPARATE_THREAD:
logger.debug("Running task '%s' in separate thread", task_identifier)
Expand All @@ -58,13 +60,26 @@ def delay(

def run_in_thread(*, args: typing.Tuple = (), kwargs: typing.Dict = None):
logger.info("Running function %s in unmanaged thread.", f.__name__)
_validate_inputs(*args, **kwargs)
Thread(target=f, args=args, kwargs=kwargs, daemon=True).start()

f.delay = delay
f.run_in_thread = run_in_thread
f.task_identifier = task_identifier
def _wrapper(*args, **kwargs):
"""
Execute the function after validating the arguments. Ensures that, in unit testing,
the arguments are validated to prevent issues with serialization in an environment
that utilises the task processor.
"""
_validate_inputs(*args, **kwargs)
return f(*args, **kwargs)

return f
_wrapper.delay = delay
_wrapper.run_in_thread = run_in_thread
_wrapper.task_identifier = task_identifier

# patch the original unwrapped function onto the wrapped version for testing
_wrapper.unwrapped = f

return _wrapper

return decorator

Expand Down Expand Up @@ -101,3 +116,11 @@ def decorator(f: typing.Callable):
return task

return decorator


def _validate_inputs(*args, **kwargs):
try:
Task.serialize_data(args or tuple())
Task.serialize_data(kwargs or dict())
except TypeError as e:
raise InvalidArgumentsError("Inputs are not serializable.") from e
4 changes: 4 additions & 0 deletions api/task_processor/exceptions.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,6 @@
class TaskProcessingError(Exception):
pass


class InvalidArgumentsError(TaskProcessingError):
pass
2 changes: 1 addition & 1 deletion api/tests/unit/audit/test_unit_audit_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ def test_create_segment_priorities_changed_audit_log(
create_segment_priorities_changed_audit_log(
previous_id_priority_pairs=[
(feature_segment.id, 0),
(another_feature_segment, 1),
(another_feature_segment.id, 1),
],
feature_segment_ids=[feature_segment.id, another_feature_segment.id],
user_id=admin_user.id,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,16 @@
from datetime import timedelta

import pytest
from pytest_django.fixtures import SettingsWrapper

from task_processor.decorators import (
register_recurring_task,
register_task_handler,
)
from task_processor.exceptions import InvalidArgumentsError
from task_processor.models import RecurringTask
from task_processor.task_registry import get_task
from task_processor.task_run_method import TaskRunMethod


def test_register_task_handler_run_in_thread(mocker, caplog):
Expand Down Expand Up @@ -41,7 +44,7 @@ def my_function(*args, **kwargs):

# Then
mock_thread_class.assert_called_once_with(
target=my_function, args=args, kwargs=kwargs, daemon=True
target=my_function.unwrapped, args=args, kwargs=kwargs, daemon=True
)
mock_thread.start.assert_called_once()

Expand Down Expand Up @@ -93,3 +96,38 @@ def some_function(first_arg, second_arg):
assert not RecurringTask.objects.filter(task_identifier=task_identifier).exists()
with pytest.raises(KeyError):
assert get_task(task_identifier)


def test_register_task_handler_validates_inputs() -> None:
# Given
@register_task_handler()
def my_function(*args, **kwargs):
pass

class NonSerializableObj:
pass

# When
with pytest.raises(InvalidArgumentsError):
my_function(NonSerializableObj())


@pytest.mark.parametrize(
"task_run_method", (TaskRunMethod.SEPARATE_THREAD, TaskRunMethod.SYNCHRONOUSLY)
)
def test_inputs_are_validated_when_run_without_task_processor(
settings: SettingsWrapper, task_run_method: TaskRunMethod
) -> None:
# Given
settings.TASK_RUN_METHOD = task_run_method

@register_task_handler()
def my_function(*args, **kwargs):
pass

class NonSerializableObj:
pass

# When
with pytest.raises(InvalidArgumentsError):
my_function.delay(args=(NonSerializableObj(),))

3 comments on commit d947474

@vercel
Copy link

@vercel vercel bot commented on d947474 Sep 6, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Successfully deployed to the following URLs:

docs – ./docs

docs-flagsmith.vercel.app
docs.bullet-train.io
docs-git-main-flagsmith.vercel.app
docs.flagsmith.com

@vercel
Copy link

@vercel vercel bot commented on d947474 Sep 6, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@vercel
Copy link

@vercel vercel bot commented on d947474 Sep 6, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please sign in to comment.