Skip to content

Commit

Permalink
Agent Metadata Servicer (#2012)
Browse files Browse the repository at this point in the history
Signed-off-by: Future Outlier <[email protected]>
Signed-off-by: Kevin Su <[email protected]>
Signed-off-by: Future-Outlier <[email protected]>
Co-authored-by: Future Outlier <[email protected]>
Co-authored-by: Kevin Su <[email protected]>
  • Loading branch information
3 people authored Jan 31, 2024
1 parent c63f011 commit 7179d4c
Show file tree
Hide file tree
Showing 11 changed files with 96 additions and 18 deletions.
8 changes: 6 additions & 2 deletions flytekit/clis/sdk_in_container/serve.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
from concurrent import futures

import rich_click as click
from flyteidl.service.agent_pb2_grpc import add_AsyncAgentServiceServicer_to_server
from flyteidl.service.agent_pb2_grpc import (
add_AgentMetadataServiceServicer_to_server,
add_AsyncAgentServiceServicer_to_server,
)
from grpc import aio


Expand Down Expand Up @@ -49,7 +52,7 @@ def agent(_: click.Context, port, worker, timeout):

async def _start_grpc_server(port: int, worker: int, timeout: int):
click.secho("Starting up the server to expose the prometheus metrics...", fg="blue")
from flytekit.extend.backend.agent_service import AsyncAgentService
from flytekit.extend.backend.agent_service import AgentMetadataService, AsyncAgentService

try:
from prometheus_client import start_http_server
Expand All @@ -61,6 +64,7 @@ async def _start_grpc_server(port: int, worker: int, timeout: int):
server = aio.server(futures.ThreadPoolExecutor(max_workers=worker))

add_AsyncAgentServiceServicer_to_server(AsyncAgentService(), server)
add_AgentMetadataServiceServicer_to_server(AgentMetadataService(), server)

server.add_insecure_port(f"[::]:{port}")
await server.start()
Expand Down
31 changes: 25 additions & 6 deletions flytekit/extend/backend/agent_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,14 @@
CreateTaskResponse,
DeleteTaskRequest,
DeleteTaskResponse,
GetAgentRequest,
GetAgentResponse,
GetTaskRequest,
GetTaskResponse,
ListAgentsRequest,
ListAgentsResponse,
)
from flyteidl.service.agent_pb2_grpc import AsyncAgentServiceServicer
from flyteidl.service.agent_pb2_grpc import AgentMetadataServiceServicer, AsyncAgentServiceServicer
from prometheus_client import Counter, Summary

from flytekit import logger
Expand All @@ -26,18 +30,20 @@

# Follow the naming convention. https://prometheus.io/docs/practices/naming/
request_success_count = Counter(
f"{metric_prefix}requests_success_total", "Total number of successful requests", ["task_type", "operation"]
f"{metric_prefix}requests_success_total",
"Total number of successful requests",
["task_type", "operation"],
)
request_failure_count = Counter(
f"{metric_prefix}requests_failure_total",
"Total number of failed requests",
["task_type", "operation", "error_code"],
)

request_latency = Summary(
f"{metric_prefix}request_latency_seconds", "Time spent processing agent request", ["task_type", "operation"]
f"{metric_prefix}request_latency_seconds",
"Time spent processing agent request",
["task_type", "operation"],
)

input_literal_size = Summary(f"{metric_prefix}input_literal_bytes", "Size of input literal", ["task_type"])


Expand Down Expand Up @@ -96,8 +102,12 @@ async def CreateTask(self, request: CreateTaskRequest, context: grpc.ServicerCon
logger.info(f"{tmp.type} agent start creating the job")
if agent.asynchronous:
return await agent.async_create(
context=context, inputs=inputs, output_prefix=request.output_prefix, task_template=tmp
context=context,
inputs=inputs,
output_prefix=request.output_prefix,
task_template=tmp,
)

return await asyncio.get_running_loop().run_in_executor(
None,
agent.create,
Expand All @@ -122,3 +132,12 @@ async def DeleteTask(self, request: DeleteTaskRequest, context: grpc.ServicerCon
if agent.asynchronous:
return await agent.async_delete(context=context, resource_meta=request.resource_meta)
return await asyncio.get_running_loop().run_in_executor(None, agent.delete, context, request.resource_meta)


class AgentMetadataService(AgentMetadataServiceServicer):
async def GetAgent(self, request: GetAgentRequest, context: grpc.ServicerContext) -> GetAgentResponse:
return GetAgentResponse(agent=AgentRegistry._METADATA[request.name])

async def ListAgents(self, request: ListAgentsRequest, context: grpc.ServicerContext) -> ListAgentsResponse:
agents = [agent for agent in AgentRegistry._METADATA.values()]
return ListAgentsResponse(agents=agents)
27 changes: 23 additions & 4 deletions flytekit/extend/backend/base_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
RETRYABLE_FAILURE,
RUNNING,
SUCCEEDED,
Agent,
CreateTaskResponse,
DeleteTaskResponse,
GetTaskResponse,
Expand Down Expand Up @@ -45,7 +46,9 @@ class AgentBase(ABC):
will look up the agent based on the task type. Every task type can only have one agent.
"""

def __init__(self, task_type: str, asynchronous=True):
name = "Base Agent"

def __init__(self, task_type: str, asynchronous: bool = True):
self._task_type = task_type
self._asynchronous = asynchronous

Expand Down Expand Up @@ -113,25 +116,41 @@ async def async_delete(self, context: grpc.ServicerContext, resource_meta: bytes

class AgentRegistry(object):
"""
This is the registry for all agents. The agent service will look up the agent
based on the task type.
This is the registry for all agents.
The agent service will look up the agent registry based on the task type.
The agent metadata service will look up the agent metadata based on the agent name.
"""

_REGISTRY: typing.Dict[str, AgentBase] = {}
_METADATA: typing.Dict[str, Agent] = {}

@staticmethod
def register(agent: AgentBase):
if agent.task_type in AgentRegistry._REGISTRY:
raise ValueError(f"Duplicate agent for task type {agent.task_type}")
AgentRegistry._REGISTRY[agent.task_type] = agent
logger.info(f"Registering an agent for task type {agent.task_type}")

if agent.name in AgentRegistry._METADATA:
agent_metadata = AgentRegistry._METADATA[agent.name]
agent_metadata.supported_task_types.append(agent.task_type)
else:
agent_metadata = Agent(name=agent.name, supported_task_types=[agent.task_type])
AgentRegistry._METADATA[agent.name] = agent_metadata

logger.info(f"Registering an agent for task type: {agent.task_type}, name: {agent.name}")

@staticmethod
def get_agent(task_type: str) -> typing.Optional[AgentBase]:
if task_type not in AgentRegistry._REGISTRY:
raise FlyteAgentNotFound(f"Cannot find agent for task type: {task_type}.")
return AgentRegistry._REGISTRY[task_type]

@staticmethod
def get_agent_metadata(name: str) -> Agent:
if name not in AgentRegistry._METADATA:
raise FlyteAgentNotFound(f"Cannot find agent for name: {name}.")
return AgentRegistry._METADATA[name]


def convert_to_flyte_state(state: str) -> State:
"""
Expand Down
2 changes: 2 additions & 0 deletions flytekit/sensor/sensor_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@


class SensorEngine(AgentBase):
name = "Sensor"

def __init__(self):
super().__init__(task_type="sensor", asynchronous=True)

Expand Down
2 changes: 2 additions & 0 deletions plugins/flytekit-airflow/flytekitplugins/airflow/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,8 @@ class AirflowAgent(AgentBase):
In this case, those operators will be converted to AirflowContainerTask and executed in the pod.
"""

name = "Airflow Agent"

def __init__(self):
super().__init__(task_type="airflow", asynchronous=True)

Expand Down
2 changes: 2 additions & 0 deletions plugins/flytekit-bigquery/flytekitplugins/bigquery/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@ class Metadata:


class BigQueryAgent(AgentBase):
name = "Bigquery Agent"

def __init__(self):
super().__init__(task_type="bigquery_query_job_task", asynchronous=False)

Expand Down
4 changes: 3 additions & 1 deletion plugins/flytekit-mmcloud/flytekitplugins/mmcloud/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,10 @@ class Metadata:


class MMCloudAgent(AgentBase):
name = "MMCloud Agent"

def __init__(self):
super().__init__(task_type="mmcloud_task")
super().__init__(task_type="mmcloud_task", asynchronous=True)
self._response_format = ["--format", "json"]

async def async_login(self):
Expand Down
2 changes: 1 addition & 1 deletion plugins/flytekit-papermill/dev-requirements.in
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
flyteidl>=1.10.0
flyteidl>=1.10.7b0
-e file:../../.#egg=flytekitplugins-pod&subdirectory=plugins/flytekit-k8s-pod
-e file:../../.#egg=flytekitplugins-spark&subdirectory=plugins/flytekit-spark
-e file:../../.#egg=flytekitplugins-awsbatch&subdirectory=plugins/flytekit-aws-batch
6 changes: 4 additions & 2 deletions plugins/flytekit-papermill/dev-requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ docker-image-py==0.1.12
# via flytekit
docstring-parser==0.15
# via flytekit
flyteidl==1.10.6
flyteidl==1.10.7b0
# via
# -r dev-requirements.in
# flytekit
Expand Down Expand Up @@ -235,7 +235,9 @@ packaging==23.2
# docker
# marshmallow
pandas==1.5.3
# via flytekit
# via
# flytekit
# flytekitplugins-spark
portalocker==2.8.2
# via msal-extensions
protobuf==4.24.4
Expand Down
4 changes: 3 additions & 1 deletion plugins/flytekit-spark/flytekitplugins/spark/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,10 @@ class Metadata:


class DatabricksAgent(AgentBase):
name = "Databricks Agent"

def __init__(self):
super().__init__(task_type="spark")
super().__init__(task_type="spark", asynchronous=True)

async def async_create(
self,
Expand Down
26 changes: 25 additions & 1 deletion tests/flytekit/unit/extend/test_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,14 @@
DeleteTaskResponse,
GetTaskRequest,
GetTaskResponse,
ListAgentsRequest,
ListAgentsResponse,
Resource,
)

from flytekit import PythonFunctionTask, task
from flytekit.configuration import FastSerializationSettings, Image, ImageConfig, SerializationSettings
from flytekit.extend.backend.agent_service import AsyncAgentService
from flytekit.extend.backend.agent_service import AgentMetadataService, AsyncAgentService
from flytekit.extend.backend.base_agent import (
AgentBase,
AgentRegistry,
Expand All @@ -49,6 +51,8 @@ class Metadata:


class DummyAgent(AgentBase):
name = "Dummy Agent"

def __init__(self):
super().__init__(task_type="dummy", asynchronous=False)

Expand All @@ -71,6 +75,8 @@ def delete(self, context: grpc.ServicerContext, resource_meta: bytes) -> DeleteT


class AsyncDummyAgent(AgentBase):
name = "Async Dummy Agent"

def __init__(self):
super().__init__(task_type="async_dummy", asynchronous=True)

Expand All @@ -91,6 +97,8 @@ async def async_delete(self, context: grpc.ServicerContext, resource_meta: bytes


class SyncDummyAgent(AgentBase):
name = "Sync Dummy Agent"

def __init__(self):
super().__init__(task_type="sync_dummy", asynchronous=True)

Expand Down Expand Up @@ -161,6 +169,10 @@ def __init__(self, **kwargs):
with pytest.raises(Exception, match="Cannot find agent for task type: non-exist-type."):
t.execute()

agent_metadata = AgentRegistry.get_agent_metadata("Dummy Agent")
assert agent_metadata.name == "Dummy Agent"
assert agent_metadata.supported_task_types == ["dummy"]


@pytest.mark.asyncio
async def test_async_dummy_agent():
Expand All @@ -175,6 +187,10 @@ async def test_async_dummy_agent():
res = await agent.async_delete(ctx, metadata_bytes)
assert res == DeleteTaskResponse()

agent_metadata = AgentRegistry.get_agent_metadata("Async Dummy Agent")
assert agent_metadata.name == "Async Dummy Agent"
assert agent_metadata.supported_task_types == ["async_dummy"]


@pytest.mark.asyncio
async def test_sync_dummy_agent():
Expand All @@ -185,6 +201,10 @@ async def test_sync_dummy_agent():
assert res.resource.state == SUCCEEDED
assert res.resource.outputs == LiteralMap({}).to_flyte_idl()

agent_metadata = AgentRegistry.get_agent_metadata("Sync Dummy Agent")
assert agent_metadata.name == "Sync Dummy Agent"
assert agent_metadata.supported_task_types == ["sync_dummy"]


@pytest.mark.asyncio
async def run_agent_server():
Expand Down Expand Up @@ -223,6 +243,10 @@ async def run_agent_server():
res = await service.GetTask(GetTaskRequest(task_type=fake_agent, resource_meta=metadata_bytes), ctx)
assert res is None

metadata_service = AgentMetadataService()
res = await metadata_service.ListAgent(ListAgentsRequest(), ctx)
assert isinstance(res, ListAgentsResponse)


def test_agent_server():
loop.run_in_executor(None, run_agent_server)
Expand Down

0 comments on commit 7179d4c

Please sign in to comment.