Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Lazy load modules #1590

Merged
merged 31 commits into from
May 5, 2023
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 1 addition & 9 deletions flytekit/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,22 +221,14 @@
from flytekit.core.task import Secret, reference_task, task
from flytekit.core.workflow import ImperativeWorkflow as Workflow
from flytekit.core.workflow import WorkflowFailurePolicy, reference_workflow, workflow
from flytekit.deck import Deck
from flytekit.extras import pytorch, sklearn, tensorflow
from flytekit.loggers import logger
from flytekit.models.common import Annotations, AuthRole, Labels
from flytekit.models.core.execution import WorkflowExecutionPhase
from flytekit.models.core.types import BlobType
from flytekit.models.documentation import Description, Documentation, SourceCode
from flytekit.models.literals import Blob, BlobMetadata, Literal, Scalar
from flytekit.models.types import LiteralType
from flytekit.types import directory, file, numpy, schema
from flytekit.types.structured.structured_dataset import (
StructuredDataset,
StructuredDatasetFormat,
StructuredDatasetTransformerEngine,
StructuredDatasetType,
)
from flytekit.types import directory, file

__version__ = "0.0.0+develop"

Expand Down
3 changes: 2 additions & 1 deletion flytekit/configuration/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,6 @@
from typing import Dict, List, Optional

from dataclasses_json import dataclass_json
from docker_image import reference

from flytekit.configuration import internal as _internal
from flytekit.configuration.default_images import DefaultImages
Expand Down Expand Up @@ -205,6 +204,8 @@ def look_up_image_info(name: str, tag: str, optional_tag: bool = False) -> Image
:param Text tag: e.g. somedocker.com/myimage:someversion123
:rtype: Text
"""
from docker_image import reference

ref = reference.Reference.parse(tag)
if not optional_tag and ref["tag"] is None:
raise AssertionError(f"Incorrectly formatted image {tag}, missing tag value")
Expand Down
3 changes: 2 additions & 1 deletion flytekit/core/base_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@
)
from flytekit.core.tracker import TrackedInstance
from flytekit.core.type_engine import TypeEngine
from flytekit.deck.deck import Deck
from flytekit.loggers import logger
from flytekit.models import dynamic_job as _dynamic_job
from flytekit.models import interface as _interface_models
Expand Down Expand Up @@ -574,6 +573,8 @@ def dispatch_execute(
) from e

if self._disable_deck is False:
from flytekit.deck.deck import Deck

INPUT = "input"
OUTPUT = "output"

Expand Down
2 changes: 1 addition & 1 deletion flytekit/core/container_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def __init__(
metadata_format: MetadataFormat = MetadataFormat.JSON,
io_strategy: Optional[IOStrategy] = None,
secret_requests: Optional[List[Secret]] = None,
pod_template: Optional[PodTemplate] = None,
pod_template: Optional["PodTemplate"] = None,
pod_template_name: Optional[str] = None,
**kwargs,
):
Expand Down
7 changes: 6 additions & 1 deletion flytekit/core/context_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@
from enum import Enum
from typing import Generator, List, Optional, Union

from flytekit.clients import friendly as friendly_client # noqa
import lazy_import

from flytekit.configuration import Config, SecretsConfig, SerializationSettings
from flytekit.core import mock_stats, utils
from flytekit.core.checkpointer import Checkpoint, SyncCheckpoint
Expand All @@ -39,7 +40,11 @@
from flytekit.models.core import identifier as _identifier

if typing.TYPE_CHECKING:
from flytekit.clients import friendly as friendly_client
from flytekit.deck.deck import Deck
else:
friendly_client = lazy_import.lazy_module("flytekit.clients.friendly")


# TODO: resolve circular import from flytekit.core.python_auto_container import TaskResolverMixin

Expand Down
4 changes: 3 additions & 1 deletion flytekit/core/local_cache.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from typing import Optional

import joblib
import lazy_import
from diskcache import Cache

from flytekit.models.literals import Literal, LiteralCollection, LiteralMap

joblib = lazy_import.lazy_module("joblib")

# Location on the filesystem where serialized objects will be stored
# TODO: read from config
CACHE_LOCATION = "~/.flyte/local-cache"
Expand Down
17 changes: 11 additions & 6 deletions flytekit/core/pod_template.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,27 @@
from dataclasses import dataclass, field
from typing import Dict, Optional

from kubernetes.client.models import V1PodSpec
from dataclasses import dataclass
from typing import TYPE_CHECKING, Dict, Optional

from flytekit.exceptions import user as _user_exceptions

if TYPE_CHECKING:
from kubernetes.client import V1PodSpec

PRIMARY_CONTAINER_DEFAULT_NAME = "primary"


@dataclass(init=True, repr=True, eq=True, frozen=True)
@dataclass(init=True, repr=True, eq=True, frozen=False)
class PodTemplate(object):
"""Custom PodTemplate specification for a Task."""

pod_spec: V1PodSpec = field(default_factory=lambda: V1PodSpec(containers=[]))
pod_spec: "V1PodSpec" = None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

needs optional now

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

primary_container_name: str = PRIMARY_CONTAINER_DEFAULT_NAME
labels: Optional[Dict[str, str]] = None
annotations: Optional[Dict[str, str]] = None

def __post_init__(self):
if self.pod_spec is None:
from kubernetes.client import V1PodSpec

self.pod_spec = V1PodSpec(containers=[])
if not self.primary_container_name:
raise _user_exceptions.FlyteValidationException("A primary container name cannot be undefined")
2 changes: 1 addition & 1 deletion flytekit/core/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def task(
task_resolver: Optional[TaskResolverMixin] = None,
docs: Optional[Documentation] = None,
disable_deck: bool = True,
pod_template: Optional[PodTemplate] = None,
pod_template: Optional["PodTemplate"] = None,
pod_template_name: Optional[str] = None,
) -> Union[Callable, PythonFunctionTask]:
"""
Expand Down
32 changes: 30 additions & 2 deletions flytekit/core/type_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
from google.protobuf.message import Message
from google.protobuf.struct_pb2 import Struct
from marshmallow_enum import EnumField, LoadDumpOptions
from marshmallow_jsonschema import JSONSchema
from typing_extensions import Annotated, get_args, get_origin

from flytekit.core.annotation import FlyteAnnotation
Expand Down Expand Up @@ -326,6 +325,8 @@ def get_literal_type(self, t: Type[T]) -> LiteralType:
# https://github.com/fuhrysteve/marshmallow-jsonschema/blob/81eada1a0c42ff67de216923968af0a6b54e5dcb/marshmallow_jsonschema/base.py#L228
if isinstance(v, EnumField):
v.load_by = LoadDumpOptions.name
from marshmallow_jsonschema import JSONSchema

schema = JSONSchema().dump(s)
except Exception as e:
# https://github.com/lovasoa/marshmallow_dataclass/issues/13
Expand Down Expand Up @@ -373,7 +374,7 @@ def _get_origin_type_in_annotation(self, python_type: Type[T]) -> Type[T]:
def _fix_structured_dataset_type(self, python_type: Type[T], python_val: typing.Any) -> T:
# In python 3.7, 3.8, DataclassJson will deserialize Annotated[StructuredDataset, kwtypes(..)] to a dict,
# so here we convert it back to the Structured Dataset.
from flytekit import StructuredDataset
from flytekit.types.structured import StructuredDataset

if python_type == StructuredDataset and type(python_val) == dict:
return StructuredDataset(**python_val)
Expand Down Expand Up @@ -768,11 +769,38 @@ def get_transformer(cls, python_type: Type) -> TypeTransformer[T]:

raise ValueError(f"Type {python_type} not supported currently in Flytekit. Please register a new transformer")

@classmethod
def lazy_import_transformers(cls, python_type: Type):
"""
Only load the transformers if needed.
"""
if get_origin(python_type) is Annotated:
python_type = get_args(python_type)[0]
if not hasattr(python_type, "__name__"):
return
name = python_type.__name__
if name == "tensorflow":
from flytekit.extras import tensorflow # noqa: F401
elif name == "torch":
from flytekit.extras import pytorch # noqa: F401
elif name == "sklearn":
from flytekit.extras import sklearn # noqa: F401
elif name in ["pandas", "pyarrow"]:
from flytekit.types.structured.structured_dataset import ( # noqa: F401
StructuredDataset,
StructuredDatasetFormat,
StructuredDatasetTransformerEngine,
StructuredDatasetType,
)
elif name == "numpy":
from flytekit.types import numpy # noqa: F401

@classmethod
def to_literal_type(cls, python_type: Type) -> LiteralType:
"""
Converts a python type into a flyte specific ``LiteralType``
"""
cls.lazy_import_transformers(python_type)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we add a comment for why this is in to literal type but not in to python value and to literal?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

moved to get_transformer

transformer = cls.get_transformer(python_type)
res = transformer.get_literal_type(python_type)
data = None
Expand Down
19 changes: 12 additions & 7 deletions flytekit/core/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,18 @@
import time as _time
from hashlib import sha224 as _sha224
from pathlib import Path
from typing import Any, Dict, List, Optional, cast
from typing import TYPE_CHECKING, Any, Dict, List, Optional, cast

import lazy_import
from flyteidl.core import tasks_pb2 as _core_task
from kubernetes.client import ApiClient
from kubernetes.client.models import V1Container, V1EnvVar, V1ResourceRequirements

from flytekit.core.pod_template import PodTemplate
from flytekit.loggers import logger
from flytekit.models import task as _task_model
from flytekit.models import task as task_models

if TYPE_CHECKING:
from flytekit.models import task as task_models
else:
task_models = lazy_import.lazy_module("flytekit.models.task")


def _dnsify(value: str) -> str:
Expand Down Expand Up @@ -131,11 +133,14 @@ def _get_container_definition(
)


def _sanitize_resource_name(resource: _task_model.Resources.ResourceEntry) -> str:
def _sanitize_resource_name(resource: task_models.Resources.ResourceEntry) -> str:
return _core_task.Resources.ResourceName.Name(resource.name).lower().replace("_", "-")


def _serialize_pod_spec(pod_template: PodTemplate, primary_container: _task_model.Container) -> Dict[str, Any]:
def _serialize_pod_spec(pod_template: "PodTemplate", primary_container: task_models.Container) -> Dict[str, Any]:
from kubernetes.client import ApiClient
from kubernetes.client.models import V1Container, V1EnvVar, V1ResourceRequirements

containers = cast(PodTemplate, pod_template).pod_spec.containers
primary_exists = False

Expand Down
40 changes: 20 additions & 20 deletions flytekit/deck/deck.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,12 @@
import typing
from typing import Optional

from jinja2 import Environment, FileSystemLoader, select_autoescape

from flytekit.core.context_manager import ExecutionParameters, ExecutionState, FlyteContext, FlyteContextManager
from flytekit.loggers import logger

OUTPUT_DIR_JUPYTER_PREFIX = "jupyter"
DECK_FILE_NAME = "deck.html"

try:
from IPython.core.display import HTML
except ImportError:
...


class Deck:
"""
Expand Down Expand Up @@ -103,8 +96,12 @@ def _get_deck(
If ignore_jupyter is set to True, then it will return a str even in a jupyter environment.
"""
deck_map = {deck.name: deck.html for deck in new_user_params.decks}
raw_html = template.render(metadata=deck_map)
raw_html = get_deck_template().render(metadata=deck_map)
if not ignore_jupyter and _ipython_check():
try:
from IPython.core.display import HTML
except ImportError:
...
return HTML(raw_html)
return raw_html

Expand All @@ -121,15 +118,18 @@ def _output_deck(task_name: str, new_user_params: ExecutionParameters):
logger.info(f"{task_name} task creates flyte deck html to file://{deck_path}")


root = os.path.dirname(os.path.abspath(__file__))
templates_dir = os.path.join(root, "html")
env = Environment(
loader=FileSystemLoader(templates_dir),
# 🔥 include autoescaping for security purposes
# sources:
# - https://jinja.palletsprojects.com/en/3.0.x/api/#autoescaping
# - https://stackoverflow.com/a/38642558/8474894 (see in comments)
# - https://stackoverflow.com/a/68826578/8474894
autoescape=select_autoescape(enabled_extensions=("html",)),
)
template = env.get_template("template.html")
def get_deck_template() -> "Template":
from jinja2 import Environment, FileSystemLoader, select_autoescape

root = os.path.dirname(os.path.abspath(__file__))
templates_dir = os.path.join(root, "html")
env = Environment(
loader=FileSystemLoader(templates_dir),
# 🔥 include autoescaping for security purposes
# sources:
# - https://jinja.palletsprojects.com/en/3.0.x/api/#autoescaping
# - https://stackoverflow.com/a/38642558/8474894 (see in comments)
# - https://stackoverflow.com/a/68826578/8474894
autoescape=select_autoescape(enabled_extensions=("html",)),
)
return env.get_template("template.html")
10 changes: 6 additions & 4 deletions flytekit/deck/renderer.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
from typing import Any

import pandas
import pyarrow
import lazy_import
from typing_extensions import Protocol, runtime_checkable

pandas = lazy_import.lazy_module("pandas")
pyarrow = lazy_import.lazy_module("pyarrow")


@runtime_checkable
class Renderable(Protocol):
Expand All @@ -27,7 +29,7 @@ def __init__(self, max_rows: int = DEFAULT_MAX_ROWS, max_cols: int = DEFAULT_MAX
self._max_rows = max_rows
self._max_cols = max_cols

def to_html(self, df: pandas.DataFrame) -> str:
def to_html(self, df: "pandas.DataFrame") -> str:
assert isinstance(df, pandas.DataFrame)
return df.to_html(max_rows=self._max_rows, max_cols=self._max_cols)

Expand All @@ -37,6 +39,6 @@ class ArrowRenderer:
Render a Arrow dataframe as an HTML table.
"""

def to_html(self, df: pyarrow.Table) -> str:
def to_html(self, df: "pyarrow.Table") -> str:
assert isinstance(df, pyarrow.Table)
return df.to_string()
3 changes: 1 addition & 2 deletions flytekit/types/structured/basic_dfs.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,7 @@
from flytekit import FlyteContext, logger
from flytekit.configuration import DataConfig
from flytekit.core.data_persistence import s3_setup_args
from flytekit.deck import TopFrameRenderer
from flytekit.deck.renderer import ArrowRenderer
from flytekit.deck.renderer import ArrowRenderer, TopFrameRenderer
from flytekit.models import literals
from flytekit.models.literals import StructuredDatasetMetadata
from flytekit.models.types import StructuredDatasetType
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@
from google.protobuf import json_format
from google.protobuf.struct_pb2 import Struct

from flytekit import StructuredDataset
from flytekit.configuration import SerializationSettings
from flytekit.extend import SQLTask
from flytekit.models import task as _task_model
from flytekit.types.structured import StructuredDataset


@dataclass
Expand Down
3 changes: 2 additions & 1 deletion plugins/flytekit-bigquery/tests/test_bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,10 @@
from google.protobuf import json_format
from google.protobuf.struct_pb2 import Struct

from flytekit import StructuredDataset, kwtypes, workflow
from flytekit import kwtypes, workflow
from flytekit.configuration import Image, ImageConfig, SerializationSettings
from flytekit.extend import get_serializable
from flytekit.models.literals import StructuredDataset

query_template = "SELECT * FROM `bigquery-public-data.crypto_dogecoin.transactions` WHERE @version = 1 LIMIT 10"

Expand Down
Loading