Skip to content

Commit

Permalink
wip: Set/Get step implementation
Browse files Browse the repository at this point in the history
Signed-off-by: Diwank Singh Tomer <[email protected]>
  • Loading branch information
creatorrr committed Sep 3, 2024
1 parent cde3f73 commit cb8ce3e
Show file tree
Hide file tree
Showing 22 changed files with 267 additions and 133 deletions.
1 change: 1 addition & 0 deletions agents-api/agents_api/activities/task_steps/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from .prompt_step import prompt_step
from .raise_complete_async import raise_complete_async
from .return_step import return_step
from .set_value_step import set_value_step
from .switch_step import switch_step
from .tool_call_step import tool_call_step
from .transition_step import transition_step
Expand Down
30 changes: 0 additions & 30 deletions agents-api/agents_api/activities/task_steps/prompt_step.py
Original file line number Diff line number Diff line change
@@ -1,43 +1,13 @@
from beartype import beartype
from temporalio import activity

from ...autogen.openapi_model import (
Content,
ContentModel,
)
from ...clients import (
litellm, # We dont directly import `acompletion` so we can mock it
)
from ...common.protocol.tasks import StepContext, StepOutcome
from ...common.utils.template import render_template


def _content_to_dict(
content: str | list[str] | list[Content | ContentModel], role: str
) -> str | list[dict]:
if isinstance(content, str):
return content

result = []
for s in content:
if isinstance(s, str):
result.append({"content": {"type": "text", "text": s, "role": role}})
elif isinstance(s, Content):
result.append({"content": {"type": s.type, "text": s.text, "role": role}})
elif isinstance(s, ContentModel):
result.append(
{
"content": {
"type": s.type,
"image_url": {"url": s.image_url.url},
"role": role,
}
}
)

return result


@activity.defn
@beartype
async def prompt_step(context: StepContext) -> StepOutcome:
Expand Down
37 changes: 37 additions & 0 deletions agents-api/agents_api/activities/task_steps/set_value_step.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
from typing import Any

from beartype import beartype
from temporalio import activity

from ...activities.utils import simple_eval_dict
from ...common.protocol.tasks import StepContext, StepOutcome
from ...env import testing


@beartype
async def set_value_step(
context: StepContext,
additional_values: dict[str, Any] = {},
override_expr: dict[str, str] | None = None,
) -> StepOutcome:
try:
expr = override_expr if override_expr is not None else context.current_step.set

values = context.model_dump() | additional_values
output = simple_eval_dict(expr, values)
result = StepOutcome(output=output)

return result

except BaseException as e:
activity.logger.error(f"Error in set_value_step: {e}")
return StepOutcome(error=str(e) or repr(e))


# Note: This is here just for clarity. We could have just imported set_value_step directly
# They do the same thing, so we dont need to mock the set_value_step function
mock_set_value_step = set_value_step

set_value_step = activity.defn(name="set_value_step")(
set_value_step if not testing else mock_set_value_step
)
16 changes: 1 addition & 15 deletions agents-api/agents_api/autogen/Tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -544,20 +544,6 @@ class SearchStep(BaseModel):
"""


class SetKey(BaseModel):
model_config = ConfigDict(
populate_by_name=True,
)
key: str
"""
The key to set
"""
value: str
"""
The value to set
"""


class SetStep(BaseModel):
model_config = ConfigDict(
populate_by_name=True,
Expand All @@ -566,7 +552,7 @@ class SetStep(BaseModel):
"""
The kind of step
"""
set: SetKey
set: dict[str, str]
"""
The value to set
"""
Expand Down
55 changes: 47 additions & 8 deletions agents-api/agents_api/workflows/task_execution/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@
IfElseWorkflowStep: task_steps.if_else_step,
ForeachStep: task_steps.for_each_step,
MapReduceStep: task_steps.map_reduce_step,
SetStep: task_steps.set_value_step,
}

# TODO: Avoid local activities for now (currently experimental)
Expand Down Expand Up @@ -143,6 +144,27 @@ async def transition(

@workflow.defn
class TaskExecutionWorkflow:
user_state: dict[str, Any] = {}

def __init__(self) -> None:
self.user_state = {}

@workflow.query
def get_user_state(self) -> dict[str, Any]:
return self.user_state

@workflow.query
def get_user_state_by_key(self, key: str) -> Any:
return self.user_state.get(key)

@workflow.signal
def set_user_state(self, key: str, value: Any) -> None:
self.user_state[key] = value

@workflow.signal
def update_user_state(self, values: dict[str, Any]) -> None:
self.user_state.update(values)

@workflow.run
async def run(
self,
Expand Down Expand Up @@ -485,15 +507,32 @@ async def run(
workflow.logger.debug("Prompt step: Received response")
state = PartialTransition(output=response)

case GetStep(), _:
# FIXME: Implement GetStep
workflow.logger.error("GetStep not yet implemented")
raise ApplicationError("Not implemented")
# FIXME: This is not working as expected
case SetStep(), StepOutcome(output=evaluated_output):
workflow.logger.info("Set step: Updating user state")
self.update_user_state(evaluated_output)

case SetStep(), _:
# FIXME: Implement SetStep
workflow.logger.error("SetStep not yet implemented")
raise ApplicationError("Not implemented")
print("-" * 100)
print("user_state", self.user_state)
print()
print("-" * 100)
print()
print("evaluated_output", evaluated_output)
print("-" * 100)

# Pass along the previous output unchanged
state = PartialTransition(output=context.current_input)

case GetStep(get=key), _:
workflow.logger.info(f"Get step: Fetching '{key}' from user state")
value = self.get_user_state_by_key(key)
workflow.logger.debug(f"Retrieved value: {value}")

print("-" * 100)
print("user_state", self.user_state)
print("-" * 100)

state = PartialTransition(output=value)

case EmbedStep(), _:
# FIXME: Implement EmbedStep
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
# /usr/bin/env python3

MIGRATION_ID = "make_transition_output_optional"
CREATED_AT = 1725323734.591567


def run(client, queries):
joiner = "}\n\n{"

query = joiner.join(queries)
query = f"{{\n{query}\n}}"
client.run(query)


make_transition_output_optional_query = dict(
up="""
?[
execution_id,
transition_id,
output,
type,
current,
next,
task_token,
metadata,
created_at,
updated_at,
] :=
*transitions {
execution_id,
transition_id,
output,
type,
current,
next,
task_token,
metadata,
created_at,
updated_at,
}
:replace transitions {
execution_id: Uuid,
transition_id: Uuid,
=>
type: String,
current: (String, Int),
next: (String, Int)?,
output: Json?, # <--- this is the only change; output is now optional
task_token: String? default null,
metadata: Json default {},
created_at: Float default now(),
updated_at: Float default now(),
}
""",
down="""
?[
execution_id,
transition_id,
output,
type,
current,
next,
task_token,
metadata,
created_at,
updated_at,
] :=
*transitions {
execution_id,
transition_id,
output,
type,
current,
next,
task_token,
metadata,
created_at,
updated_at,
}
:replace transitions {
execution_id: Uuid,
transition_id: Uuid,
=>
type: String,
current: (String, Int),
next: (String, Int)?,
output: Json,
task_token: String? default null,
metadata: Json default {},
created_at: Float default now(),
updated_at: Float default now(),
}
""",
)


queries = [
make_transition_output_optional_query,
]


def up(client):
run(client, [q["up"] for q in queries])


def down(client):
run(client, [q["down"] for q in reversed(queries)])
6 changes: 3 additions & 3 deletions agents-api/poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion agents-api/tests/fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def test_agent(cozo_client=cozo_client, developer_id=test_developer_id):
agent = create_agent(
developer_id=developer_id,
data=CreateAgentRequest(
model="gpt-4o",
model="gpt-4o-mini",
name="test agent",
about="test agent about",
metadata={"test": "test"},
Expand Down
10 changes: 5 additions & 5 deletions agents-api/tests/test_agent_queries.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def _(client=cozo_client, developer_id=test_developer_id):
data=CreateAgentRequest(
name="test agent",
about="test agent about",
model="gpt-4o",
model="gpt-4o-mini",
),
client=client,
)
Expand All @@ -41,7 +41,7 @@ def _(client=cozo_client, developer_id=test_developer_id):
data=CreateAgentRequest(
name="test agent",
about="test agent about",
model="gpt-4o",
model="gpt-4o-mini",
instructions=["test instruction"],
),
client=client,
Expand All @@ -56,7 +56,7 @@ def _(client=cozo_client, developer_id=test_developer_id):
data=CreateOrUpdateAgentRequest(
name="test agent",
about="test agent about",
model="gpt-4o",
model="gpt-4o-mini",
instructions=["test instruction"],
),
client=client,
Expand Down Expand Up @@ -86,7 +86,7 @@ def _(client=cozo_client, developer_id=test_developer_id):
data=CreateAgentRequest(
name="test agent",
about="test agent about",
model="gpt-4o",
model="gpt-4o-mini",
instructions=["test instruction"],
),
client=client,
Expand All @@ -108,7 +108,7 @@ def _(client=cozo_client, developer_id=test_developer_id, agent=test_agent):
data=UpdateAgentRequest(
name="updated agent",
about="updated agent about",
model="gpt-4o",
model="gpt-4o-mini",
default_settings={"temperature": 1.0},
metadata={"hello": "world"},
),
Expand Down
Loading

0 comments on commit cb8ce3e

Please sign in to comment.