From a46badb8bd2751a23c5b87b70771a593ae16a960 Mon Sep 17 00:00:00 2001 From: Future-Outlier Date: Fri, 8 Sep 2023 07:17:32 +0800 Subject: [PATCH] Async agent delete function for while loop case (#1802) Signed-off-by: Future Outlier Signed-off-by: Kevin Su Co-authored-by: Future Outlier Co-authored-by: Kevin Su Signed-off-by: Future Outlier --- flytekit/extend/backend/agent_service.py | 14 +-- flytekit/extend/backend/base_agent.py | 105 ++++++++++-------- plugins/flytekit-bigquery/tests/test_agent.py | 2 +- tests/flytekit/unit/extend/test_agent.py | 29 ++++- 4 files changed, 90 insertions(+), 60 deletions(-) diff --git a/flytekit/extend/backend/agent_service.py b/flytekit/extend/backend/agent_service.py index 470bd01e2e..9c294e5fae 100644 --- a/flytekit/extend/backend/agent_service.py +++ b/flytekit/extend/backend/agent_service.py @@ -2,14 +2,12 @@ import grpc from flyteidl.admin.agent_pb2 import ( - PERMANENT_FAILURE, CreateTaskRequest, CreateTaskResponse, DeleteTaskRequest, DeleteTaskResponse, GetTaskRequest, GetTaskResponse, - Resource, ) from flyteidl.service.agent_pb2_grpc import AsyncAgentServiceServicer @@ -24,10 +22,8 @@ async def CreateTask(self, request: CreateTaskRequest, context: grpc.ServicerCon try: tmp = TaskTemplate.from_flyte_idl(request.template) inputs = LiteralMap.from_flyte_idl(request.inputs) if request.inputs else None - agent = AgentRegistry.get_agent(context, tmp.type) + agent = AgentRegistry.get_agent(tmp.type) logger.info(f"{tmp.type} agent start creating the job") - if agent is None: - return CreateTaskResponse() if agent.asynchronous: try: return await agent.async_create( @@ -50,10 +46,8 @@ async def CreateTask(self, request: CreateTaskRequest, context: grpc.ServicerCon async def GetTask(self, request: GetTaskRequest, context: grpc.ServicerContext) -> GetTaskResponse: try: - agent = AgentRegistry.get_agent(context, request.task_type) + agent = AgentRegistry.get_agent(request.task_type) logger.info(f"{agent.task_type} agent start checking the status of the job") - if agent is None: - return GetTaskResponse(resource=Resource(state=PERMANENT_FAILURE)) if agent.asynchronous: try: return await agent.async_get(context=context, resource_meta=request.resource_meta) @@ -72,10 +66,8 @@ async def GetTask(self, request: GetTaskRequest, context: grpc.ServicerContext) async def DeleteTask(self, request: DeleteTaskRequest, context: grpc.ServicerContext) -> DeleteTaskResponse: try: - agent = AgentRegistry.get_agent(context, request.task_type) + agent = AgentRegistry.get_agent(request.task_type) logger.info(f"{agent.task_type} agent start deleting the job") - if agent is None: - return DeleteTaskResponse() if agent.asynchronous: try: return await agent.async_delete(context=context, resource_meta=request.resource_meta) diff --git a/flytekit/extend/backend/base_agent.py b/flytekit/extend/backend/base_agent.py index 1bf34c029a..50574e67b1 100644 --- a/flytekit/extend/backend/base_agent.py +++ b/flytekit/extend/backend/base_agent.py @@ -122,12 +122,9 @@ def register(agent: AgentBase): logger.info(f"Registering an agent for task type {agent.task_type}") @staticmethod - def get_agent(context: grpc.ServicerContext, task_type: str) -> typing.Optional[AgentBase]: + def get_agent(task_type: str) -> typing.Optional[AgentBase]: if task_type not in AgentRegistry._REGISTRY: - logger.error(f"Cannot find agent for task type [{task_type}]") - context.set_code(grpc.StatusCode.NOT_FOUND) - context.set_details(f"Cannot find the agent for task type [{task_type}]") - return None + raise ValueError(f"Unrecognized task type {task_type}") return AgentRegistry._REGISTRY[task_type] @@ -136,9 +133,9 @@ def convert_to_flyte_state(state: str) -> State: Convert the state from the agent to the state in flyte. """ state = state.lower() - if state in ["failed"]: + if state in ["failed", "timedout", "canceled"]: return RETRYABLE_FAILURE - elif state in ["done", "succeeded"]: + elif state in ["done", "succeeded", "success"]: return SUCCEEDED elif state in ["running"]: return RUNNING @@ -158,61 +155,79 @@ class AsyncAgentExecutorMixin: Task should inherit from this class if the task can be run in the agent. """ - def execute(self, **kwargs) -> typing.Any: - from unittest.mock import MagicMock + _is_canceled = None + _agent = None + _entity = None + def execute(self, **kwargs) -> typing.Any: from flytekit.tools.translator import get_serializable - entity = typing.cast(PythonTask, self) - m: OrderedDict = OrderedDict() - dummy_context = MagicMock(spec=grpc.ServicerContext) - cp_entity = get_serializable(m, settings=SerializationSettings(ImageConfig()), entity=entity) - agent = AgentRegistry.get_agent(dummy_context, cp_entity.template.type) + self._entity = typing.cast(PythonTask, self) + task_template = get_serializable(OrderedDict(), SerializationSettings(ImageConfig()), self._entity).template + self._agent = AgentRegistry.get_agent(task_template.type) - if agent is None: - raise Exception("Cannot find the agent for the task") - literals = {} + res = asyncio.run(self._create(task_template, kwargs)) + res = asyncio.run(self._get(resource_meta=res.resource_meta)) + + if res.resource.state != SUCCEEDED: + raise Exception(f"Failed to run the task {self._entity.name}") + + return LiteralMap.from_flyte_idl(res.resource.outputs) + + async def _create( + self, task_template: TaskTemplate, inputs: typing.Dict[str, typing.Any] = None + ) -> CreateTaskResponse: ctx = FlyteContext.current_context() - for k, v in kwargs.items(): - literals[k] = TypeEngine.to_literal(ctx, v, type(v), entity.interface.inputs[k].type) + grpc_ctx = _get_grpc_context() + + # Convert python inputs to literals + literals = {} + for k, v in inputs.items(): + literals[k] = TypeEngine.to_literal(ctx, v, type(v), self._entity.interface.inputs[k].type) inputs = LiteralMap(literals) if literals else None output_prefix = ctx.file_access.get_random_local_directory() - cp_entity = get_serializable(m, settings=SerializationSettings(ImageConfig()), entity=entity) - if agent.asynchronous: - res = asyncio.run(agent.async_create(dummy_context, output_prefix, cp_entity.template, inputs)) + + if self._agent.asynchronous: + res = await self._agent.async_create(grpc_ctx, output_prefix, task_template, inputs) else: - res = agent.create(dummy_context, output_prefix, cp_entity.template, inputs) - signal.signal(signal.SIGINT, partial(self.signal_handler, agent, dummy_context, res.resource_meta)) + res = self._agent.create(grpc_ctx, output_prefix, task_template, inputs) + + signal.signal(signal.SIGINT, partial(self.signal_handler, res.resource_meta)) # type: ignore + return res + + async def _get(self, resource_meta: bytes) -> GetTaskResponse: state = RUNNING - metadata = res.resource_meta + grpc_ctx = _get_grpc_context() + progress = Progress(transient=True) - task = progress.add_task(f"[cyan]Running Task {entity.name}...", total=None) + task = progress.add_task(f"[cyan]Running Task {self._entity.name}...", total=None) with progress: while not is_terminal_state(state): progress.start_task(task) time.sleep(1) - if agent.asynchronous: - res = asyncio.run(agent.async_get(dummy_context, metadata)) + if self._agent.asynchronous: + res = await self._agent.async_get(grpc_ctx, resource_meta) + if self._is_canceled: + await self._is_canceled + sys.exit(1) else: - res = agent.get(dummy_context, metadata) + res = self._agent.get(grpc_ctx, resource_meta) state = res.resource.state logger.info(f"Task state: {state}") + return res + + def signal_handler(self, resource_meta: bytes, signum: int, frame: FrameType) -> typing.Any: + grpc_ctx = _get_grpc_context() + if self._agent.asynchronous: + if self._is_canceled is None: + self._is_canceled = asyncio.create_task(self._agent.async_delete(grpc_ctx, resource_meta)) + else: + self._agent.delete(grpc_ctx, resource_meta) + sys.exit(1) - if state != SUCCEEDED: - raise Exception(f"Failed to run the task {entity.name}") - return LiteralMap.from_flyte_idl(res.resource.outputs) +def _get_grpc_context(): + from unittest.mock import MagicMock - def signal_handler( - self, - agent: AgentBase, - context: grpc.ServicerContext, - resource_meta: bytes, - signum: int, - frame: FrameType, - ) -> typing.Any: - if agent.asynchronous: - asyncio.run(agent.async_delete(context, resource_meta)) - else: - agent.delete(context, resource_meta) - sys.exit(1) + grpc_ctx = MagicMock(spec=grpc.ServicerContext) + return grpc_ctx diff --git a/plugins/flytekit-bigquery/tests/test_agent.py b/plugins/flytekit-bigquery/tests/test_agent.py index 16b5b7af4d..af53f4031d 100644 --- a/plugins/flytekit-bigquery/tests/test_agent.py +++ b/plugins/flytekit-bigquery/tests/test_agent.py @@ -44,7 +44,7 @@ def __init__(self): mock_instance.cancel_job.return_value = MockJob() ctx = MagicMock(spec=grpc.ServicerContext) - agent = AgentRegistry.get_agent(ctx, "bigquery_query_job_task") + agent = AgentRegistry.get_agent("bigquery_query_job_task") task_id = Identifier( resource_type=ResourceType.TASK, project="project", domain="domain", name="name", version="version" diff --git a/tests/flytekit/unit/extend/test_agent.py b/tests/flytekit/unit/extend/test_agent.py index bf1db6e333..b763a7e402 100644 --- a/tests/flytekit/unit/extend/test_agent.py +++ b/tests/flytekit/unit/extend/test_agent.py @@ -9,6 +9,7 @@ import pytest from flyteidl.admin.agent_pb2 import ( PERMANENT_FAILURE, + RETRYABLE_FAILURE, RUNNING, SUCCEEDED, CreateTaskRequest, @@ -23,7 +24,13 @@ import flytekit.models.interface as interface_models from flytekit import PythonFunctionTask from flytekit.extend.backend.agent_service import AsyncAgentService -from flytekit.extend.backend.base_agent import AgentBase, AgentRegistry, AsyncAgentExecutorMixin, is_terminal_state +from flytekit.extend.backend.base_agent import ( + AgentBase, + AgentRegistry, + AsyncAgentExecutorMixin, + convert_to_flyte_state, + is_terminal_state, +) from flytekit.models import literals, task, types from flytekit.models.core.identifier import Identifier, ResourceType from flytekit.models.literals import LiteralMap @@ -97,7 +104,7 @@ def delete(self, context: grpc.ServicerContext, resource_meta: bytes) -> DeleteT def test_dummy_agent(): ctx = MagicMock(spec=grpc.ServicerContext) - agent = AgentRegistry.get_agent(ctx, "dummy") + agent = AgentRegistry.get_agent("dummy") metadata_bytes = json.dumps(asdict(Metadata(job_id=dummy_id))).encode("utf-8") assert agent.create(ctx, "/tmp", dummy_template, task_inputs).resource_meta == metadata_bytes assert agent.get(ctx, metadata_bytes).resource.state == SUCCEEDED @@ -114,7 +121,7 @@ def __init__(self, **kwargs): t.execute() t._task_type = "non-exist-type" - with pytest.raises(Exception, match="Cannot find the agent for the task"): + with pytest.raises(Exception, match="Unrecognized task type non-exist-type"): t.execute() @@ -147,3 +154,19 @@ def test_is_terminal_state(): assert is_terminal_state(PERMANENT_FAILURE) assert is_terminal_state(PERMANENT_FAILURE) assert not is_terminal_state(RUNNING) + + +def test_convert_to_flyte_state(): + assert convert_to_flyte_state("FAILED") == RETRYABLE_FAILURE + assert convert_to_flyte_state("TIMEDOUT") == RETRYABLE_FAILURE + assert convert_to_flyte_state("CANCELED") == RETRYABLE_FAILURE + + assert convert_to_flyte_state("DONE") == SUCCEEDED + assert convert_to_flyte_state("SUCCEEDED") == SUCCEEDED + assert convert_to_flyte_state("SUCCESS") == SUCCEEDED + + assert convert_to_flyte_state("RUNNING") == RUNNING + + invalid_state = "INVALID_STATE" + with pytest.raises(Exception, match=f"Unrecognized state: {invalid_state.lower()}"): + convert_to_flyte_state(invalid_state)