-
Notifications
You must be signed in to change notification settings - Fork 301
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Chatgpt Agent #1822
Chatgpt Agent #1822
Changes from all commits
540eb5d
984d44a
3d936fc
931533a
8365c53
25d6a5d
970bf3b
b5a74d8
6087c5f
f8680a1
b225310
09bc23a
bd5dbd7
aeb5ea1
0717e2c
54d2ddf
541edc6
7c9dcbc
18a9e5d
75573ab
61e0a76
8de3fa8
be7d22d
bca202b
c0139db
cf9ff07
3b13b48
24df7b3
563ca22
5fa4f18
5f1183b
203d2d4
9366dfb
66f5e60
3249963
558c6bd
9f5dd0a
c91ee4c
4adf029
402e1a9
2db4de7
420cbf5
1208268
a64f6ee
815dde4
9f3072e
713db18
0ecebdf
a058734
719ae32
5823cb1
c6058fc
e6482ae
c6343ec
ee0b829
a9b16c8
d2ebf6c
a660fc1
a3b0ecc
bc9fb27
6de876b
f785a19
ff44060
8892dee
1810a7b
8d5bb61
fd8dd5d
36b315e
e6b0ba0
dc87c81
8ea1a9d
42b8d90
14a2698
67582fd
c273689
232c80c
f842fdd
835c48a
d1e99be
b0670d3
6e8f30d
104ba6f
ec35714
c463992
9758ce6
868f81b
0068137
6d15fdc
b56c845
510711e
9d1e5bd
b09a627
5a0348b
ea91dc7
7786aca
ca34c61
1d9535d
68fd527
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,71 @@ | ||
import collections | ||
import inspect | ||
from abc import abstractmethod | ||
from typing import Any, Dict, Optional, TypeVar | ||
|
||
from flyteidl.admin.agent_pb2 import CreateTaskResponse | ||
from typing_extensions import get_type_hints | ||
|
||
from flytekit.configuration import SerializationSettings | ||
from flytekit.core.base_task import PythonTask | ||
from flytekit.core.interface import Interface | ||
from flytekit.extend.backend.base_agent import AsyncAgentExecutorMixin | ||
|
||
T = TypeVar("T") | ||
TASK_MODULE = "task_module" | ||
TASK_NAME = "task_name" | ||
TASK_CONFIG = "task_config" | ||
TASK_TYPE = "api_task" | ||
|
||
|
||
class ExternalApiTask(AsyncAgentExecutorMixin, PythonTask): | ||
""" | ||
Base class for all external API tasks. External API tasks are tasks that are designed to run until they receive a | ||
response from an external service. When the response is received, the task will complete. External API tasks are | ||
designed to be run by the flyte agent. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
name: str, | ||
config: Optional[T] = None, | ||
task_type: str = TASK_TYPE, | ||
return_type: Optional[Any] = None, | ||
**kwargs, | ||
): | ||
type_hints = get_type_hints(self.do, include_extras=True) | ||
signature = inspect.signature(self.do) | ||
inputs = collections.OrderedDict() | ||
outputs = collections.OrderedDict({"o0": return_type}) if return_type else collections.OrderedDict() | ||
|
||
for k, _ in signature.parameters.items(): # type: ignore | ||
annotation = type_hints.get(k, None) | ||
inputs[k] = annotation | ||
|
||
super().__init__( | ||
task_type=task_type, | ||
name=name, | ||
task_config=config, | ||
interface=Interface(inputs=inputs, outputs=outputs), | ||
**kwargs, | ||
) | ||
|
||
self._task_config = config | ||
|
||
@abstractmethod | ||
async def do(self, **kwargs) -> CreateTaskResponse: | ||
""" | ||
Initiate an HTTP request to an external service such as OpenAI or Vertex AI and retrieve the response. | ||
""" | ||
raise NotImplementedError | ||
|
||
def get_custom(self, settings: SerializationSettings = None) -> Dict[str, Any]: | ||
cfg = { | ||
TASK_MODULE: type(self).__module__, | ||
TASK_NAME: type(self).__name__, | ||
} | ||
|
||
if self._task_config is not None: | ||
cfg[TASK_CONFIG] = self._task_config | ||
|
||
return cfg |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,71 @@ | ||
import importlib | ||
import typing | ||
from dataclasses import dataclass | ||
from typing import final | ||
|
||
import grpc | ||
from flyteidl.admin.agent_pb2 import CreateTaskResponse | ||
|
||
from flytekit import FlyteContextManager | ||
from flytekit.core.external_api_task import TASK_CONFIG, TASK_MODULE, TASK_NAME, TASK_TYPE | ||
from flytekit.core.type_engine import TypeEngine | ||
from flytekit.extend.backend.base_agent import AgentBase, AgentRegistry | ||
from flytekit.models.literals import LiteralMap | ||
from flytekit.models.task import TaskTemplate | ||
|
||
T = typing.TypeVar("T") | ||
|
||
|
||
@dataclass | ||
class IOContext: | ||
inputs: LiteralMap | ||
output_prefix: str | ||
|
||
|
||
class SyncAgentBase(AgentBase): | ||
""" | ||
SyncAgentBase is an agent responsible for syncrhounous tasks, which are fast and quick. | ||
|
||
This class is meant to be subclassed when implementing plugins that require | ||
an external API to perform the task execution. It provides a routing mechanism | ||
to direct the task to the appropriate handler based on the task's specifications. | ||
""" | ||
|
||
def __init__(self): | ||
super().__init__(task_type=TASK_TYPE, asynchronous=True) | ||
|
||
@final | ||
async def async_create( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. i would change this to SyncAgentBase as follows class SyncAgentBase(AgentBase):
@final
async def async_create( self,
context: grpc.ServicerContext, io_ctx: IOContext, task_template: TaskTemplate,
) -> CreateTaskResponse:
do(context, output_prefix, task_template, inputs)
def async do(context, io_ctx, task_template):
python_interface_inputs = {
name: TypeEngine.guess_python_type(lt.type) for name, lt in task_template.interface.inputs.items()
}
ctx = FlyteContextManager.current_context()
native_inputs = {}
if inputs:
native_inputs = TypeEngine.literal_map_to_kwargs(ctx, inputs, python_interface_inputs)
meta = task_template.custom
task_module = importlib.import_module(name=meta[TASK_MODULE])
task_def = getattr(task_module, meta[TASK_NAME])
config = jsonpickle.decode(meta[TASK_CONFIG_PKL]) if meta.get(TASK_CONFIG_PKL) else None
return task_def(TASK_TYPE, config=config).execute(**native_inputs)
async def execute(**kwargs):
raise NotImplementedError() Checkout the final decorator - https://docs.python.org/3.8/library/typing.html#typing.final Independently, lets are make the signature of all the get/create/delete methods simpler. Think, if we have to refactor the signature in the future how can you do it easily? @dataclass
class IOContext():
inputs: LiteralMap
output_prefix: str OR change the signature to have def __init__(self, *args, **kwargs):
...
async def create(ctx, inputs, outputs, task_template, **kwargs):
... |
||
self, | ||
context: grpc.ServicerContext, | ||
output_prefix: str, | ||
task_template: TaskTemplate, | ||
inputs: typing.Optional[LiteralMap] = None, | ||
) -> CreateTaskResponse: | ||
return await self.do(context, output_prefix, task_template, inputs) | ||
|
||
async def do( | ||
self, | ||
context: grpc.ServicerContext, | ||
output_prefix: str, | ||
task_template: TaskTemplate, | ||
inputs: typing.Optional[LiteralMap] = None, | ||
) -> CreateTaskResponse: | ||
python_interface_inputs = { | ||
name: TypeEngine.guess_python_type(lt.type) for name, lt in task_template.interface.inputs.items() | ||
} | ||
ctx = FlyteContextManager.current_context() | ||
|
||
native_inputs = {} | ||
if inputs: | ||
native_inputs = TypeEngine.literal_map_to_kwargs(ctx, inputs, python_interface_inputs) | ||
|
||
meta = task_template.custom | ||
|
||
task_module = importlib.import_module(name=meta[TASK_MODULE]) | ||
task_def = getattr(task_module, meta[TASK_NAME]) | ||
config = meta[TASK_CONFIG] if meta.get(TASK_CONFIG) else None | ||
return await task_def(TASK_TYPE, config=config).do(**native_inputs) | ||
|
||
|
||
AgentRegistry.register(SyncAgentBase()) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
pytest-asyncio |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
# | ||
# This file is autogenerated by pip-compile with Python 3.9 | ||
# by the following command: | ||
# | ||
# pip-compile dev-requirements.in | ||
# | ||
exceptiongroup==1.1.3 | ||
# via pytest | ||
iniconfig==2.0.0 | ||
# via pytest | ||
packaging==23.2 | ||
# via pytest | ||
pluggy==1.3.0 | ||
# via pytest | ||
pytest==7.4.2 | ||
# via pytest-asyncio | ||
pytest-asyncio==0.21.1 | ||
# via -r dev-requirements.in | ||
tomli==2.0.1 | ||
# via pytest |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why do we have to use pickle?
cc @pingsutw do you know? is this to maintain simply python coversion? This is potentially dangerous as it may break across python versions.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
pickle is faster, we've talked about this before and Kevin told me that it might be ok.
Will there be multiple python versions usecases?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we cannot control backend and flytekit version
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you are right, I fix it!
Thank you very much for your advice!