Skip to content

Commit

Permalink
chore: fix @step usage in the code (#15588)
Browse files Browse the repository at this point in the history
  • Loading branch information
masci authored Aug 23, 2024
1 parent 849d9e5 commit 93fc38b
Show file tree
Hide file tree
Showing 8 changed files with 33 additions and 33 deletions.
2 changes: 1 addition & 1 deletion llama-index-core/llama_index/core/workflow/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def decorator(func: Callable) -> Callable:
# If this is a free function, call add_step() explicitly.
if is_free_function(func.__qualname__):
if workflow is None:
msg = f"To decorate {func.__name__} please pass a workflow class to the @step() decorator."
msg = f"To decorate {func.__name__} please pass a workflow class to the @step decorator."
raise WorkflowValidationError(msg)
workflow.add_step(func)

Expand Down
6 changes: 3 additions & 3 deletions llama-index-core/llama_index/core/workflow/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def _start(self, stepwise: bool = False) -> WorkflowSession:
step_func, "__step_config", None
)
if not step_config:
raise ValueError(f"Step {name} is missing `@step()` decorator.")
raise ValueError(f"Step {name} is missing `@step` decorator.")

async def _task(
name: str,
Expand Down Expand Up @@ -295,7 +295,7 @@ def is_done(self) -> bool:
"""Checks if the workflow is done."""
return self._step_session is None

@step()
@step
async def _done(self, ctx: Context, ev: StopEvent) -> None:
"""Tears down the whole workflow and stop execution."""
ctx.session._retval = ev.result or None
Expand All @@ -316,7 +316,7 @@ def _validate(self) -> None:
step_func, "__step_config", None
)
if not step_config:
raise ValueError(f"Step {name} is missing `@step()` decorator.")
raise ValueError(f"Step {name} is missing `@step` decorator.")

for event_type in step_config.accepted_events:
consumed_events.add(event_type)
Expand Down
6 changes: 3 additions & 3 deletions llama-index-core/tests/workflow/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,15 @@ class LastEvent(Event):


class DummyWorkflow(Workflow):
@step()
@step
async def start_step(self, ev: StartEvent) -> OneTestEvent:
return OneTestEvent()

@step()
@step
async def middle_step(self, ev: OneTestEvent) -> LastEvent:
return LastEvent()

@step()
@step
async def end_step(self, ev: LastEvent) -> StopEvent:
return StopEvent(result="Workflow completed")

Expand Down
6 changes: 3 additions & 3 deletions llama-index-core/tests/workflow/test_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,15 +23,15 @@ async def test_collect_events():
ev2 = AnotherTestEvent()

class TestWorkflow(Workflow):
@step()
@step
async def step1(self, _: StartEvent) -> OneTestEvent:
return ev1

@step()
@step
async def step2(self, _: StartEvent) -> AnotherTestEvent:
return ev2

@step(pass_context=True)
@step
async def step3(
self, ctx: Context, ev: Union[OneTestEvent, AnotherTestEvent]
) -> Optional[StopEvent]:
Expand Down
6 changes: 3 additions & 3 deletions llama-index-core/tests/workflow/test_decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ class TestWorkflow(Workflow):
def f1(self, ev: Event) -> Event:
return ev

@step()
@step
def f2(self, ev: Event) -> Event:
return ev

Expand Down Expand Up @@ -57,11 +57,11 @@ def test_decorate_free_function_wrong_decorator():
with pytest.raises(
WorkflowValidationError,
match=re.escape(
"To decorate f please pass a workflow class to the @step() decorator."
"To decorate f please pass a workflow class to the @step decorator."
),
):

@step()
@step
def f(ev: Event) -> Event:
return Event()

Expand Down
6 changes: 3 additions & 3 deletions llama-index-core/tests/workflow/test_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
class ServiceWorkflow(Workflow):
"""This wokflow is only responsible to generate a number, it knows nothing about the caller."""

@step()
@step
async def generate(self, ev: StartEvent) -> StopEvent:
return StopEvent(result=42)

Expand All @@ -30,14 +30,14 @@ class DummyWorkflow(Workflow):
and it only knows it has to call `run` on that instance.
"""

@step()
@step
async def get_a_number(
self, service_workflow: ServiceWorkflow, ev: StartEvent, ctx: Context
) -> NumGenerated:
res = await service_workflow.run()
return NumGenerated(num=int(res))

@step()
@step
async def multiply(self, ev: NumGenerated) -> StopEvent:
return StopEvent(ev.num * 2)

Expand Down
4 changes: 2 additions & 2 deletions llama-index-core/tests/workflow/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,11 +135,11 @@ def f2(ev: OneTestEvent, foo: OneTestEvent):

def test_get_steps_from():
class Test:
@step()
@step
def start(self, start: StartEvent) -> OneTestEvent:
return OneTestEvent()

@step()
@step
def my_method(self, event: OneTestEvent) -> StopEvent:
return StopEvent()

Expand Down
30 changes: 15 additions & 15 deletions llama-index-core/tests/workflow/test_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ async def test_workflow_run_step(workflow):
@pytest.mark.asyncio()
async def test_workflow_timeout():
class SlowWorkflow(Workflow):
@step()
@step
async def slow_step(self, ev: StartEvent) -> StopEvent:
await asyncio.sleep(5.0)
return StopEvent(result="Done")
Expand All @@ -71,7 +71,7 @@ async def slow_step(self, ev: StartEvent) -> StopEvent:
@pytest.mark.asyncio()
async def test_workflow_validation():
class InvalidWorkflow(Workflow):
@step()
@step
async def invalid_step(self, ev: StartEvent) -> None:
pass

Expand All @@ -85,12 +85,12 @@ async def test_workflow_event_propagation():
events = []

class EventTrackingWorkflow(Workflow):
@step()
@step
async def step1(self, ev: StartEvent) -> OneTestEvent:
events.append("step1")
return OneTestEvent()

@step()
@step
async def step2(self, ev: OneTestEvent) -> StopEvent:
events.append("step2")
return StopEvent(result="Done")
Expand All @@ -103,11 +103,11 @@ async def step2(self, ev: OneTestEvent) -> StopEvent:
@pytest.mark.asyncio()
async def test_sync_async_steps():
class SyncAsyncWorkflow(Workflow):
@step()
@step
async def async_step(self, ev: StartEvent) -> OneTestEvent:
return OneTestEvent()

@step()
@step
def sync_step(self, ev: OneTestEvent) -> StopEvent:
return StopEvent(result="Done")

Expand All @@ -119,7 +119,7 @@ def sync_step(self, ev: OneTestEvent) -> StopEvent:
@pytest.mark.asyncio()
async def test_workflow_num_workers():
class NumWorkersWorkflow(Workflow):
@step(pass_context=True)
@step
async def original_step(
self, ctx: Context, ev: StartEvent
) -> OneTestEvent | LastEvent:
Expand All @@ -135,7 +135,7 @@ async def test_step(self, ev: OneTestEvent) -> AnotherTestEvent:
await asyncio.sleep(1.0)
return AnotherTestEvent(another_test_param=ev.test_param)

@step(pass_context=True)
@step
async def final_step(
self, ctx: Context, ev: AnotherTestEvent | LastEvent
) -> StopEvent:
Expand Down Expand Up @@ -165,16 +165,16 @@ async def final_step(
@pytest.mark.asyncio()
async def test_workflow_step_send_event():
class StepSendEventWorkflow(Workflow):
@step()
@step
async def step1(self, ctx: Context, ev: StartEvent) -> OneTestEvent:
ctx.session.send_event(OneTestEvent(), step="step2")
return None

@step()
@step
async def step2(self, ev: OneTestEvent) -> StopEvent:
return StopEvent(result="step2")

@step()
@step
async def step3(self, ev: OneTestEvent) -> StopEvent:
return StopEvent(result="step3")

Expand All @@ -190,12 +190,12 @@ async def step3(self, ev: OneTestEvent) -> StopEvent:
@pytest.mark.asyncio()
async def test_workflow_step_send_event_to_None():
class StepSendEventToNoneWorkflow(Workflow):
@step()
@step
async def step1(self, ctx: Context, ev: StartEvent) -> OneTestEvent:
ctx.session.send_event(OneTestEvent(), step=None)
return None

@step()
@step
async def step2(self, ev: OneTestEvent) -> StopEvent:
return StopEvent(result="step2")

Expand All @@ -208,7 +208,7 @@ async def step2(self, ev: OneTestEvent) -> StopEvent:
@pytest.mark.asyncio()
async def test_workflow_missing_service():
class DummyWorkflow(Workflow):
@step()
@step
async def step(self, ev: StartEvent, my_service: Workflow) -> StopEvent:
return StopEvent(result=42)

Expand All @@ -224,7 +224,7 @@ async def step(self, ev: StartEvent, my_service: Workflow) -> StopEvent:
@pytest.mark.asyncio()
async def test_workflow_multiple_runs():
class DummyWorkflow(Workflow):
@step()
@step
async def step(self, ev: StartEvent) -> StopEvent:
return StopEvent(result=ev.number * 2)

Expand Down

0 comments on commit 93fc38b

Please sign in to comment.