diff --git a/flytekit/__init__.py b/flytekit/__init__.py index f081381ee6..047bbf2674 100644 --- a/flytekit/__init__.py +++ b/flytekit/__init__.py @@ -197,6 +197,8 @@ from rich import traceback +from flytekit.lazy_import.lazy_module import lazy_module + if sys.version_info < (3, 10): from importlib_metadata import entry_points else: @@ -224,7 +226,6 @@ from flytekit.core.workflow import ImperativeWorkflow as Workflow from flytekit.core.workflow import WorkflowFailurePolicy, reference_workflow, workflow from flytekit.deck import Deck -from flytekit.extras import pytorch, sklearn, tensorflow from flytekit.image_spec import ImageSpec from flytekit.loggers import logger from flytekit.models.common import Annotations, AuthRole, Labels @@ -233,7 +234,7 @@ from flytekit.models.documentation import Description, Documentation, SourceCode from flytekit.models.literals import Blob, BlobMetadata, Literal, Scalar from flytekit.models.types import LiteralType -from flytekit.types import directory, file, numpy, schema +from flytekit.types import directory, file from flytekit.types.structured.structured_dataset import ( StructuredDataset, StructuredDatasetFormat, diff --git a/flytekit/configuration/__init__.py b/flytekit/configuration/__init__.py index eba35cfb25..4785bd228c 100644 --- a/flytekit/configuration/__init__.py +++ b/flytekit/configuration/__init__.py @@ -145,7 +145,6 @@ import yaml from dataclasses_json import dataclass_json -from docker_image import reference from flytekit.configuration import internal as _internal from flytekit.configuration.default_images import DefaultImages @@ -208,6 +207,8 @@ def look_up_image_info(name: str, tag: str, optional_tag: bool = False) -> Image :param Text tag: e.g. somedocker.com/myimage:someversion123 :rtype: Text """ + from docker_image import reference + if pathlib.Path(tag).is_file(): with open(tag, "r") as f: image_spec_dict = yaml.safe_load(f) diff --git a/flytekit/core/base_task.py b/flytekit/core/base_task.py index 63c5518bfd..1f9e27f735 100644 --- a/flytekit/core/base_task.py +++ b/flytekit/core/base_task.py @@ -45,7 +45,6 @@ from flytekit.core.tracker import TrackedInstance from flytekit.core.type_engine import TypeEngine, TypeTransformerFailedError from flytekit.core.utils import timeit -from flytekit.deck.deck import Deck from flytekit.loggers import logger from flytekit.models import dynamic_job as _dynamic_job from flytekit.models import interface as _interface_models @@ -590,6 +589,8 @@ def dispatch_execute( raise TypeError(msg) from e if self._disable_deck is False: + from flytekit.deck.deck import Deck + INPUT = "input" OUTPUT = "output" diff --git a/flytekit/core/container_task.py b/flytekit/core/container_task.py index d51f71d837..fd604004d6 100644 --- a/flytekit/core/container_task.py +++ b/flytekit/core/container_task.py @@ -49,7 +49,7 @@ def __init__( metadata_format: MetadataFormat = MetadataFormat.JSON, io_strategy: Optional[IOStrategy] = None, secret_requests: Optional[List[Secret]] = None, - pod_template: Optional[PodTemplate] = None, + pod_template: Optional["PodTemplate"] = None, pod_template_name: Optional[str] = None, **kwargs, ): diff --git a/flytekit/core/context_manager.py b/flytekit/core/context_manager.py index de3b339251..e2923bfc7f 100644 --- a/flytekit/core/context_manager.py +++ b/flytekit/core/context_manager.py @@ -27,7 +27,6 @@ from enum import Enum from typing import Generator, List, Optional, Union -from flytekit.clients import friendly as friendly_client # noqa from flytekit.configuration import Config, SecretsConfig, SerializationSettings from flytekit.core import mock_stats, utils from flytekit.core.checkpointer import Checkpoint, SyncCheckpoint @@ -39,7 +38,8 @@ from flytekit.models.core import identifier as _identifier if typing.TYPE_CHECKING: - from flytekit.deck.deck import Deck + from flytekit import Deck + from flytekit.clients import friendly as friendly_client # noqa # TODO: resolve circular import from flytekit.core.python_auto_container import TaskResolverMixin @@ -262,7 +262,7 @@ def decks(self) -> typing.List: @property def default_deck(self) -> Deck: - from flytekit.deck.deck import Deck + from flytekit import Deck return Deck("default") @@ -551,7 +551,7 @@ class FlyteContext(object): file_access: FileAccessProvider level: int = 0 - flyte_client: Optional[friendly_client.SynchronousFlyteClient] = None + flyte_client: Optional["friendly_client.SynchronousFlyteClient"] = None compilation_state: Optional[CompilationState] = None execution_state: Optional[ExecutionState] = None serialization_settings: Optional[SerializationSettings] = None @@ -660,7 +660,7 @@ class Builder(object): level: int = 0 compilation_state: Optional[CompilationState] = None execution_state: Optional[ExecutionState] = None - flyte_client: Optional[friendly_client.SynchronousFlyteClient] = None + flyte_client: Optional["friendly_client.SynchronousFlyteClient"] = None serialization_settings: Optional[SerializationSettings] = None in_a_condition: bool = False diff --git a/flytekit/core/local_cache.py b/flytekit/core/local_cache.py index 11cb3b926c..e0b205ca5b 100644 --- a/flytekit/core/local_cache.py +++ b/flytekit/core/local_cache.py @@ -1,10 +1,12 @@ from typing import Optional -import joblib from diskcache import Cache +from flytekit import lazy_module from flytekit.models.literals import Literal, LiteralCollection, LiteralMap +joblib = lazy_module("joblib") + # Location on the filesystem where serialized objects will be stored # TODO: read from config CACHE_LOCATION = "~/.flyte/local-cache" diff --git a/flytekit/core/node.py b/flytekit/core/node.py index 73f951d721..4f7838d2b6 100644 --- a/flytekit/core/node.py +++ b/flytekit/core/node.py @@ -84,7 +84,9 @@ def metadata(self) -> _workflow_model.NodeMetadata: def with_overrides(self, *args, **kwargs): if "node_name" in kwargs: - self._id = kwargs["node_name"] + # Convert the node name into a DNS-compliant. + # https://kubernetes.io/docs/concepts/overview/working-with-objects/names/#dns-subdomain-names + self._id = _dnsify(kwargs["node_name"]) if "aliases" in kwargs: alias_dict = kwargs["aliases"] if not isinstance(alias_dict, dict): diff --git a/flytekit/core/pod_template.py b/flytekit/core/pod_template.py index af211e55d7..98ba92af36 100644 --- a/flytekit/core/pod_template.py +++ b/flytekit/core/pod_template.py @@ -1,22 +1,27 @@ -from dataclasses import dataclass, field -from typing import Dict, Optional - -from kubernetes.client.models import V1PodSpec +from dataclasses import dataclass +from typing import TYPE_CHECKING, Dict, Optional from flytekit.exceptions import user as _user_exceptions +if TYPE_CHECKING: + from kubernetes.client import V1PodSpec + PRIMARY_CONTAINER_DEFAULT_NAME = "primary" -@dataclass(init=True, repr=True, eq=True, frozen=True) +@dataclass(init=True, repr=True, eq=True, frozen=False) class PodTemplate(object): """Custom PodTemplate specification for a Task.""" - pod_spec: V1PodSpec = field(default_factory=lambda: V1PodSpec(containers=[])) + pod_spec: Optional["V1PodSpec"] = None primary_container_name: str = PRIMARY_CONTAINER_DEFAULT_NAME labels: Optional[Dict[str, str]] = None annotations: Optional[Dict[str, str]] = None def __post_init__(self): + if self.pod_spec is None: + from kubernetes.client import V1PodSpec + + self.pod_spec = V1PodSpec(containers=[]) if not self.primary_container_name: raise _user_exceptions.FlyteValidationException("A primary container name cannot be undefined") diff --git a/flytekit/core/task.py b/flytekit/core/task.py index afd435e625..2cdaa50365 100644 --- a/flytekit/core/task.py +++ b/flytekit/core/task.py @@ -94,7 +94,7 @@ def task( task_resolver: Optional[TaskResolverMixin] = None, docs: Optional[Documentation] = None, disable_deck: bool = True, - pod_template: Optional[PodTemplate] = None, + pod_template: Optional["PodTemplate"] = None, pod_template_name: Optional[str] = None, ) -> Union[Callable, PythonFunctionTask]: """ diff --git a/flytekit/core/type_engine.py b/flytekit/core/type_engine.py index 3957a609d8..bd355bd529 100644 --- a/flytekit/core/type_engine.py +++ b/flytekit/core/type_engine.py @@ -22,7 +22,6 @@ from google.protobuf.message import Message from google.protobuf.struct_pb2 import Struct from marshmallow_enum import EnumField, LoadDumpOptions -from marshmallow_jsonschema import JSONSchema from typing_extensions import Annotated, get_args, get_origin from flytekit.core.annotation import FlyteAnnotation @@ -31,6 +30,7 @@ from flytekit.core.type_helpers import load_type_from_tag from flytekit.core.utils import timeit from flytekit.exceptions import user as user_exceptions +from flytekit.lazy_import.lazy_module import is_imported from flytekit.loggers import logger from flytekit.models import interface as _interface_models from flytekit.models import types as _type_models @@ -329,6 +329,8 @@ def get_literal_type(self, t: Type[T]) -> LiteralType: # https://github.com/fuhrysteve/marshmallow-jsonschema/blob/81eada1a0c42ff67de216923968af0a6b54e5dcb/marshmallow_jsonschema/base.py#L228 if isinstance(v, EnumField): v.load_by = LoadDumpOptions.name + from marshmallow_jsonschema import JSONSchema + schema = JSONSchema().dump(s) except Exception as e: # https://github.com/lovasoa/marshmallow_dataclass/issues/13 @@ -376,7 +378,7 @@ def _get_origin_type_in_annotation(self, python_type: Type[T]) -> Type[T]: def _fix_structured_dataset_type(self, python_type: Type[T], python_val: typing.Any) -> T: # In python 3.7, 3.8, DataclassJson will deserialize Annotated[StructuredDataset, kwtypes(..)] to a dict, # so here we convert it back to the Structured Dataset. - from flytekit import StructuredDataset + from flytekit.types.structured import StructuredDataset if python_type == StructuredDataset and type(python_val) == dict: return StructuredDataset(**python_val) @@ -672,6 +674,7 @@ class TypeEngine(typing.Generic[T]): _REGISTRY: typing.Dict[type, TypeTransformer[T]] = {} _RESTRICTED_TYPES: typing.List[type] = [] _DATACLASS_TRANSFORMER: TypeTransformer = DataclassTransformer() # type: ignore + has_lazy_import = False @classmethod def register( @@ -729,7 +732,7 @@ def get_transformer(cls, python_type: Type) -> TypeTransformer[T]: Step 4: if v is of type data class, use the dataclass transformer """ - + cls.lazy_import_transformers() # Step 1 if get_origin(python_type) is Annotated: python_type = get_args(python_type)[0] @@ -771,6 +774,39 @@ def get_transformer(cls, python_type: Type) -> TypeTransformer[T]: raise ValueError(f"Type {python_type} not supported currently in Flytekit. Please register a new transformer") + @classmethod + def lazy_import_transformers(cls): + """ + Only load the transformers if needed. + """ + if cls.has_lazy_import: + return + cls.has_lazy_import = True + from flytekit.types.structured import ( + register_arrow_handlers, + register_bigquery_handlers, + register_pandas_handlers, + ) + + if is_imported("tensorflow"): + from flytekit.extras import tensorflow # noqa: F401 + if is_imported("torch"): + from flytekit.extras import pytorch # noqa: F401 + if is_imported("sklearn"): + from flytekit.extras import sklearn # noqa: F401 + if is_imported("pandas"): + try: + from flytekit.types import schema # noqa: F401 + except ValueError: + logger.debug("Transformer for pandas is already registered.") + register_pandas_handlers() + if is_imported("pyarrow"): + register_arrow_handlers() + if is_imported("google.cloud.bigquery"): + register_bigquery_handlers() + if is_imported("numpy"): + from flytekit.types import numpy # noqa: F401 + @classmethod def to_literal_type(cls, python_type: Type) -> LiteralType: """ diff --git a/flytekit/core/utils.py b/flytekit/core/utils.py index c24f40bff3..437d2b71a4 100644 --- a/flytekit/core/utils.py +++ b/flytekit/core/utils.py @@ -6,16 +6,15 @@ from functools import wraps from hashlib import sha224 as _sha224 from pathlib import Path -from typing import Any, Callable, Dict, List, Optional, cast +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, cast from flyteidl.core import tasks_pb2 as _core_task -from kubernetes.client import ApiClient -from kubernetes.client.models import V1Container, V1EnvVar, V1ResourceRequirements from flytekit.core.pod_template import PodTemplate from flytekit.loggers import logger -from flytekit.models import task as _task_model -from flytekit.models import task as task_models + +if TYPE_CHECKING: + from flytekit.models import task as task_models def _dnsify(value: str) -> str: @@ -60,7 +59,7 @@ def _get_container_definition( image: str, command: List[str], args: Optional[List[str]] = None, - data_loading_config: Optional[task_models.DataLoadingConfig] = None, + data_loading_config: Optional["task_models.DataLoadingConfig"] = None, storage_request: Optional[str] = None, ephemeral_storage_request: Optional[str] = None, cpu_request: Optional[str] = None, @@ -72,7 +71,7 @@ def _get_container_definition( gpu_limit: Optional[str] = None, memory_limit: Optional[str] = None, environment: Optional[Dict[str, str]] = None, -) -> task_models.Container: +) -> "task_models.Container": storage_limit = storage_limit storage_request = storage_request ephemeral_storage_limit = ephemeral_storage_limit @@ -84,6 +83,8 @@ def _get_container_definition( memory_limit = memory_limit memory_request = memory_request + from flytekit.models import task as task_models + # TODO: Use convert_resources_to_resource_model instead of manually fixing the resources. requests = [] if storage_request: @@ -133,12 +134,17 @@ def _get_container_definition( ) -def _sanitize_resource_name(resource: _task_model.Resources.ResourceEntry) -> str: +def _sanitize_resource_name(resource: "task_models.Resources.ResourceEntry") -> str: return _core_task.Resources.ResourceName.Name(resource.name).lower().replace("_", "-") -def _serialize_pod_spec(pod_template: PodTemplate, primary_container: _task_model.Container) -> Dict[str, Any]: - containers = cast(PodTemplate, pod_template).pod_spec.containers +def _serialize_pod_spec(pod_template: "PodTemplate", primary_container: "task_models.Container") -> Dict[str, Any]: + from kubernetes.client import ApiClient, V1PodSpec + from kubernetes.client.models import V1Container, V1EnvVar, V1ResourceRequirements + + if pod_template.pod_spec is None: + return {} + containers = cast(V1PodSpec, pod_template.pod_spec).containers primary_exists = False for container in containers: @@ -173,7 +179,7 @@ def _serialize_pod_spec(pod_template: PodTemplate, primary_container: _task_mode container.env or [] ) final_containers.append(container) - cast(PodTemplate, pod_template).pod_spec.containers = final_containers + cast(V1PodSpec, pod_template.pod_spec).containers = final_containers return ApiClient().sanitize_for_serialization(cast(PodTemplate, pod_template).pod_spec) diff --git a/flytekit/deck/deck.py b/flytekit/deck/deck.py index 7b49e98f4c..0d53ec18d6 100644 --- a/flytekit/deck/deck.py +++ b/flytekit/deck/deck.py @@ -2,8 +2,6 @@ import typing from typing import Optional -from jinja2 import Environment, FileSystemLoader, select_autoescape - from flytekit.core.context_manager import ExecutionParameters, ExecutionState, FlyteContext, FlyteContextManager from flytekit.loggers import logger @@ -11,12 +9,6 @@ DECK_FILE_NAME = "deck.html" -try: - from IPython.core.display import HTML -except ImportError: - ... - - class Deck: """ Deck enable users to get customizable and default visibility into their tasks. @@ -156,8 +148,12 @@ def _get_deck( If ignore_jupyter is set to True, then it will return a str even in a jupyter environment. """ deck_map = {deck.name: deck.html for deck in new_user_params.decks} - raw_html = template.render(metadata=deck_map) + raw_html = get_deck_template().render(metadata=deck_map) if not ignore_jupyter and _ipython_check(): + try: + from IPython.core.display import HTML + except ImportError: + ... return HTML(raw_html) return raw_html @@ -174,15 +170,18 @@ def _output_deck(task_name: str, new_user_params: ExecutionParameters): logger.info(f"{task_name} task creates flyte deck html to file://{deck_path}") -root = os.path.dirname(os.path.abspath(__file__)) -templates_dir = os.path.join(root, "html") -env = Environment( - loader=FileSystemLoader(templates_dir), - # 🔥 include autoescaping for security purposes - # sources: - # - https://jinja.palletsprojects.com/en/3.0.x/api/#autoescaping - # - https://stackoverflow.com/a/38642558/8474894 (see in comments) - # - https://stackoverflow.com/a/68826578/8474894 - autoescape=select_autoescape(enabled_extensions=("html",)), -) -template = env.get_template("template.html") +def get_deck_template() -> "Template": + from jinja2 import Environment, FileSystemLoader, select_autoescape + + root = os.path.dirname(os.path.abspath(__file__)) + templates_dir = os.path.join(root, "html") + env = Environment( + loader=FileSystemLoader(templates_dir), + # 🔥 include autoescaping for security purposes + # sources: + # - https://jinja.palletsprojects.com/en/3.0.x/api/#autoescaping + # - https://stackoverflow.com/a/38642558/8474894 (see in comments) + # - https://stackoverflow.com/a/68826578/8474894 + autoescape=select_autoescape(enabled_extensions=("html",)), + ) + return env.get_template("template.html") diff --git a/flytekit/deck/renderer.py b/flytekit/deck/renderer.py index dddb88e420..cfea92ec4e 100644 --- a/flytekit/deck/renderer.py +++ b/flytekit/deck/renderer.py @@ -1,14 +1,22 @@ -from typing import Any +from typing import TYPE_CHECKING, Any -import pandas -import pyarrow from typing_extensions import Protocol, runtime_checkable +from flytekit import lazy_module + +if TYPE_CHECKING: + # Always import these modules in type-checking mode or when running pytest + import pandas + import pyarrow +else: + pandas = lazy_module("pandas") + pyarrow = lazy_module("pyarrow") + @runtime_checkable class Renderable(Protocol): def to_html(self, python_value: Any) -> str: - """Convert a object(markdown, pandas.dataframe) to HTML and return HTML as a unicode string. + """Convert an object(markdown, pandas.dataframe) to HTML and return HTML as a unicode string. Returns: An HTML document as a string. """ raise NotImplementedError @@ -27,16 +35,16 @@ def __init__(self, max_rows: int = DEFAULT_MAX_ROWS, max_cols: int = DEFAULT_MAX self._max_rows = max_rows self._max_cols = max_cols - def to_html(self, df: pandas.DataFrame) -> str: + def to_html(self, df: "pandas.DataFrame") -> str: assert isinstance(df, pandas.DataFrame) return df.to_html(max_rows=self._max_rows, max_cols=self._max_cols) class ArrowRenderer: """ - Render a Arrow dataframe as an HTML table. + Render an Arrow dataframe as an HTML table. """ - def to_html(self, df: pyarrow.Table) -> str: + def to_html(self, df: "pyarrow.Table") -> str: assert isinstance(df, pyarrow.Table) return df.to_string() diff --git a/flytekit/extras/sqlite3/task.py b/flytekit/extras/sqlite3/task.py index 8e7d8b3b29..ef8013a5da 100644 --- a/flytekit/extras/sqlite3/task.py +++ b/flytekit/extras/sqlite3/task.py @@ -14,7 +14,6 @@ from flytekit.core.python_customized_container_task import PythonCustomizedContainerTask from flytekit.core.shim_task import ShimTaskExecutor from flytekit.models import task as task_models -from flytekit.types.schema import FlyteSchema def unarchive_file(local_path: str, to_dir: str): @@ -78,12 +77,14 @@ def __init__( query_template: str, inputs: typing.Optional[typing.Dict[str, typing.Type]] = None, task_config: typing.Optional[SQLite3Config] = None, - output_schema_type: typing.Optional[typing.Type[FlyteSchema]] = None, + output_schema_type: typing.Optional[typing.Type["FlyteSchema"]] = None, # type: ignore container_image: typing.Optional[str] = None, **kwargs, ): if task_config is None or task_config.uri is None: raise ValueError("SQLite DB uri is required.") + from flytekit.types.schema import FlyteSchema + outputs = kwtypes(results=output_schema_type if output_schema_type else FlyteSchema) super().__init__( name=name, diff --git a/flytekit/image_spec/image_spec.py b/flytekit/image_spec/image_spec.py index 0e0ca6ea52..1cddb2a913 100644 --- a/flytekit/image_spec/image_spec.py +++ b/flytekit/image_spec/image_spec.py @@ -9,10 +9,8 @@ from typing import List, Optional import click -import docker import requests from dataclasses_json import dataclass_json -from docker.errors import APIError, ImageNotFound DOCKER_HUB = "docker.io" _F_IMG_ID = "_F_IMG_ID" @@ -69,6 +67,9 @@ def exist(self) -> bool: """ Check if the image exists in the registry. """ + import docker + from docker.errors import APIError, ImageNotFound + try: client = docker.from_env() if self.registry: diff --git a/flytekit/lazy_import/__init__.py b/flytekit/lazy_import/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/flytekit/lazy_import/lazy_module.py b/flytekit/lazy_import/lazy_module.py new file mode 100644 index 0000000000..553386eb52 --- /dev/null +++ b/flytekit/lazy_import/lazy_module.py @@ -0,0 +1,33 @@ +import importlib.util +import sys + +LAZY_MODULES = [] + + +def is_imported(module_name): + """ + This function is used to check if a module has been imported by the regular import. + """ + return module_name in sys.modules and module_name not in LAZY_MODULES + + +def lazy_module(fullname): + """ + This function is used to lazily import modules. It is used in the following way: + .. code-block:: python + from flytekit.lazy_import import lazy_module + sklearn = lazy_module("sklearn") + sklearn.svm.SVC() + :param Text fullname: The full name of the module to import + """ + if fullname in sys.modules: + return sys.modules[fullname] + # https://docs.python.org/3/library/importlib.html#implementing-lazy-imports + spec = importlib.util.find_spec(fullname) + loader = importlib.util.LazyLoader(spec.loader) + spec.loader = loader + module = importlib.util.module_from_spec(spec) + sys.modules[fullname] = module + LAZY_MODULES.append(module) + loader.exec_module(module) + return module diff --git a/flytekit/types/structured/__init__.py b/flytekit/types/structured/__init__.py index 52577a650d..86fa19f4f0 100644 --- a/flytekit/types/structured/__init__.py +++ b/flytekit/types/structured/__init__.py @@ -13,15 +13,9 @@ """ -from flytekit.configuration.internal import LocalSDK +from flytekit.deck.renderer import ArrowRenderer, TopFrameRenderer from flytekit.loggers import logger -from .basic_dfs import ( - ArrowToParquetEncodingHandler, - PandasToParquetEncodingHandler, - ParquetToArrowDecodingHandler, - ParquetToPandasDecodingHandler, -) from .structured_dataset import ( StructuredDataset, StructuredDatasetDecoder, @@ -29,15 +23,42 @@ StructuredDatasetTransformerEngine, ) -try: - from .bigquery import ( - ArrowToBQEncodingHandlers, - BQToArrowDecodingHandler, - BQToPandasDecodingHandler, - PandasToBQEncodingHandlers, - ) -except ImportError: - logger.info( - "We won't register bigquery handler for structured dataset because " - "we can't find the packages google-cloud-bigquery-storage and google-cloud-bigquery" - ) + +def register_pandas_handlers(): + import pandas as pd + + from .basic_dfs import PandasToParquetEncodingHandler, ParquetToPandasDecodingHandler + + StructuredDatasetTransformerEngine.register(PandasToParquetEncodingHandler(), default_format_for_type=True) + StructuredDatasetTransformerEngine.register(ParquetToPandasDecodingHandler(), default_format_for_type=True) + StructuredDatasetTransformerEngine.register_renderer(pd.DataFrame, TopFrameRenderer()) + + +def register_arrow_handlers(): + import pyarrow as pa + + from .basic_dfs import ArrowToParquetEncodingHandler, ParquetToArrowDecodingHandler + + StructuredDatasetTransformerEngine.register(ArrowToParquetEncodingHandler(), default_format_for_type=True) + StructuredDatasetTransformerEngine.register(ParquetToArrowDecodingHandler(), default_format_for_type=True) + StructuredDatasetTransformerEngine.register_renderer(pa.Table, ArrowRenderer()) + + +def register_bigquery_handlers(): + try: + from .bigquery import ( + ArrowToBQEncodingHandlers, + BQToArrowDecodingHandler, + BQToPandasDecodingHandler, + PandasToBQEncodingHandlers, + ) + + StructuredDatasetTransformerEngine.register(PandasToBQEncodingHandlers()) + StructuredDatasetTransformerEngine.register(BQToPandasDecodingHandler()) + StructuredDatasetTransformerEngine.register(ArrowToBQEncodingHandlers()) + StructuredDatasetTransformerEngine.register(BQToArrowDecodingHandler()) + except ImportError: + logger.info( + "We won't register bigquery handler for structured dataset because " + "we can't find the packages google-cloud-bigquery-storage and google-cloud-bigquery" + ) diff --git a/flytekit/types/structured/basic_dfs.py b/flytekit/types/structured/basic_dfs.py index c8f4ef3baa..8004867271 100644 --- a/flytekit/types/structured/basic_dfs.py +++ b/flytekit/types/structured/basic_dfs.py @@ -13,8 +13,6 @@ from flytekit import FlyteContext, logger from flytekit.configuration import DataConfig from flytekit.core.data_persistence import s3_setup_args -from flytekit.deck import TopFrameRenderer -from flytekit.deck.renderer import ArrowRenderer from flytekit.models import literals from flytekit.models.literals import StructuredDatasetMetadata from flytekit.models.types import StructuredDatasetType @@ -23,7 +21,6 @@ StructuredDataset, StructuredDatasetDecoder, StructuredDatasetEncoder, - StructuredDatasetTransformerEngine, ) T = TypeVar("T") @@ -132,12 +129,3 @@ def decode( if fs is not None: return pq.read_table(path, filesystem=fs, columns=columns) raise e - - -StructuredDatasetTransformerEngine.register(PandasToParquetEncodingHandler(), default_format_for_type=True) -StructuredDatasetTransformerEngine.register(ParquetToPandasDecodingHandler(), default_format_for_type=True) -StructuredDatasetTransformerEngine.register(ArrowToParquetEncodingHandler(), default_format_for_type=True) -StructuredDatasetTransformerEngine.register(ParquetToArrowDecodingHandler(), default_format_for_type=True) - -StructuredDatasetTransformerEngine.register_renderer(pd.DataFrame, TopFrameRenderer()) -StructuredDatasetTransformerEngine.register_renderer(pa.Table, ArrowRenderer()) diff --git a/flytekit/types/structured/bigquery.py b/flytekit/types/structured/bigquery.py index 85cede1544..049a21c07e 100644 --- a/flytekit/types/structured/bigquery.py +++ b/flytekit/types/structured/bigquery.py @@ -14,7 +14,6 @@ StructuredDatasetDecoder, StructuredDatasetEncoder, StructuredDatasetMetadata, - StructuredDatasetTransformerEngine, ) BIGQUERY = "bq" @@ -110,9 +109,3 @@ def decode( current_task_metadata: StructuredDatasetMetadata, ) -> pa.Table: return pa.Table.from_pandas(_read_from_bq(flyte_value, current_task_metadata)) - - -StructuredDatasetTransformerEngine.register(PandasToBQEncodingHandlers()) -StructuredDatasetTransformerEngine.register(BQToPandasDecodingHandler()) -StructuredDatasetTransformerEngine.register(ArrowToBQEncodingHandlers()) -StructuredDatasetTransformerEngine.register(BQToArrowDecodingHandler()) diff --git a/flytekit/types/structured/structured_dataset.py b/flytekit/types/structured/structured_dataset.py index 9b4951e084..05df91776c 100644 --- a/flytekit/types/structured/structured_dataset.py +++ b/flytekit/types/structured/structured_dataset.py @@ -8,14 +8,12 @@ from typing import Dict, Generator, Optional, Type, Union import _datetime -import numpy as _np -import pandas as pd -import pyarrow as pa from dataclasses_json import config, dataclass_json from fsspec.utils import get_protocol from marshmallow import fields from typing_extensions import Annotated, TypeAlias, get_args, get_origin +from flytekit import lazy_module from flytekit.core.context_manager import FlyteContext, FlyteContextManager from flytekit.core.type_engine import TypeEngine, TypeTransformer from flytekit.deck.renderer import Renderable @@ -25,6 +23,13 @@ from flytekit.models.literals import Literal, Scalar, StructuredDatasetMetadata from flytekit.models.types import LiteralType, SchemaType, StructuredDatasetType +if typing.TYPE_CHECKING: + import pandas as pd + import pyarrow as pa +else: + pd = lazy_module("pandas") + pa = lazy_module("pyarrow") + T = typing.TypeVar("T") # StructuredDataset type or a dataframe type DF = typing.TypeVar("DF") # Dataframe type @@ -110,7 +115,7 @@ def iter(self) -> Generator[DF, None, None]: def extract_cols_and_format( t: typing.Any, -) -> typing.Tuple[Type[T], Optional[typing.OrderedDict[str, Type]], Optional[str], Optional[pa.lib.Schema]]: +) -> typing.Tuple[Type[T], Optional[typing.OrderedDict[str, Type]], Optional[str], Optional["pa.lib.Schema"]]: """ Helper function, just used to iterate through Annotations and extract out the following information: - base type, if not Annotated, it will just be the type that was passed in. @@ -144,7 +149,7 @@ def extract_cols_and_format( if ordered_dict_cols is not None: raise ValueError(f"Column information was already found {ordered_dict_cols}, cannot use {aa}") ordered_dict_cols = aa - elif isinstance(aa, pa.Schema): + elif isinstance(aa, pa.lib.Schema): if pa_schema is not None: raise ValueError(f"Arrow schema was already found {pa_schema}, cannot use {aa}") pa_schema = aa @@ -291,16 +296,8 @@ def convert_schema_type_to_structured_dataset_type( raise AssertionError(f"Unrecognized SchemaColumnType: {column_type}") -class DuplicateHandlerError(ValueError): - ... - - -class StructuredDatasetTransformerEngine(TypeTransformer[StructuredDataset]): - """ - Think of this transformer as a higher-level meta transformer that is used for all the dataframe types. - If you are bringing a custom data frame type, or any data frame type, to flytekit, instead of - registering with the main type engine, you should register with this transformer instead. - """ +def get_supported_types(): + import numpy as _np _SUPPORTED_TYPES: typing.Dict[Type, LiteralType] = { _np.int32: type_models.LiteralType(simple=type_models.SimpleType.INTEGER), @@ -322,6 +319,19 @@ class StructuredDatasetTransformerEngine(TypeTransformer[StructuredDataset]): _np.object_: type_models.LiteralType(simple=type_models.SimpleType.STRING), str: type_models.LiteralType(simple=type_models.SimpleType.STRING), } + return _SUPPORTED_TYPES + + +class DuplicateHandlerError(ValueError): + ... + + +class StructuredDatasetTransformerEngine(TypeTransformer[StructuredDataset]): + """ + Think of this transformer as a higher-level meta transformer that is used for all the dataframe types. + If you are bringing a custom data frame type, or any data frame type, to flytekit, instead of + registering with the main type engine, you should register with this transformer instead. + """ ENCODERS: Dict[Type, Dict[str, Dict[str, StructuredDatasetEncoder]]] = {} DECODERS: Dict[Type, Dict[str, Dict[str, StructuredDatasetDecoder]]] = {} @@ -552,7 +562,7 @@ def to_literal( ) return Literal(scalar=Scalar(structured_dataset=python_val._literal_sd)) - # 2. A task returns a python StructuredDataset with a uri. + # 2. A task returns a python StructuredDataset with an uri. # Note: this case is also what happens we start a local execution of a task with a python StructuredDataset. # It gets converted into a literal first, then back into a python StructuredDataset. # @@ -798,8 +808,8 @@ def iter_as( return result def _get_dataset_column_literal_type(self, t: Type) -> type_models.LiteralType: - if t in self._SUPPORTED_TYPES: - return self._SUPPORTED_TYPES[t] + if t in get_supported_types(): + return get_supported_types()[t] if hasattr(t, "__origin__") and t.__origin__ == list: return type_models.LiteralType(collection_type=self._get_dataset_column_literal_type(t.__args__[0])) if hasattr(t, "__origin__") and t.__origin__ == dict: diff --git a/plugins/flytekit-bigquery/flytekitplugins/bigquery/task.py b/plugins/flytekit-bigquery/flytekitplugins/bigquery/task.py index 1d4a7f0dbd..67ff323e4f 100644 --- a/plugins/flytekit-bigquery/flytekitplugins/bigquery/task.py +++ b/plugins/flytekit-bigquery/flytekitplugins/bigquery/task.py @@ -5,10 +5,10 @@ from google.protobuf import json_format from google.protobuf.struct_pb2 import Struct -from flytekit import StructuredDataset from flytekit.configuration import SerializationSettings from flytekit.extend import SQLTask from flytekit.models import task as _task_model +from flytekit.types.structured import StructuredDataset @dataclass diff --git a/plugins/flytekit-deck-standard/flytekitplugins/deck/renderer.py b/plugins/flytekit-deck-standard/flytekitplugins/deck/renderer.py index 30f279b1a9..c090ea6a46 100644 --- a/plugins/flytekit-deck-standard/flytekitplugins/deck/renderer.py +++ b/plugins/flytekit-deck-standard/flytekitplugins/deck/renderer.py @@ -1,15 +1,19 @@ -import base64 -from io import BytesIO -from typing import List, Optional, Union - -import markdown -import pandas as pd -import plotly.express as px -from PIL import Image -from ydata_profiling import ProfileReport +from typing import TYPE_CHECKING, List, Optional, Union +from flytekit import lazy_module from flytekit.types.file import FlyteFile +if TYPE_CHECKING: + import markdown + import pandas as pd + import PIL + import plotly.express as px +else: + pd = lazy_module("pandas") + markdown = lazy_module("markdown") + px = lazy_module("plotly.express") + PIL = lazy_module("PIL") + class FrameProfilingRenderer: """ @@ -19,9 +23,11 @@ class FrameProfilingRenderer: def __init__(self, title: str = "Pandas Profiling Report"): self._title = title - def to_html(self, df: pd.DataFrame) -> str: + def to_html(self, df: "pd.DataFrame") -> str: assert isinstance(df, pd.DataFrame) - profile = ProfileReport(df, title=self._title) + import ydata_profiling + + profile = ydata_profiling.ProfileReport(df, title=self._title) return profile.to_html() @@ -44,7 +50,7 @@ class BoxRenderer: Each box spans from quartile 1 (Q1) to quartile 3 (Q3). The second quartile (Q2) is marked by a line inside the box. By default, the - whiskers correspond to the box' edges +/- 1.5 times the interquartile + whiskers correspond to the box edges +/- 1.5 times the interquartile range (IQR: Q3-Q1), see "points" for other options. """ @@ -52,7 +58,7 @@ class BoxRenderer: def __init__(self, column_name): self._column_name = column_name - def to_html(self, df: pd.DataFrame) -> str: + def to_html(self, df: "pd.DataFrame") -> str: fig = px.box(df, y=self._column_name) return fig.to_html() @@ -62,22 +68,25 @@ class ImageRenderer: represented as a base64-encoded string. """ - def to_html(cls, image_src: Union[FlyteFile, Image.Image]) -> str: - img = cls._get_image_object(image_src) - return cls._image_to_html_string(img) + def to_html(self, image_src: Union[FlyteFile, "PIL.Image.Image"]) -> str: + img = self._get_image_object(image_src) + return self._image_to_html_string(img) @staticmethod - def _get_image_object(image_src: Union[FlyteFile, Image.Image]) -> Image.Image: + def _get_image_object(image_src: Union[FlyteFile, "PIL.Image.Image"]) -> "PIL.Image.Image": if isinstance(image_src, FlyteFile): local_path = image_src.download() - return Image.open(local_path) - elif isinstance(image_src, Image.Image): + return PIL.Image.open(local_path) + elif isinstance(image_src, PIL.Image.Image): return image_src else: raise ValueError("Unsupported image source type") @staticmethod - def _image_to_html_string(img: Image.Image) -> str: + def _image_to_html_string(img: "PIL.Image.Image") -> str: + import base64 + from io import BytesIO + buffered = BytesIO() img.save(buffered, format="PNG") img_base64 = base64.b64encode(buffered.getvalue()).decode() diff --git a/plugins/flytekit-mlflow/flytekitplugins/mlflow/tracking.py b/plugins/flytekit-mlflow/flytekitplugins/mlflow/tracking.py index b58aa4a120..9fa897f90e 100644 --- a/plugins/flytekit-mlflow/flytekitplugins/mlflow/tracking.py +++ b/plugins/flytekit-mlflow/flytekitplugins/mlflow/tracking.py @@ -13,7 +13,7 @@ from flytekit import FlyteContextManager from flytekit.bin.entrypoint import get_one_of from flytekit.core.context_manager import ExecutionState -from flytekit.deck import TopFrameRenderer +from flytekit.deck.renderer import TopFrameRenderer def metric_to_df(metrics: typing.List[Metric]) -> pd.DataFrame: diff --git a/tests/flytekit/unit/core/test_node_creation.py b/tests/flytekit/unit/core/test_node_creation.py index b8d32e4c8d..da708a8571 100644 --- a/tests/flytekit/unit/core/test_node_creation.py +++ b/tests/flytekit/unit/core/test_node_creation.py @@ -415,7 +415,7 @@ def t1(a: str) -> str: @workflow def my_wf(a: str) -> str: - return t1(a=a).with_overrides(name="foo") + return t1(a=a).with_overrides(name="foo", node_name="t_1") serialization_settings = flytekit.configuration.SerializationSettings( project="test_proj", @@ -427,6 +427,7 @@ def my_wf(a: str) -> str: wf_spec = get_serializable(OrderedDict(), serialization_settings, my_wf) assert len(wf_spec.template.nodes) == 1 assert wf_spec.template.nodes[0].metadata.name == "foo" + assert wf_spec.template.nodes[0].id == "t-1" def test_config_override():