Skip to content

Commit

Permalink
Update argument setting for in fast registered, dynamically generated…
Browse files Browse the repository at this point in the history
…, pod tasks (#835)

Signed-off-by: Yee Hing Tong <[email protected]>
  • Loading branch information
wild-endeavor authored Jan 27, 2022
1 parent c20372e commit 170f5af
Show file tree
Hide file tree
Showing 6 changed files with 84 additions and 23 deletions.
4 changes: 4 additions & 0 deletions flytekit/core/context_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -400,8 +400,12 @@ class FastSerializationSettings(object):
"""

enabled: bool = False
# This is the location that the code should be copied into.
destination_dir: Optional[str] = None

# This is the zip file where the new code was uploaded to.
distribution_location: Optional[str] = None


@dataclass(frozen=True)
class SerializationSettings(object):
Expand Down
20 changes: 7 additions & 13 deletions flytekit/core/python_function_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,18 +238,6 @@ def compile_into_workflow(
"Compilation for a dynamic workflow called in fast execution mode but no additional code "
"distribution could be retrieved"
)
logger.warn(f"ctx.execution_state.additional_context {ctx.execution_state.additional_context}")
for task_template in tts:
sanitized_args = []
for arg in task_template.container.args:
if arg == "{{ .remote_package_path }}":
sanitized_args.append(ctx.execution_state.additional_context.get("dynamic_addl_distro"))
elif arg == "{{ .dest_dir }}":
sanitized_args.append(ctx.execution_state.additional_context.get("dynamic_dest_dir", "."))
else:
sanitized_args.append(arg)
del task_template.container.args[:]
task_template.container.args.extend(sanitized_args)

dj_spec = _dynamic_job.DynamicJobSpec(
min_successes=len(workflow_spec.template.nodes),
Expand Down Expand Up @@ -290,7 +278,13 @@ def dynamic_execute(self, task_function: Callable, **kwargs) -> Any:
if is_fast_execution:
ctx = ctx.with_serialization_settings(
ctx.serialization_settings.new_builder()
.with_fast_serialization_settings(FastSerializationSettings(enabled=True))
.with_fast_serialization_settings(
FastSerializationSettings(
enabled=True,
destination_dir=ctx.execution_state.additional_context.get("dynamic_dest_dir", "."),
distribution_location=ctx.execution_state.additional_context.get("dynamic_addl_distro"),
)
)
.build()
)

Expand Down
14 changes: 6 additions & 8 deletions flytekit/tools/translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,19 +68,17 @@ def _fast_serialize_command_fn(
) -> Callable[[SerializationSettings], List[str]]:
default_command = task.get_default_command(settings)

dest_dir = (
settings.fast_serialization_settings.destination_dir if settings.fast_serialization_settings is not None else ""
)
if dest_dir is None or dest_dir == "":
dest_dir = "{{ .dest_dir }}"

def fn(settings: SerializationSettings) -> List[str]:
return [
"pyflyte-fast-execute",
"--additional-distribution",
"{{ .remote_package_path }}",
settings.fast_serialization_settings.distribution_location
if settings.fast_serialization_settings and settings.fast_serialization_settings.distribution_location
else "{{ .remote_package_path }}",
"--dest-dir",
dest_dir,
settings.fast_serialization_settings.destination_dir
if settings.fast_serialization_settings and settings.fast_serialization_settings.destination_dir
else "{{ .dest_dir }}",
"--",
*default_command,
]
Expand Down
2 changes: 1 addition & 1 deletion plugins/flytekit-k8s-pod/flytekitplugins/pod/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def _serialize_pod_spec(self, settings: SerializationSettings) -> Dict[str, Any]
final_containers = []
for container in containers:
# In the case of the primary container, we overwrite specific container attributes with the default values
# used in an SDK runnable task.
# used in the regular Python task.
if container.name == self.task_config.primary_container_name:
sdk_default_container = super().get_container(settings)

Expand Down
63 changes: 63 additions & 0 deletions plugins/flytekit-k8s-pod/tests/test_pod.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from flytekit import Resources, TaskMetadata, dynamic, map_task, task
from flytekit.core import context_manager
from flytekit.core.context_manager import FastSerializationSettings
from flytekit.core.type_engine import TypeEngine
from flytekit.extend import ExecutionState, Image, ImageConfig, SerializationSettings
from flytekit.tools.translator import get_serializable

Expand Down Expand Up @@ -391,3 +392,65 @@ def simple_pod_task(i: int):
"task-name",
"simple_pod_task",
]


def test_fast():
REQUESTS_GPU = Resources(cpu="123m", mem="234Mi", ephemeral_storage="123M", gpu="1")
LIMITS_GPU = Resources(cpu="124M", mem="235Mi", ephemeral_storage="124M", gpu="1")

def get_minimal_pod_task_config() -> Pod:
primary_container = V1Container(name="flytetask")
pod_spec = V1PodSpec(containers=[primary_container])
return Pod(pod_spec=pod_spec, primary_container_name="flytetask")

@task(
task_config=get_minimal_pod_task_config(),
requests=REQUESTS_GPU,
limits=LIMITS_GPU,
)
def pod_task_with_resources(dummy_input: str) -> str:
return dummy_input

@dynamic(requests=REQUESTS_GPU, limits=LIMITS_GPU)
def dynamic_task_with_pod_subtask(dummy_input: str) -> str:
pod_task_with_resources(dummy_input=dummy_input)
return dummy_input

default_img = Image(name="default", fqn="test", tag="tag")
serialization_settings = SerializationSettings(
project="project",
domain="domain",
version="version",
env={"FOO": "baz"},
image_config=ImageConfig(default_image=default_img, images=[default_img]),
fast_serialization_settings=FastSerializationSettings(enabled=True),
)

with context_manager.FlyteContextManager.with_context(
context_manager.FlyteContextManager.current_context().with_serialization_settings(serialization_settings)
) as ctx:
with context_manager.FlyteContextManager.with_context(
ctx.with_execution_state(
ctx.execution_state.with_params(
mode=ExecutionState.Mode.TASK_EXECUTION,
additional_context={
"dynamic_addl_distro": "s3://my-s3-bucket/fast/123",
"dynamic_dest_dir": "/User/flyte/workflows",
},
)
)
) as ctx:
input_literal_map = TypeEngine.dict_to_literal_map(ctx, {"dummy_input": "hi"})
dynamic_job_spec = dynamic_task_with_pod_subtask.dispatch_execute(ctx, input_literal_map)
# print(dynamic_job_spec)
assert len(dynamic_job_spec._nodes) == 1
assert len(dynamic_job_spec.tasks) == 1
args = " ".join(dynamic_job_spec.tasks[0].k8s_pod.pod_spec["containers"][0]["args"])
assert args.startswith(
"pyflyte-fast-execute --additional-distribution s3://my-s3-bucket/fast/123 "
"--dest-dir /User/flyte/workflows"
)
assert dynamic_job_spec.tasks[0].k8s_pod.pod_spec["containers"][0]["resources"]["limits"]["cpu"] == "124M"
assert dynamic_job_spec.tasks[0].k8s_pod.pod_spec["containers"][0]["resources"]["requests"]["gpu"] == "1"

assert context_manager.FlyteContextManager.size() == 1
4 changes: 3 additions & 1 deletion tests/flytekit/unit/core/test_type_hints.py
Original file line number Diff line number Diff line change
Expand Up @@ -621,7 +621,9 @@ def my_wf(a: int) -> typing.List[str]:
)
)
) as ctx:
dynamic_job_spec = my_subwf.compile_into_workflow(ctx, my_subwf._task_function, a=5)
input_literal_map = TypeEngine.dict_to_literal_map(ctx, {"a": 5})

dynamic_job_spec = my_subwf.dispatch_execute(ctx, input_literal_map)
assert len(dynamic_job_spec._nodes) == 5
assert len(dynamic_job_spec.tasks) == 1
args = " ".join(dynamic_job_spec.tasks[0].container.args)
Expand Down

0 comments on commit 170f5af

Please sign in to comment.