Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Signal use #1398

Merged
merged 16 commits into from
Jan 5, 2023
17 changes: 17 additions & 0 deletions flytekit/clients/raw.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,13 @@
import grpc
import requests as _requests
from flyteidl.admin.project_pb2 import ProjectListRequest
from flyteidl.admin.signal_pb2 import SignalList, SignalListRequest, SignalSetRequest, SignalSetResponse
from flyteidl.service import admin_pb2_grpc as _admin_service
from flyteidl.service import auth_pb2
from flyteidl.service import auth_pb2_grpc as auth_service
from flyteidl.service import dataproxy_pb2 as _dataproxy_pb2
from flyteidl.service import dataproxy_pb2_grpc as dataproxy_service
from flyteidl.service import signal_pb2_grpc as signal_service
from flyteidl.service.dataproxy_pb2_grpc import DataProxyServiceStub
from google.protobuf.json_format import MessageToJson as _MessageToJson

Expand Down Expand Up @@ -145,6 +147,7 @@ def __init__(self, cfg: PlatformConfig, **kwargs):
)
self._stub = _admin_service.AdminServiceStub(self._channel)
self._auth_stub = auth_service.AuthMetadataServiceStub(self._channel)
self._signal = signal_service.SignalServiceStub(self._channel)
try:
resp = self._auth_stub.GetPublicClientConfig(auth_pb2.PublicClientAuthConfigRequest())
self._public_client_config = resp
Expand Down Expand Up @@ -406,6 +409,20 @@ def get_task(self, get_object_request):
"""
return self._stub.GetTask(get_object_request, metadata=self._metadata)

@_handle_rpc_error(retry=True)
def set_signal(self, signal_set_request: SignalSetRequest) -> SignalSetResponse:
"""
This sets a signal
"""
return self._signal.SetSignal(signal_set_request, metadata=self._metadata)

@_handle_rpc_error(retry=True)
def list_signals(self, signal_list_request: SignalListRequest) -> SignalList:
"""
This lists signals
"""
return self._signal.ListSignals(signal_list_request, metadata=self._metadata)

####################################################################################################################
#
# Workflow Endpoints
Expand Down
29 changes: 20 additions & 9 deletions flytekit/remote/entities.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,6 +334,12 @@ def promote_from_model(
return cls(new_if_else_block), converted_sub_workflows


class FlyteGateNode(_workflow_model.GateNode):
@classmethod
def promote_from_model(cls, model: _workflow_model.GateNode):
return cls(model.signal, model.sleep, model.approve)


class FlyteNode(_hash_mixin.HashOnReferenceMixin, _workflow_model.Node):
"""A class encapsulating a remote Flyte node."""

Expand All @@ -343,22 +349,23 @@ def __init__(
upstream_nodes,
bindings,
metadata,
task_node: FlyteTaskNode = None,
workflow_node: FlyteWorkflowNode = None,
branch_node: FlyteBranchNode = None,
task_node: Optional[FlyteTaskNode] = None,
workflow_node: Optional[FlyteWorkflowNode] = None,
branch_node: Optional[FlyteBranchNode] = None,
gate_node: Optional[FlyteGateNode] = None,
):
if not task_node and not workflow_node and not branch_node:
if not task_node and not workflow_node and not branch_node and not gate_node:
raise _user_exceptions.FlyteAssertion(
"An Flyte node must have one of task|workflow|branch entity specified at once"
"An Flyte node must have one of task|workflow|branch|gate entity specified at once"
)
# todo: wip - flyte_branch_node is a hack, it should be a Condition, but backing out a Condition object from
# the compiled IfElseBlock is cumbersome, shouldn't do it if we can get away with it.
# TODO: Revisit flyte_branch_node and flyte_gate_node, should they be another type like Condition instead
# of a node?
if task_node:
self._flyte_entity = task_node.flyte_task
elif workflow_node:
self._flyte_entity = workflow_node.flyte_workflow or workflow_node.flyte_launch_plan
else:
self._flyte_entity = branch_node
self._flyte_entity = branch_node or gate_node

super(FlyteNode, self).__init__(
id=id,
Expand All @@ -369,6 +376,7 @@ def __init__(
task_node=task_node,
workflow_node=workflow_node,
branch_node=branch_node,
gate_node=gate_node,
)
self._upstream = upstream_nodes

Expand Down Expand Up @@ -412,7 +420,7 @@ def promote_from_model(
remote_logger.warning(f"Should not call promote from model on a start node or end node {model}")
return None, converted_sub_workflows

flyte_task_node, flyte_workflow_node, flyte_branch_node = None, None, None
flyte_task_node, flyte_workflow_node, flyte_branch_node, flyte_gate_node = None, None, None, None
if model.task_node is not None:
if model.task_node.reference_id not in tasks:
raise RuntimeError(
Expand All @@ -435,6 +443,8 @@ def promote_from_model(
tasks,
converted_sub_workflows,
)
elif model.gate_node is not None:
flyte_gate_node = FlyteGateNode.promote_from_model(model.gate_node)
else:
raise _system_exceptions.FlyteSystemException(
f"Bad Node model, neither task nor workflow detected, node: {model}"
Expand All @@ -459,6 +469,7 @@ def promote_from_model(
task_node=flyte_task_node,
workflow_node=flyte_workflow_node,
branch_node=flyte_branch_node,
gate_node=flyte_gate_node,
),
converted_sub_workflows,
)
Expand Down
67 changes: 66 additions & 1 deletion flytekit/remote/remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from dataclasses import asdict, dataclass
from datetime import datetime, timedelta

from flyteidl.admin.signal_pb2 import Signal, SignalListRequest, SignalSetRequest
from flyteidl.core import literals_pb2 as literals_pb2

from flytekit import Literal
Expand All @@ -40,11 +41,12 @@
from flytekit.models import launch_plan as launch_plan_models
from flytekit.models import literals as literal_models
from flytekit.models import task as task_models
from flytekit.models import types as type_models
from flytekit.models.admin import common as admin_common_models
from flytekit.models.admin import workflow as admin_workflow_models
from flytekit.models.admin.common import Sort
from flytekit.models.core import workflow as workflow_model
from flytekit.models.core.identifier import Identifier, ResourceType, WorkflowExecutionIdentifier
from flytekit.models.core.identifier import Identifier, ResourceType, SignalIdentifier, WorkflowExecutionIdentifier
from flytekit.models.core.workflow import NodeMetadata
from flytekit.models.execution import (
ExecutionMetadata,
Expand Down Expand Up @@ -350,6 +352,69 @@ def fetch_execution(self, project: str = None, domain: str = None, name: str = N
# Listing Entities #
######################

def list_signals(
self,
execution_name: str,
project: typing.Optional[str] = None,
domain: typing.Optional[str] = None,
limit: int = 100,
filters: typing.Optional[typing.List[filter_models.Filter]] = None,
) -> typing.List[Signal]:
"""
:param execution_name: The name of the execution. This is the tailend of the URL when looking at the workflow execution.
:param project: The execution project, will default to the Remote's default project.
:param domain: The execution domain, will default to the Remote's default domain.
:param limit: The number of signals to fetch
:param filters: Optional list of filters
"""
wf_exec_id = WorkflowExecutionIdentifier(
project=project or self.default_project, domain=domain or self.default_domain, name=execution_name
)
req = SignalListRequest(workflow_execution_id=wf_exec_id.to_flyte_idl(), limit=limit, filters=filters)
resp = self.client.list_signals(req)
s = resp.signals
return s

def set_signal(
self,
signal_id: str,
execution_name: str,
value: typing.Union[literal_models.Literal, typing.Any],
project: typing.Optional[str] = None,
domain: typing.Optional[str] = None,
python_type: typing.Optional[typing.Type] = None,
literal_type: typing.Optional[type_models.LiteralType] = None,
):
"""
:param signal_id: The name of the signal, this is the key used in the approve() or wait_for_input() call.
:param execution_name: The name of the execution. This is the tail-end of the URL when looking
at the workflow execution.
:param value: This is either a Literal or a Python value which FlyteRemote will invoke the TypeEngine to
convert into a Literal. This argument is only value for wait_for_input type signals.
:param project: The execution project, will default to the Remote's default project.
:param domain: The execution domain, will default to the Remote's default domain.
:param python_type: Provide a python type to help with conversion if the value you provided is not a Literal.
:param literal_type: Provide a Flyte literal type to help with conversion if the value you provided
is not a Literal
"""
wf_exec_id = WorkflowExecutionIdentifier(
project=project or self.default_project, domain=domain or self.default_domain, name=execution_name
)
if isinstance(value, Literal):
remote_logger.debug(f"Using provided {value} as existing Literal value")
lit = value
else:
lt = literal_type or (
TypeEngine.to_literal_type(python_type) if python_type else TypeEngine.to_literal_type(type(value))
)
lit = TypeEngine.to_literal(self.context, value, python_type or type(value), lt)
remote_logger.debug(f"Converted {value} to literal {lit} using literal type {lt}")

req = SignalSetRequest(id=SignalIdentifier(signal_id, wf_exec_id).to_flyte_idl(), value=lit.to_flyte_idl())

# Response is empty currently, nothing to give back to the user.
self.client.set_signal(req)

def recent_executions(
self,
project: typing.Optional[str] = None,
Expand Down
5 changes: 5 additions & 0 deletions flytekit/tools/translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -658,6 +658,11 @@ def get_serializable(
elif isinstance(entity, BranchNode):
cp_entity = get_serializable_branch_node(entity_mapping, settings, entity, options)

elif isinstance(entity, GateNode):
import ipdb

ipdb.set_trace()

elif isinstance(entity, FlyteTask) or isinstance(entity, FlyteWorkflow):
if entity.should_register:
if isinstance(entity, FlyteTask):
Expand Down
2 changes: 1 addition & 1 deletion flytekit/types/structured/structured_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -484,7 +484,7 @@ def register_for_protocol(
if (default_format_for_type or default_for_type) and h.supported_format != GENERIC_FORMAT:
if h.python_type in cls.DEFAULT_FORMATS and not override:
if cls.DEFAULT_FORMATS[h.python_type] != h.supported_format:
logger.debug(
logger.info(
f"Not using handler {h} with format {h.supported_format} as default for {h.python_type}, {cls.DEFAULT_FORMATS[h.python_type]} already specified."
)
else:
Expand Down
8 changes: 6 additions & 2 deletions tests/flytekit/unit/clients/test_raw.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,12 +40,13 @@ def get_admin_stub_mock() -> mock.MagicMock:
return auth_stub_mock


@mock.patch("flytekit.clients.raw.signal_service")
@mock.patch("flytekit.clients.raw.dataproxy_service")
@mock.patch("flytekit.clients.raw.auth_service")
@mock.patch("flytekit.clients.raw._admin_service")
@mock.patch("flytekit.clients.raw.grpc.insecure_channel")
@mock.patch("flytekit.clients.raw.grpc.secure_channel")
def test_client_set_token(mock_secure_channel, mock_channel, mock_admin, mock_admin_auth, mock_dataproxy):
def test_client_set_token(mock_secure_channel, mock_channel, mock_admin, mock_admin_auth, mock_dataproxy, mock_signal):
mock_secure_channel.return_value = True
mock_channel.return_value = True
mock_admin.AdminServiceStub.return_value = True
Expand Down Expand Up @@ -73,6 +74,7 @@ def test_refresh_credentials_from_command(mock_call_to_external_process, mock_ad
mock_set_access_token.assert_called_with(token, client.public_client_config.authorization_metadata_key)


@mock.patch("flytekit.clients.raw.signal_service")
@mock.patch("flytekit.clients.raw.dataproxy_service")
@mock.patch("flytekit.clients.raw.get_basic_authorization_header")
@mock.patch("flytekit.clients.raw.get_token")
Expand All @@ -88,6 +90,7 @@ def test_refresh_client_credentials_aka_basic(
mock_get_token,
mock_get_basic_header,
mock_dataproxy,
mock_signal,
):
mock_secure_channel.return_value = True
mock_channel.return_value = True
Expand All @@ -112,12 +115,13 @@ def test_refresh_client_credentials_aka_basic(
assert client._metadata[0][0] == "authorization"


@mock.patch("flytekit.clients.raw.signal_service")
@mock.patch("flytekit.clients.raw.dataproxy_service")
@mock.patch("flytekit.clients.raw.auth_service")
@mock.patch("flytekit.clients.raw._admin_service")
@mock.patch("flytekit.clients.raw.grpc.insecure_channel")
@mock.patch("flytekit.clients.raw.grpc.secure_channel")
def test_raises(mock_secure_channel, mock_channel, mock_admin, mock_admin_auth, mock_dataproxy):
def test_raises(mock_secure_channel, mock_channel, mock_admin, mock_admin_auth, mock_dataproxy, mock_signal):
mock_secure_channel.return_value = True
mock_channel.return_value = True
mock_admin.AdminServiceStub.return_value = True
Expand Down
35 changes: 34 additions & 1 deletion tests/flytekit/unit/core/test_gate.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@
from flytekit.core.task import task
from flytekit.core.type_engine import TypeEngine
from flytekit.core.workflow import workflow
from flytekit.tools.translator import get_serializable
from flytekit.remote.entities import FlyteWorkflow
from flytekit.tools.translator import gather_dependent_entities, get_serializable

default_img = Image(name="default", fqn="test", tag="tag")
serialization_settings = SerializationSettings(
Expand Down Expand Up @@ -290,3 +291,35 @@ def cond_wf(a: int) -> float:
x = cond_wf(a=3)
assert x == 6
assert stdin.read() == ""


def test_promote():
@task
def t1(a: int) -> int:
return a + 5

@task
def t2(a: int) -> int:
return a + 6

@workflow
def wf(a: int) -> typing.Tuple[int, int, int]:
zzz = sleep(timedelta(seconds=10))
x = t1(a=a)
s1 = wait_for_input("my-signal-name", timeout=timedelta(hours=1), expected_type=bool)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

qq: should the name of the gate node be unique?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so this is actually an interesting question. signals ids are unique within a workflow execution. so if you had this same statement twice it would just wait twice for a single signal with the given name. nothing would break unless the expected types are different. presumably you could have two subworkflows that both wait on the same signal by using the same name. at least that's how it works in the backend. not sure if we want to explicitely restrict this from flytekit side?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

got it, thanks for the explain

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes.

s2 = wait_for_input("my-signal-name-2", timeout=timedelta(hours=2), expected_type=int)
z = t1(a=5)
y = t2(a=s2)
q = t2(a=approve(y, "approvalfory", timeout=timedelta(hours=2)))
zzz >> x
x >> s1
s1 >> z

return y, z, q

entries = OrderedDict()
wf_spec = get_serializable(entries, serialization_settings, wf)
tts, wf_specs, lp_specs = gather_dependent_entities(entries)

fwf = FlyteWorkflow.promote_from_model(wf_spec.template, tasks=tts)
assert fwf.template.nodes[2].gate_node is not None
42 changes: 42 additions & 0 deletions tests/flytekit/unit/core/test_signal.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
from flyteidl.admin.signal_pb2 import Signal, SignalList
from mock import MagicMock

from flytekit.configuration import Config
from flytekit.core.context_manager import FlyteContextManager
from flytekit.core.type_engine import TypeEngine
from flytekit.models.core.identifier import SignalIdentifier, WorkflowExecutionIdentifier
from flytekit.remote.remote import FlyteRemote


def test_remote_list_signals():
ctx = FlyteContextManager.current_context()
wfeid = WorkflowExecutionIdentifier("p", "d", "execid")
signal_id = SignalIdentifier(signal_id="sigid", execution_id=wfeid).to_flyte_idl()
lt = TypeEngine.to_literal_type(int)
signal = Signal(
id=signal_id,
type=lt.to_flyte_idl(),
value=TypeEngine.to_literal(ctx, 3, int, lt).to_flyte_idl(),
)

mock_client = MagicMock()
mock_client.list_signals.return_value = SignalList(signals=[signal], token="")

remote = FlyteRemote(config=Config.auto(), default_project="p1", default_domain="d1")
remote._client = mock_client
res = remote.list_signals("execid", "p", "d", limit=10)
assert len(res) == 1


def test_remote_set_signal():
mock_client = MagicMock()

def checker(request):
assert request.id.signal_id == "sigid"
assert request.value.scalar.primitive.integer == 3

mock_client.set_signal.side_effect = checker

remote = FlyteRemote(config=Config.auto(), default_project="p1", default_domain="d1")
remote._client = mock_client
remote.set_signal("sigid", "execid", 3)