Skip to content

Commit

Permalink
feat(agents-api): Set/get steps based on workflow state
Browse files Browse the repository at this point in the history
Signed-off-by: Diwank Singh Tomer <[email protected]>
  • Loading branch information
creatorrr committed Sep 3, 2024
1 parent cb8ce3e commit fdc7bfc
Show file tree
Hide file tree
Showing 4 changed files with 48 additions and 25 deletions.
1 change: 1 addition & 0 deletions agents-api/agents_api/clients/temporal.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,4 +54,5 @@ async def run_task_execution_workflow(
task_queue=temporal_task_queue,
id=str(job_id),
run_timeout=timedelta(days=31),
# TODO: Should add search_attributes for queryability
)
2 changes: 1 addition & 1 deletion agents-api/agents_api/common/utils/template.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
simple_jinja_regex = re.compile(r"{{|{%.+}}|%}", re.DOTALL)


# FIXME: This does not work for some reason
# TODO: This does not work for some reason
def is_simple_jinja(template_string: str) -> bool:
return simple_jinja_regex.search(template_string) is None

Expand Down
68 changes: 45 additions & 23 deletions agents-api/agents_api/workflows/task_execution/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,24 @@ async def transition(
raise ApplicationError(f"Error in transition: {e}") from e


async def continue_as_child(
execution_input: ExecutionInput,
start: TransitionTarget,
previous_inputs: list[Any],
user_state: dict[str, Any] = {},
) -> Any:
return await workflow.execute_child_workflow(
TaskExecutionWorkflow.run,
args=[
execution_input,
start,
previous_inputs,
user_state,
],
# TODO: Should add search_attributes for queryability
)


@workflow.defn
class TaskExecutionWorkflow:
user_state: dict[str, Any] = {}
Expand Down Expand Up @@ -171,7 +189,11 @@ async def run(
execution_input: ExecutionInput,
start: TransitionTarget = TransitionTarget(workflow="main", step=0),
previous_inputs: list[Any] = [],
user_state: dict[str, Any] = {},
) -> Any:
# Set the initial user state
self.user_state = user_state

workflow.logger.info(
f"TaskExecutionWorkflow for task {execution_input.task.id}"
f" [LOC {start.workflow}.{start.step}]"
Expand Down Expand Up @@ -297,9 +319,9 @@ async def run(
]

# Execute the chosen branch and come back here
result = await workflow.execute_child_workflow(
TaskExecutionWorkflow.run,
args=case_args,
result = await continue_as_child(
*case_args,
user_state=self.user_state,
)

state = PartialTransition(output=result)
Expand Down Expand Up @@ -342,9 +364,9 @@ async def run(
]

# Execute the chosen branch and come back here
result = await workflow.execute_child_workflow(
TaskExecutionWorkflow.run,
args=if_else_args,
result = await continue_as_child(
*if_else_args,
user_state=self.user_state,
)

state = PartialTransition(output=result)
Expand Down Expand Up @@ -376,9 +398,9 @@ async def run(
]

# Execute the chosen branch and come back here
result = await workflow.execute_child_workflow(
TaskExecutionWorkflow.run,
args=foreach_args,
result = await continue_as_child(
*foreach_args,
user_state=self.user_state,
)

state = PartialTransition(output=result)
Expand Down Expand Up @@ -417,9 +439,9 @@ async def run(

# TODO: We should parallelize this
# Execute the chosen branch and come back here
output = await workflow.execute_child_workflow(
TaskExecutionWorkflow.run,
args=map_reduce_args,
output = await continue_as_child(
*map_reduce_args,
user_state=self.user_state,
)

# Reduce the result with the initial value
Expand Down Expand Up @@ -483,9 +505,11 @@ async def run(
next=yield_next_target,
)

result = await workflow.execute_child_workflow(
TaskExecutionWorkflow.run,
args=[execution_input, yield_next_target, [output]],
result = await continue_as_child(
execution_input=execution_input,
start=yield_next_target,
previous_inputs=[output],
user_state=self.user_state,
)

state = PartialTransition(output=result)
Expand Down Expand Up @@ -555,7 +579,7 @@ async def run(
raise ApplicationError("Not implemented")

case _:
# FIXME: Add steps that are not yet supported
# TODO: Add steps that are not yet supported
workflow.logger.error(
f"Unhandled step type: {type(context.current_step).__name__}"
)
Expand Down Expand Up @@ -585,11 +609,9 @@ async def run(
)

# TODO: Should use a continue_as_new workflow if history grows too large
return await workflow.execute_child_workflow(
TaskExecutionWorkflow.run,
args=[
execution_input,
final_state.next,
previous_inputs + [final_state.output],
],
return await continue_as_child(
execution_input=execution_input,
start=final_state.next,
previous_inputs=previous_inputs + [final_state.output],
user_state=self.user_state,
)
2 changes: 1 addition & 1 deletion agents-api/tests/test_entry_queries.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def _(client=cozo_client, developer_id=test_developer_id, session=test_session):
content="test entry content",
)

# FIXME: We should make sessions.updated_at also a updated_at_ms field to avoid this sleep
# TODO: We should make sessions.updated_at also a updated_at_ms field to avoid this sleep
time.sleep(1)

create_entries(
Expand Down

0 comments on commit fdc7bfc

Please sign in to comment.