Skip to content

Commit

Permalink
eliminate redundant literal conversion for Iterator[JSON] type (fly…
Browse files Browse the repository at this point in the history
…teorg#2602)

* eliminate redundant literal conversion for  type

Signed-off-by: Samhita Alla <[email protected]>

* add test

Signed-off-by: Samhita Alla <[email protected]>

* lint

Signed-off-by: Samhita Alla <[email protected]>

* add isclass check

Signed-off-by: Samhita Alla <[email protected]>

---------

Signed-off-by: Samhita Alla <[email protected]>
  • Loading branch information
samhita-alla authored and Mecoli1219 committed Jul 27, 2024
1 parent 4912e4c commit 2ae928e
Show file tree
Hide file tree
Showing 2 changed files with 212 additions and 22 deletions.
20 changes: 16 additions & 4 deletions flytekit/clis/sdk_in_container/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,12 @@
import tempfile
import typing
from dataclasses import dataclass, field, fields
from typing import get_args
from typing import Iterator, get_args

import rich_click as click
from mashumaro.codecs.json import JSONEncoder
from rich.progress import Progress
from typing_extensions import get_origin

from flytekit import Annotations, FlyteContext, FlyteContextManager, Labels, Literal
from flytekit.clis.sdk_in_container.helpers import patch_image_config
Expand Down Expand Up @@ -538,10 +539,21 @@ def _run(*args, **kwargs):
for input_name, v in entity.python_interface.inputs_with_defaults.items():
processed_click_value = kwargs.get(input_name)
optional_v = False

skip_default_value_selection = False
if processed_click_value is None and isinstance(v, typing.Tuple):
optional_v = is_optional(v[0])
if len(v) == 2:
processed_click_value = v[1]
if entity_type == "workflow" and hasattr(v[0], "__args__"):
origin_base_type = get_origin(v[0])
if inspect.isclass(origin_base_type) and issubclass(origin_base_type, Iterator): # Iterator
args = getattr(v[0], "__args__")
if isinstance(args, tuple) and get_origin(args[0]) is typing.Union: # Iterator[JSON]
logger.debug(f"Detected Iterator[JSON] in {entity.name} input annotations...")
skip_default_value_selection = True

if not skip_default_value_selection:
optional_v = is_optional(v[0])
if len(v) == 2:
processed_click_value = v[1]
if isinstance(processed_click_value, ArtifactQuery):
if run_level_params.is_remote:
click.secho(
Expand Down
214 changes: 196 additions & 18 deletions tests/flytekit/unit/cli/pyflyte/test_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,19 +9,34 @@
import pytest
import yaml
from click.testing import CliRunner
from flytekit.loggers import logging, logger

from flytekit.clis.sdk_in_container import pyflyte
from flytekit.clis.sdk_in_container.run import RunLevelParams, get_entities_in_file, run_command
from flytekit.clis.sdk_in_container.run import (
RunLevelParams,
get_entities_in_file,
run_command,
)
from flytekit.configuration import Config, Image, ImageConfig
from flytekit.core.task import task
from flytekit.image_spec.image_spec import ImageBuildEngine, ImageSpec, calculate_hash_from_image_spec
from flytekit.image_spec.image_spec import (
ImageBuildEngine,
ImageSpec,
calculate_hash_from_image_spec,
)
from flytekit.interaction.click_types import DirParamType, FileParamType
from flytekit.remote import FlyteRemote
from typing import Iterator
from flytekit.types.iterator import JSON
from flytekit import workflow


pytest.importorskip("pandas")

REMOTE_WORKFLOW_FILE = "https://raw.githubusercontent.com/flyteorg/flytesnacks/8337b64b33df046b2f6e4cba03c74b7bdc0c4fb1/cookbook/core/flyte_basics/basic_workflow.py"
IMPERATIVE_WORKFLOW_FILE = os.path.join(os.path.dirname(os.path.realpath(__file__)), "imperative_wf.py")
IMPERATIVE_WORKFLOW_FILE = os.path.join(
os.path.dirname(os.path.realpath(__file__)), "imperative_wf.py"
)
DIR_NAME = os.path.dirname(os.path.realpath(__file__))


Expand All @@ -46,7 +61,9 @@ def workflow_file(request, tmp_path_factory):
@pytest.fixture
def remote():
with mock.patch("flytekit.clients.friendly.SynchronousFlyteClient") as mock_client:
flyte_remote = FlyteRemote(config=Config.auto(), default_project="p1", default_domain="d1")
flyte_remote = FlyteRemote(
config=Config.auto(), default_project="p1", default_domain="d1"
)
flyte_remote._client = mock_client
return flyte_remote

Expand All @@ -70,7 +87,9 @@ def test_pyflyte_run_wf(remote, remote_flag, workflow_file):
with mock.patch("flytekit.configuration.plugin.FlyteRemote"):
runner = CliRunner()
result = runner.invoke(
pyflyte.main, ["run", remote_flag, workflow_file, "my_wf", "--help"], catch_exceptions=False
pyflyte.main,
["run", remote_flag, workflow_file, "my_wf", "--help"],
catch_exceptions=False,
)

assert result.exit_code == 0
Expand All @@ -81,7 +100,9 @@ def test_pyflyte_run_with_labels():
with mock.patch("flytekit.configuration.plugin.FlyteRemote"):
runner = CliRunner()
result = runner.invoke(
pyflyte.main, ["run", "--remote", str(workflow_file), "my_wf", "--help"], catch_exceptions=False
pyflyte.main,
["run", "--remote", str(workflow_file), "my_wf", "--help"],
catch_exceptions=False,
)
assert result.exit_code == 0

Expand All @@ -100,7 +121,16 @@ def test_copy_all_files():
runner = CliRunner()
result = runner.invoke(
pyflyte.main,
["run", "--copy-all", IMPERATIVE_WORKFLOW_FILE, "wf", "--in1", "hello", "--in2", "world"],
[
"run",
"--copy-all",
IMPERATIVE_WORKFLOW_FILE,
"wf",
"--in1",
"hello",
"--in2",
"world",
],
catch_exceptions=False,
)
assert result.exit_code == 0
Expand Down Expand Up @@ -176,7 +206,13 @@ def test_pyflyte_run_cli(workflow_file):

@pytest.mark.parametrize(
"input",
["1", os.path.join(DIR_NAME, "testdata/df.parquet"), '{"x":1.0, "y":2.0}', "2020-05-01", "RED"],
[
"1",
os.path.join(DIR_NAME, "testdata/df.parquet"),
'{"x":1.0, "y":2.0}',
"2020-05-01",
"RED",
],
)
def test_union_type1(input):
runner = CliRunner()
Expand Down Expand Up @@ -300,7 +336,10 @@ def test_nested_workflow(working_dir, wf_path, monkeypatch: pytest.MonkeyPatch):
],
catch_exceptions=False,
)
assert result.stdout.strip() == "Running Execution on local.\nRunning Execution on local."
assert (
result.stdout.strip()
== "Running Execution on local.\nRunning Execution on local."
)
assert result.exit_code == 0


Expand All @@ -325,12 +364,18 @@ def test_list_default_arguments(wf_path):

# default case, what comes from click if no image is specified, the click param is configured to use the default.
ic_result_1 = ImageConfig(
default_image=Image(name="default", fqn="ghcr.io/flyteorg/mydefault", tag="py3.9-latest"),
images=[Image(name="default", fqn="ghcr.io/flyteorg/mydefault", tag="py3.9-latest")],
default_image=Image(
name="default", fqn="ghcr.io/flyteorg/mydefault", tag="py3.9-latest"
),
images=[
Image(name="default", fqn="ghcr.io/flyteorg/mydefault", tag="py3.9-latest")
],
)
# test that command line args are merged with the file
ic_result_2 = ImageConfig(
default_image=Image(name="default", fqn="cr.flyte.org/flyteorg/flytekit", tag="py3.9-latest"),
default_image=Image(
name="default", fqn="cr.flyte.org/flyteorg/flytekit", tag="py3.9-latest"
),
images=[
Image(name="default", fqn="cr.flyte.org/flyteorg/flytekit", tag="py3.9-latest"),
Image(name="asdf", fqn="ghcr.io/asdf/asdf", tag="latest"),
Expand All @@ -345,7 +390,9 @@ def test_list_default_arguments(wf_path):
)
# test that command line args override the file
ic_result_3 = ImageConfig(
default_image=Image(name="default", fqn="cr.flyte.org/flyteorg/flytekit", tag="py3.9-latest"),
default_image=Image(
name="default", fqn="cr.flyte.org/flyteorg/flytekit", tag="py3.9-latest"
),
images=[
Image(name="default", fqn="cr.flyte.org/flyteorg/flytekit", tag="py3.9-latest"),
Image(name="xyz", fqn="ghcr.io/asdf/asdf", tag="latest"),
Expand Down Expand Up @@ -395,21 +442,29 @@ def test_list_default_arguments(wf_path):
reason="Github macos-latest image does not have docker installed as per https://github.com/orgs/community/discussions/25777",
)
def test_pyflyte_run_run(
mock_image, image_string, leaf_configuration_file_name, final_image_config, mock_image_spec_builder
mock_image,
image_string,
leaf_configuration_file_name,
final_image_config,
mock_image_spec_builder,
):
mock_image.return_value = "cr.flyte.org/flyteorg/flytekit:py3.9-latest"
ImageBuildEngine.register("test", mock_image_spec_builder)

@task
def tk():
...
def tk(): ...

mock_click_ctx = mock.MagicMock()
mock_remote = mock.MagicMock()
image_tuple = (image_string,)
image_config = ImageConfig.validate_image(None, "", image_tuple)

pp = pathlib.Path(__file__).parent.parent.parent / "configuration" / "configs" / leaf_configuration_file_name
pp = (
pathlib.Path(__file__).parent.parent.parent
/ "configuration"
/ "configs"
/ leaf_configuration_file_name
)

obj = RunLevelParams(
project="p",
Expand All @@ -429,6 +484,125 @@ def check_image(*args, **kwargs):
run_command(mock_click_ctx, tk)()


def jsons():
for x in [
{
"custom_id": "request-1",
"method": "POST",
"url": "/v1/chat/completions",
"body": {
"model": "gpt-3.5-turbo",
"messages": [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "What is 2+2?"},
],
},
},
]:
yield x


@mock.patch("flytekit.configuration.default_images.DefaultImages.default_image")
def test_pyflyte_run_with_iterator_json_type(
mock_image, mock_image_spec_builder, caplog
):
mock_image.return_value = "cr.flyte.org/flyteorg/flytekit:py3.9-latest"
ImageBuildEngine.register(
"test",
mock_image_spec_builder,
)

@task
def t1(x: Iterator[JSON]) -> Iterator[JSON]:
return x

@workflow
def tk(x: Iterator[JSON] = jsons()) -> Iterator[JSON]:
return t1(x=x)

@task
def t2(x: list[int]) -> list[int]:
return x

@workflow
def tk_list(x: list[int] = [1, 2, 3]) -> list[int]:
return t2(x=x)

@task
def t3(x: Iterator[int]) -> Iterator[int]:
return x

@workflow
def tk_simple_iterator(x: Iterator[int] = iter([1, 2, 3])) -> Iterator[int]:
return t3(x=x)

mock_click_ctx = mock.MagicMock()
mock_remote = mock.MagicMock()
image_tuple = ("ghcr.io/flyteorg/mydefault:py3.9-latest",)
image_config = ImageConfig.validate_image(None, "", image_tuple)

pp = (
pathlib.Path(__file__).parent.parent.parent
/ "configuration"
/ "configs"
/ "no_images.yaml"
)

obj = RunLevelParams(
project="p",
domain="d",
image_config=image_config,
remote=True,
config_file=str(pp),
)
obj._remote = mock_remote
mock_click_ctx.obj = obj

def check_image(*args, **kwargs):
assert kwargs["image_config"] == ic_result_1

mock_remote.register_script.side_effect = check_image

logger.propagate = True
with caplog.at_level(logging.DEBUG, logger="flytekit"):
run_command(mock_click_ctx, tk)()
assert any(
"Detected Iterator[JSON] in pyflyte.test_run.tk input annotations..."
in message[2]
for message in caplog.record_tuples
)

caplog.clear()

with caplog.at_level(logging.DEBUG, logger="flytekit"):
run_command(mock_click_ctx, tk_list)()
assert not any(
"Detected Iterator[JSON] in pyflyte.test_run.tk_list input annotations..."
in message[2]
for message in caplog.record_tuples
)

caplog.clear()

with caplog.at_level(logging.DEBUG, logger="flytekit"):
run_command(mock_click_ctx, t1)()
assert not any(
"Detected Iterator[JSON] in pyflyte.test_run.t1 input annotations..."
in message[2]
for message in caplog.record_tuples
)

caplog.clear()

with caplog.at_level(logging.DEBUG, logger="flytekit"):
run_command(mock_click_ctx, tk_simple_iterator)()
assert not any(
"Detected Iterator[JSON] in pyflyte.test_run.tk_simple_iterator input annotations..."
in message[2]
for message in caplog.record_tuples
)


def test_file_param():
m = mock.MagicMock()
flyte_file = FileParamType().convert(__file__, m, m)
Expand Down Expand Up @@ -484,7 +658,11 @@ def test_pyflyte_run_with_none(a_val, workflow_file):
"envs, envs_argument, expected_output",
[
(["--env", "MY_ENV_VAR=hello"], '["MY_ENV_VAR"]', "hello"),
(["--env", "MY_ENV_VAR=hello", "--env", "ABC=42"], '["MY_ENV_VAR","ABC"]', "hello,42"),
(
["--env", "MY_ENV_VAR=hello", "--env", "ABC=42"],
'["MY_ENV_VAR","ABC"]',
"hello,42",
),
],
)
@pytest.mark.parametrize(
Expand Down

0 comments on commit 2ae928e

Please sign in to comment.