diff --git a/flytekit/remote/remote.py b/flytekit/remote/remote.py index 856563f0db..c81098d6db 100644 --- a/flytekit/remote/remote.py +++ b/flytekit/remote/remote.py @@ -23,6 +23,7 @@ from typing import Dict import click +import cloudpickle import fsspec import requests from flyteidl.admin.signal_pb2 import Signal, SignalListRequest, SignalSetRequest @@ -915,6 +916,7 @@ def upload_file( def _version_from_hash( md5_bytes: bytes, serialization_settings: SerializationSettings, + default_inputs: typing.Optional[Dict[str, typing.Any]] = None, *additional_context: str, ) -> str: """ @@ -939,6 +941,9 @@ def _version_from_hash( for s in additional_context: h.update(bytes(s, "utf-8")) + if default_inputs: + h.update(cloudpickle.dumps(default_inputs)) + # Omit the character '=' from the version as that's essentially padding used by the base64 encoding # and does not increase entropy of the hash while making it very inconvenient to copy-and-paste. return base64.urlsafe_b64encode(h.digest()).decode("ascii").rstrip("=") @@ -1013,10 +1018,16 @@ def _get_image_names(entity: typing.Union[PythonAutoContainerTask, WorkflowBase] return image_names return [] + default_inputs = None + if isinstance(entity, WorkflowBase): + default_inputs = entity.python_interface.default_inputs_as_kwargs + # The md5 version that we send to S3/GCS has to match the file contents exactly, # but we don't have to use it when registering with the Flyte backend. # For that add the hash of the compilation settings to hash of file - version = self._version_from_hash(md5_bytes, serialization_settings, *_get_image_names(entity)) + version = self._version_from_hash( + md5_bytes, serialization_settings, default_inputs, *_get_image_names(entity) + ) if isinstance(entity, PythonTask): return self.register_task(entity, serialization_settings, version) diff --git a/tests/flytekit/integration/remote/test_remote.py b/tests/flytekit/integration/remote/test_remote.py index f23fc061d9..78b30c8276 100644 --- a/tests/flytekit/integration/remote/test_remote.py +++ b/tests/flytekit/integration/remote/test_remote.py @@ -73,10 +73,14 @@ def run(file_name, wf_name, *args): assert out.returncode == 0 -# test child_workflow.parent_wf asynchronously register a parent wf1 with child lp from another wf2. def test_remote_run(): + # child_workflow.parent_wf asynchronously register a parent wf1 with child lp from another wf2. run("child_workflow.py", "parent_wf", "--a", "3") + # run twice to make sure it will register a new version of the workflow. + run("default_lp.py", "my_wf") + run("default_lp.py", "my_wf") + def test_fetch_execute_launch_plan(register): remote = FlyteRemote(Config.auto(config_file=CONFIG), PROJECT, DOMAIN) diff --git a/tests/flytekit/integration/remote/workflows/basic/default_lp.py b/tests/flytekit/integration/remote/workflows/basic/default_lp.py new file mode 100644 index 0000000000..3ff61e965d --- /dev/null +++ b/tests/flytekit/integration/remote/workflows/basic/default_lp.py @@ -0,0 +1,17 @@ +import datetime + +from flytekit import task, workflow + + +@task +def print_datetime(time: datetime.datetime): + print(time) + + +@workflow +def my_wf(time: datetime.datetime = datetime.datetime.now()): + print_datetime(time=time) + + +if __name__ == "__main__": + print(f"Running my_wf() {my_wf()}") diff --git a/tests/flytekit/unit/remote/test_remote.py b/tests/flytekit/unit/remote/test_remote.py index a35f5efb5d..d6b0cc711c 100644 --- a/tests/flytekit/unit/remote/test_remote.py +++ b/tests/flytekit/unit/remote/test_remote.py @@ -502,7 +502,7 @@ def wf(name: str = "union"): flyte_remote = FlyteRemote(config=Config.auto(), default_project="p1", default_domain="d1") flyte_remote.register_script(wf) - version_from_hash_mock.assert_called_once_with(md5_bytes, mock.ANY, image_spec.image_name()) + version_from_hash_mock.assert_called_once_with(md5_bytes, mock.ANY, mock.ANY, image_spec.image_name()) register_workflow_mock.assert_called_once()