Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add StreamActivatedJobs #507

Draft
wants to merge 1 commit into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions pyzeebe/errors/job_errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,19 @@ def __init__(self, task_type: str, worker: str, timeout: int, max_jobs_to_activa
super().__init__(msg)


class StreamActivateJobsRequestInvalidError(PyZeebeError):
def __init__(self, task_type: str, worker: str, timeout: int):
msg = "Failed to activate jobs. Reasons:"
if task_type == "" or task_type is None:
msg = msg + "task_type is empty, "
if worker == "" or task_type is None:
msg = msg + "worker is empty, "
if timeout < 1:
msg = msg + "job timeout is smaller than 0ms, "

super().__init__(msg)


class JobAlreadyDeactivatedError(PyZeebeError):
def __init__(self, job_key: int) -> None:
super().__init__(f"Job {job_key} was already stopped (Completed/Failed/Error)")
Expand Down
30 changes: 30 additions & 0 deletions pyzeebe/grpc_internals/zeebe_job_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,15 @@
ActivateJobsRequest,
CompleteJobRequest,
FailJobRequest,
StreamActivatedJobsRequest,
ThrowErrorRequest,
)

from pyzeebe.errors import (
ActivateJobsRequestInvalidError,
JobAlreadyDeactivatedError,
JobNotFoundError,
StreamActivateJobsRequestInvalidError,
)
from pyzeebe.grpc_internals.grpc_utils import is_error_status
from pyzeebe.grpc_internals.zeebe_adapter_base import ZeebeAdapterBase
Expand Down Expand Up @@ -63,6 +65,34 @@ async def activate_jobs(
raise ActivateJobsRequestInvalidError(task_type, worker, timeout, max_jobs_to_activate) from grpc_error
await self._handle_grpc_error(grpc_error)

async def stream_activate_jobs(
self,
task_type: str,
worker: str,
timeout: int,
variables_to_fetch: Iterable[str],
request_timeout: int,
tenant_ids: Optional[Iterable[str]] = None,
) -> AsyncGenerator[Job, None]:
try:
async for raw_job in self._gateway_stub.StreamActivatedJobs(
StreamActivatedJobsRequest(
type=task_type,
worker=worker,
timeout=timeout,
fetchVariable=variables_to_fetch,
tenantIds=tenant_ids or [],
),
timeout=request_timeout,
):
job = self._create_job_from_raw_job(raw_job)
logger.debug("Got job: %s from zeebe", job)
yield job
except grpc.aio.AioRpcError as grpc_error:
if is_error_status(grpc_error, grpc.StatusCode.INVALID_ARGUMENT):
raise StreamActivateJobsRequestInvalidError(task_type, worker, timeout) from grpc_error
await self._handle_grpc_error(grpc_error)

def _create_job_from_raw_job(self, response: ActivatedJob) -> Job:
return Job(
key=response.key,
Expand Down
91 changes: 80 additions & 11 deletions pyzeebe/worker/job_poller.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import abc
import asyncio
import logging
from typing import List, Optional
from typing import Optional, final

from pyzeebe.errors import (
ActivateJobsRequestInvalidError,
Expand All @@ -17,17 +18,17 @@
logger = logging.getLogger(__name__)


class JobPoller:
class JobPollerABC(abc.ABC):
def __init__(
self,
zeebe_adapter: ZeebeJobAdapter,
task: Task,
queue: "asyncio.Queue[Job]",
queue: asyncio.Queue[Job],
worker_name: str,
request_timeout: int,
task_state: TaskState,
poll_retry_delay: int,
tenant_ids: Optional[List[str]],
tenant_ids: Optional[list[str]],
) -> None:
self.zeebe_adapter = zeebe_adapter
self.task = task
Expand All @@ -43,6 +44,9 @@ async def poll(self) -> None:
while self.should_poll():
await self.activate_max_jobs()

@abc.abstractmethod
async def poll_once(self) -> None: ...

async def activate_max_jobs(self) -> None:
if self.calculate_max_jobs_to_activate() > 0:
await self.poll_once()
Expand All @@ -54,6 +58,20 @@ async def activate_max_jobs(self) -> None:
)
await asyncio.sleep(self.poll_retry_delay)

def should_poll(self) -> bool:
return not self.stop_event.is_set() and (self.zeebe_adapter.connected or self.zeebe_adapter.retrying_connection)

def calculate_max_jobs_to_activate(self) -> int:
worker_max_jobs = self.task.config.max_running_jobs - self.task_state.count_active()
return min(worker_max_jobs, self.task.config.max_jobs_to_activate)

async def stop(self) -> None:
self.stop_event.set()
await self.queue.join()


@final
class JobPoller(JobPollerABC):
async def poll_once(self) -> None:
try:
jobs = self.zeebe_adapter.activate_jobs(
Expand Down Expand Up @@ -83,13 +101,64 @@ async def poll_once(self) -> None:
)
await asyncio.sleep(5)

def should_poll(self) -> bool:
return not self.stop_event.is_set() and (self.zeebe_adapter.connected or self.zeebe_adapter.retrying_connection)

def calculate_max_jobs_to_activate(self) -> int:
worker_max_jobs = self.task.config.max_running_jobs - self.task_state.count_active()
return min(worker_max_jobs, self.task.config.max_jobs_to_activate)
@final
class JobStreamer(JobPollerABC):
def __init__(
self,
zeebe_adapter: ZeebeJobAdapter,
task: Task,
queue: asyncio.Queue[Job],
worker_name: str,
request_timeout: int,
task_state: TaskState,
poll_retry_delay: int,
tenant_ids: Optional[list[str]],
) -> None:
super().__init__(
zeebe_adapter,
task,
queue,
worker_name,
request_timeout,
task_state,
poll_retry_delay,
tenant_ids,
)
self._create_stream()

def _create_stream(self) -> None:
self._stream = self.zeebe_adapter.stream_activate_jobs(
task_type=self.task.type,
worker=self.worker_name,
timeout=self.task.config.timeout_ms,
variables_to_fetch=self.task.config.variables_to_fetch or [],
request_timeout=self.request_timeout,
tenant_ids=self.tenant_ids,
)

async def poll_once(self) -> None:
try:
job = await self._stream.__anext__()
self.task_state.add(job)
await self.queue.put(job)
except StopAsyncIteration:
self._create_stream()
except ActivateJobsRequestInvalidError:
logger.warning("Activate job requests was invalid for task %s", self.task.type)
raise
except (
ZeebeBackPressureError,
ZeebeGatewayUnavailableError,
ZeebeInternalError,
ZeebeDeadlineExceeded,
) as error:
logger.warning(
"Failed to activate jobs from the gateway. Exception: %s. Retrying in 5 seconds...",
repr(error),
)
await asyncio.sleep(5)

async def stop(self) -> None:
self.stop_event.set()
await self.queue.join()
await self._stream.aclose()
await super().stop()
24 changes: 22 additions & 2 deletions pyzeebe/worker/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from pyzeebe.task import task_builder
from pyzeebe.task.exception_handler import ExceptionHandler
from pyzeebe.worker.job_executor import JobExecutor
from pyzeebe.worker.job_poller import JobPoller
from pyzeebe.worker.job_poller import JobPoller, JobStreamer
from pyzeebe.worker.task_router import ZeebeTaskRouter
from pyzeebe.worker.task_state import TaskState

Expand All @@ -34,6 +34,7 @@ def __init__(
poll_retry_delay: int = 5,
tenant_ids: Optional[List[str]] = None,
exception_handler: Optional[ExceptionHandler] = None,
stream_enabled: bool = False,
):
"""
Args:
Expand All @@ -47,6 +48,7 @@ def __init__(
watcher_max_errors_factor (int): Number of consecutive errors for a task watcher will accept before raising MaxConsecutiveTaskThreadError
poll_retry_delay (int): The number of seconds to wait before attempting to poll again when reaching max amount of running jobs
tenant_ids (List[str]): A list of tenant IDs for which to activate jobs. New in Zeebe 8.3.
stream_enabled (bool): Enables the job worker to stream jobs. It will still poll for older jobs, but streaming is favored. New in Zeebe 8.4.
"""
super().__init__(before, after, exception_handler)
self.zeebe_adapter = ZeebeAdapter(grpc_channel, max_connection_retries)
Expand All @@ -57,11 +59,13 @@ def __init__(
self.poll_retry_delay = poll_retry_delay
self.tenant_ids = tenant_ids
self._job_pollers: List[JobPoller] = []
self._job_streamers: List[JobStreamer] = []
self._job_executors: List[JobExecutor] = []
self._stop_event = anyio.Event()
self._stream_enabled = stream_enabled

def _init_tasks(self) -> None:
self._job_executors, self._job_pollers = [], []
self._job_executors, self._job_pollers, self._job_streamers = [], [], []

for task in self.tasks:
jobs_queue: "asyncio.Queue[Job]" = asyncio.Queue()
Expand All @@ -82,6 +86,19 @@ def _init_tasks(self) -> None:
self._job_pollers.append(poller)
self._job_executors.append(executor)

if self._stream_enabled:
streamer = JobStreamer(
self.zeebe_adapter,
task,
jobs_queue,
self.name,
self.request_timeout,
task_state,
self.poll_retry_delay,
self.tenant_ids,
)
self._job_streamers.append(streamer)

async def work(self) -> None:
"""
Start the worker. The worker will poll zeebe for jobs of each task in a different thread.
Expand All @@ -100,6 +117,9 @@ async def work(self) -> None:
for poller in self._job_pollers:
tg.start_soon(poller.poll)

for streamer in self._job_streamers:
tg.start_soon(streamer.poll)

for executor in self._job_executors:
tg.start_soon(executor.execute)

Expand Down
31 changes: 31 additions & 0 deletions tests/unit/utils/gateway_mock.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,37 @@ def ActivateJobs(self, request, context):
)
yield ActivateJobsResponse(jobs=jobs)

def StreamActivatedJobs(self, request, context):
if not request.type:
context.set_code(grpc.StatusCode.INVALID_ARGUMENT)
return ActivatedJob()

if request.timeout <= 0:
context.set_code(grpc.StatusCode.INVALID_ARGUMENT)
return ActivatedJob()

if not request.worker:
context.set_code(grpc.StatusCode.INVALID_ARGUMENT)
return ActivatedJob()

for active_job in self.active_jobs.values():
if active_job.type == request.type:
yield ActivatedJob(
key=active_job.key,
type=active_job.type,
processInstanceKey=active_job.process_instance_key,
bpmnProcessId=active_job.bpmn_process_id,
processDefinitionVersion=active_job.process_definition_version,
processDefinitionKey=active_job.process_definition_key,
elementId=active_job.element_id,
elementInstanceKey=active_job.element_instance_key,
customHeaders=json.dumps(active_job.custom_headers),
worker=active_job.worker,
retries=active_job.retries,
deadline=active_job.deadline,
variables=json.dumps(active_job.variables),
)

def CompleteJob(self, request, context):
if request.jobKey in self.active_jobs.keys():
active_job = self.active_jobs.get(request.jobKey)
Expand Down
36 changes: 35 additions & 1 deletion tests/unit/worker/job_poller_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from pyzeebe.grpc_internals.zeebe_adapter import ZeebeAdapter
from pyzeebe.job.job import Job
from pyzeebe.task.task import Task
from pyzeebe.worker.job_poller import JobPoller
from pyzeebe.worker.job_poller import JobPoller, JobStreamer
from pyzeebe.worker.task_state import TaskState
from tests.unit.utils.gateway_mock import GatewayMock
from tests.unit.utils.random_utils import random_job
Expand All @@ -17,6 +17,13 @@ def job_poller(zeebe_adapter: ZeebeAdapter, task: Task, queue: asyncio.Queue, ta
return JobPoller(zeebe_adapter, task, queue, "test_worker", 100, task_state, 0, None)


@pytest.fixture
def job_stream_poller(
zeebe_adapter: ZeebeAdapter, task: Task, queue: asyncio.Queue, task_state: TaskState
) -> JobStreamer:
return JobStreamer(zeebe_adapter, task, queue, "test_worker", 100, task_state, 0, [])


@pytest.mark.asyncio
class TestPollOnce:
async def test_one_job_is_polled(
Expand Down Expand Up @@ -44,6 +51,33 @@ async def test_job_is_added_to_task_state(
assert job_poller.task_state.count_active() == 1


@pytest.mark.asyncio
class TestStreamPollOnce:
async def test_one_job_is_polled(
self, job_stream_poller: JobStreamer, queue: asyncio.Queue, job_from_task: Job, grpc_servicer: GatewayMock
):
grpc_servicer.active_jobs[job_from_task.key] = job_from_task

await job_stream_poller.poll_once()

job: Job = queue.get_nowait()
assert job.key == job_from_task.key

async def test_no_job_is_polled(self, job_stream_poller: JobStreamer, queue: asyncio.Queue):
await job_stream_poller.poll_once()

assert queue.empty()

async def test_job_is_added_to_task_state(
self, job_stream_poller: JobStreamer, job_from_task: Job, grpc_servicer: GatewayMock
):
grpc_servicer.active_jobs[job_from_task.key] = job_from_task

await job_stream_poller.poll_once()

assert job_stream_poller.task_state.count_active() == 1


class TestShouldPoll:
def test_should_poll_returns_expected_result_when_disconnected(self, job_poller: JobPoller):
job_poller.zeebe_adapter.connected = False
Expand Down