Skip to content

Commit

Permalink
update tests
Browse files Browse the repository at this point in the history
Signed-off-by: noahjax <[email protected]>
  • Loading branch information
noahjax committed Mar 21, 2024
1 parent 8188db1 commit 2f790da
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 10 deletions.
2 changes: 1 addition & 1 deletion flytekit/models/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -581,7 +581,7 @@ def from_flyte_idl(cls, pb2_object):
:rtype: TaskExecutionMetadata
"""
return cls(
task_execution_id=_identifier.TaskExecutionIdentifier.from_flyte_idl(pb2_object.id),
task_execution_id=_identifier.TaskExecutionIdentifier.from_flyte_idl(pb2_object.task_execution_id),
namespace=pb2_object.namespace,
labels={k: v for k, v in pb2_object.labels.items()} if pb2_object.labels is not None else None,
annotations={k: v for k, v in pb2_object.annotations.items()}
Expand Down
29 changes: 20 additions & 9 deletions tests/flytekit/unit/extend/test_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@
ListAgentsResponse,
TaskCategory,
)
from flyteidl.core.identifier_pb2 import ResourceType
from flyteidl.core.execution_pb2 import TaskExecution, TaskLog
from flyteidl.core.identifier_pb2 import ResourceType

from flytekit import PythonFunctionTask, task
from flytekit.configuration import FastSerializationSettings, Image, ImageConfig, SerializationSettings
Expand All @@ -38,7 +38,12 @@
)
from flytekit.extend.backend.utils import convert_to_flyte_phase, get_agent_secret
from flytekit.models import literals
from flytekit.models.core.identifier import Identifier, NodeExecutionIdentifier, TaskExecutionIdentifier, WorkflowExecutionIdentifier
from flytekit.models.core.identifier import (
Identifier,
NodeExecutionIdentifier,
TaskExecutionIdentifier,
WorkflowExecutionIdentifier,
)
from flytekit.models.literals import LiteralMap
from flytekit.models.task import TaskExecutionMetadata, TaskTemplate
from flytekit.tools.translator import get_serializable
Expand Down Expand Up @@ -179,20 +184,26 @@ def __init__(self, **kwargs):
t.execute()


@pytest.mark.parametrize("agent", [DummyAgent(), AsyncDummyAgent()], ids=["sync", "async"])
@pytest.mark.parametrize(
"agent,consume_metadata", [(DummyAgent(), False), (AsyncDummyAgent(), True)], ids=["sync", "async"]
)
@pytest.mark.asyncio
async def test_async_agent_service(agent):
async def test_async_agent_service(agent, consume_metadata):
AgentRegistry.register(agent, override=True)
service = AsyncAgentService()
ctx = MagicMock(spec=grpc.ServicerContext)

inputs_proto = task_inputs.to_flyte_idl()
output_prefix = "/tmp"
metadata_bytes = DummyMetadata(
job_id=dummy_id,
output_path=f"{output_prefix}/{dummy_id}",
task_name=task_execution_metadata.task_execution_id.task_id.name,
).encode()
metadata_bytes = (
DummyMetadata(
job_id=dummy_id,
output_path=f"{output_prefix}/{dummy_id}",
task_name=task_execution_metadata.task_execution_id.task_id.name,
).encode()
if consume_metadata
else DummyMetadata(job_id=dummy_id).encode()
)

tmp = get_task_template(agent.task_category.name).to_flyte_idl()
task_category = TaskCategory(name=agent.task_category.name, version=0)
Expand Down

0 comments on commit 2f790da

Please sign in to comment.