Skip to content

Commit

Permalink
Auto download for fetch (flyteorg#1918)
Browse files Browse the repository at this point in the history
* Auto download for fetch

Signed-off-by: Ketan Umare <[email protected]>

* Adds new fetch with download command

` pyflyte fetch --download-to /tmp/. flyte://v1/flytesnacks/development/f5dc81c5a8c6441d4a0a/rotaterotateimage/o/o0`

Signed-off-by: Ketan Umare <[email protected]>

* Auto download and improved some bugs

Signed-off-by: Ketan Umare <[email protected]>

* Updated fetch and hitl

Signed-off-by: Ketan Umare <[email protected]>

---------

Signed-off-by: Ketan Umare <[email protected]>
  • Loading branch information
kumare3 authored and ringohoffman committed Nov 24, 2023
1 parent 31cc37a commit c91ffba
Show file tree
Hide file tree
Showing 7 changed files with 122 additions and 18 deletions.
73 changes: 70 additions & 3 deletions flytekit/clis/sdk_in_container/fetch.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,77 @@
import os
import pathlib
import typing

import rich_click as click
from google.protobuf.json_format import MessageToJson
from rich import print
from rich.panel import Panel
from rich.pretty import Pretty

from flytekit import Literal
from flytekit import BlobType, FlyteContext, Literal
from flytekit.clis.sdk_in_container.helpers import get_and_save_remote_with_click_context
from flytekit.core.type_engine import LiteralsResolver
from flytekit.interaction.rich_utils import RichCallback
from flytekit.interaction.string_literals import literal_map_string_repr, literal_string_repr
from flytekit.remote import FlyteRemote


def download_literal(var: str, data: Literal, download_to: typing.Optional[pathlib.Path] = None):
"""
Download a single literal to a file, if it is a blob or structured dataset.
"""
if data is None:
print(f"Skipping {var} as it is None.")
return
if data.scalar:
if data.scalar and (data.scalar.blob or data.scalar.structured_dataset):
uri = data.scalar.blob.uri if data.scalar.blob else data.scalar.structured_dataset.uri
if uri is None:
print("No data to download.")
return
is_multipart = False
if data.scalar.blob:
is_multipart = data.scalar.blob.metadata.type.dimensionality == BlobType.BlobDimensionality.MULTIPART
elif data.scalar.structured_dataset:
is_multipart = True
FlyteContext.current_context().file_access.get_data(
uri, str(download_to / var) + os.sep, is_multipart=is_multipart, callback=RichCallback()
)
elif data.scalar.union is not None:
download_literal(var, data.scalar.union.value, download_to)
elif data.scalar.generic is not None:
with open(download_to / f"{var}.json", "w") as f:
f.write(MessageToJson(data.scalar.generic))
else:
print(
f"[dim]Skipping {var} val {literal_string_repr(data)} as it is not a blob, structured dataset,"
f" or generic type.[/dim]"
)
return
elif data.collection:
for i, v in enumerate(data.collection.literals):
download_literal(f"{i}", v, download_to / var)
elif data.map:
download_to = pathlib.Path(download_to)
for k, v in data.map.literals.items():
download_literal(f"{k}", v, download_to / var)
print(f"Downloaded f{var} to {download_to}")


@click.command("fetch")
@click.argument("flyte_data_uri", type=str, required=True, metavar="FLYTE-DATA-URI (of the form flyte://...)")
@click.option(
"--recursive",
"-r",
is_flag=True,
help="Fetch recursively, all variables in the URI. This is not needed for directrories as they"
" are automatically recursively downloaded.",
)
@click.argument("flyte-data-uri", type=str, required=True, metavar="FLYTE-DATA-URI (format flyte://...)")
@click.argument(
"download-to", type=click.Path(), required=False, default=None, metavar="DOWNLOAD-TO Local path (optional)"
)
@click.pass_context
def fetch(ctx: click.Context, flyte_data_uri: str):
def fetch(ctx: click.Context, recursive: bool, flyte_data_uri: str, download_to: typing.Optional[str] = None):
"""
Retrieve Inputs/Outputs for a Flyte Execution or any of the inner node executions from the remote server.
Expand All @@ -32,3 +90,12 @@ def fetch(ctx: click.Context, flyte_data_uri: str):
pretty = Pretty(p)
panel = Panel(pretty)
print(panel)
if download_to:
download_to = pathlib.Path(download_to)
if isinstance(data, Literal):
download_literal("data", data, download_to)
else:
if not recursive:
raise click.UsageError("Please specify --recursive to download all variables in a literal map.")
for var, literal in data.literals.items():
download_literal(var, literal, download_to)
15 changes: 8 additions & 7 deletions flytekit/core/data_persistence.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,7 @@ def get(self, from_path: str, to_path: str, recursive: bool = False, **kwargs):
return shutil.copytree(
self.strip_file_header(from_path), self.strip_file_header(to_path), dirs_exist_ok=True
)
print(f"Getting {from_path} to {to_path}")
return file_system.get(from_path, to_path, recursive=recursive, **kwargs)
except OSError as oe:
logger.debug(f"Error in getting {from_path} to {to_path} rec {recursive} {oe}")
Expand Down Expand Up @@ -427,31 +428,31 @@ def get_random_remote_directory(self) -> str:
self.get_random_string(),
)

def download_directory(self, remote_path: str, local_path: str):
def download_directory(self, remote_path: str, local_path: str, **kwargs):
"""
Downloads directory from given remote to local path
"""
return self.get_data(remote_path, local_path, is_multipart=True)

def download(self, remote_path: str, local_path: str):
def download(self, remote_path: str, local_path: str, **kwargs):
"""
Downloads from remote to local
"""
return self.get_data(remote_path, local_path)
return self.get_data(remote_path, local_path, **kwargs)

def upload(self, file_path: str, to_path: str):
def upload(self, file_path: str, to_path: str, **kwargs):
"""
:param Text file_path:
:param Text to_path:
"""
return self.put_data(file_path, to_path)
return self.put_data(file_path, to_path, **kwargs)

def upload_directory(self, local_path: str, remote_path: str):
def upload_directory(self, local_path: str, remote_path: str, **kwargs):
"""
:param Text local_path:
:param Text remote_path:
"""
return self.put_data(local_path, remote_path, is_multipart=True)
return self.put_data(local_path, remote_path, is_multipart=True, **kwargs)

def get_data(self, remote_path: str, local_path: str, is_multipart: bool = False, **kwargs):
"""
Expand Down
7 changes: 6 additions & 1 deletion flytekit/core/gate.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@
from flytekit.core.type_engine import TypeEngine
from flytekit.exceptions.user import FlyteDisapprovalException
from flytekit.interaction.parse_stdin import parse_stdin_to_literal
from flytekit.interaction.string_literals import scalar_to_string
from flytekit.models.core import workflow as _workflow_model
from flytekit.models.literals import Scalar
from flytekit.models.types import LiteralType

DEFAULT_TIMEOUT = datetime.timedelta(hours=1)
Expand Down Expand Up @@ -111,8 +113,11 @@ def local_execute(self, ctx: FlyteContext, **kwargs) -> Union[Tuple[Promise], Pr
return p

# Assume this is an approval operation since that's the only remaining option.
v = typing.cast(Promise, self._upstream_item).val.value
if isinstance(v, Scalar):
v = scalar_to_string(v)
msg = click.style("[Approval Gate] ", fg="yellow") + click.style(
f"@{self.name} Approve {typing.cast(Promise, self._upstream_item).val.value}?", fg="cyan"
f"@{self.name} Approve {click.style(v, fg='green')}?", fg="cyan"
)
proceed = click.confirm(msg, default=True)
if proceed:
Expand Down
9 changes: 8 additions & 1 deletion flytekit/core/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from flytekit.core import constants as _common_constants
from flytekit.core.base_task import PythonTask
from flytekit.core.class_based_resolver import ClassStorageTaskResolver
from flytekit.core.condition import ConditionalSection
from flytekit.core.condition import ConditionalSection, conditional
from flytekit.core.context_manager import (
CompilationState,
ExecutionState,
Expand Down Expand Up @@ -506,6 +506,13 @@ def execute(self, **kwargs):
return get_promise(self.output_bindings[0].binding, intermediate_node_outputs)
return tuple([get_promise(b.binding, intermediate_node_outputs) for b in self.output_bindings])

def create_conditional(self, name: str) -> ConditionalSection:
ctx = FlyteContext.current_context()
if ctx.compilation_state is not None:
raise Exception("Can't already be compiling")
FlyteContextManager.with_context(ctx.with_compilation_state(self.compilation_state))
return conditional(name=name)

def add_entity(self, entity: Union[PythonTask, LaunchPlan, WorkflowBase], **kwargs) -> Node:
"""
Anytime you add an entity, all the inputs to the entity must be bound.
Expand Down
2 changes: 2 additions & 0 deletions flytekit/interaction/click_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,8 @@ def __init__(self, enum_type: typing.Type[enum.Enum]):
def convert(
self, value: typing.Any, param: typing.Optional[click.Parameter], ctx: typing.Optional[click.Context]
) -> enum.Enum:
if isinstance(value, self._enum_type):
return value
return self._enum_type(super().convert(value, param, ctx))


Expand Down
22 changes: 22 additions & 0 deletions flytekit/interaction/rich_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
import typing

from fsspec import Callback
from rich.progress import Progress


class RichCallback(Callback):
def __init__(self, rich_kwargs: typing.Optional[typing.Dict] = None, **kwargs):
super().__init__(**kwargs)
rich_kwargs = rich_kwargs or {}
self._pb = Progress(**rich_kwargs)
self._pb.start()
self._task = None

def set_size(self, size):
self._task = self._pb.add_task("Downloading...", total=size)

def relative_update(self, inc=1):
self._pb.update(self._task, advance=inc)

def __del__(self):
self._pb.stop()
12 changes: 6 additions & 6 deletions flytekit/interaction/string_literals.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,17 +10,17 @@ def primitive_to_string(primitive: Primitive) -> typing.Any:
"""
This method is used to convert a primitive to a string representation.
"""
if primitive.integer:
if primitive.integer is not None:
return primitive.integer
if primitive.float_value:
if primitive.float_value is not None:
return primitive.float_value
if primitive.boolean:
if primitive.boolean is not None:
return primitive.boolean
if primitive.string_value:
if primitive.string_value is not None:
return primitive.string_value
if primitive.datetime:
if primitive.datetime is not None:
return primitive.datetime.isoformat()
if primitive.duration:
if primitive.duration is not None:
return primitive.duration.total_seconds()
raise ValueError(f"Unknown primitive type {primitive}")

Expand Down

0 comments on commit c91ffba

Please sign in to comment.