Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Lazy load modules #1590

Merged
merged 31 commits into from
May 5, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions flytekit/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,8 @@

from rich import traceback

from flytekit.lazy_import.lazy_module import lazy_module

if sys.version_info < (3, 10):
from importlib_metadata import entry_points
else:
Expand Down Expand Up @@ -224,7 +226,6 @@
from flytekit.core.workflow import ImperativeWorkflow as Workflow
from flytekit.core.workflow import WorkflowFailurePolicy, reference_workflow, workflow
from flytekit.deck import Deck
from flytekit.extras import pytorch, sklearn, tensorflow
from flytekit.image_spec import ImageSpec
from flytekit.loggers import logger
from flytekit.models.common import Annotations, AuthRole, Labels
Expand All @@ -233,7 +234,7 @@
from flytekit.models.documentation import Description, Documentation, SourceCode
from flytekit.models.literals import Blob, BlobMetadata, Literal, Scalar
from flytekit.models.types import LiteralType
from flytekit.types import directory, file, numpy, schema
from flytekit.types import directory, file
from flytekit.types.structured.structured_dataset import (
StructuredDataset,
StructuredDatasetFormat,
Expand Down
3 changes: 2 additions & 1 deletion flytekit/configuration/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,6 @@

import yaml
from dataclasses_json import dataclass_json
from docker_image import reference

from flytekit.configuration import internal as _internal
from flytekit.configuration.default_images import DefaultImages
Expand Down Expand Up @@ -208,6 +207,8 @@ def look_up_image_info(name: str, tag: str, optional_tag: bool = False) -> Image
:param Text tag: e.g. somedocker.com/myimage:someversion123
:rtype: Text
"""
from docker_image import reference

if pathlib.Path(tag).is_file():
with open(tag, "r") as f:
image_spec_dict = yaml.safe_load(f)
Expand Down
3 changes: 2 additions & 1 deletion flytekit/core/base_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@
from flytekit.core.tracker import TrackedInstance
from flytekit.core.type_engine import TypeEngine, TypeTransformerFailedError
from flytekit.core.utils import timeit
from flytekit.deck.deck import Deck
from flytekit.loggers import logger
from flytekit.models import dynamic_job as _dynamic_job
from flytekit.models import interface as _interface_models
Expand Down Expand Up @@ -590,6 +589,8 @@ def dispatch_execute(
raise TypeError(msg) from e

if self._disable_deck is False:
from flytekit.deck.deck import Deck

INPUT = "input"
OUTPUT = "output"

Expand Down
2 changes: 1 addition & 1 deletion flytekit/core/container_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def __init__(
metadata_format: MetadataFormat = MetadataFormat.JSON,
io_strategy: Optional[IOStrategy] = None,
secret_requests: Optional[List[Secret]] = None,
pod_template: Optional[PodTemplate] = None,
pod_template: Optional["PodTemplate"] = None,
pod_template_name: Optional[str] = None,
**kwargs,
):
Expand Down
10 changes: 5 additions & 5 deletions flytekit/core/context_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
from enum import Enum
from typing import Generator, List, Optional, Union

from flytekit.clients import friendly as friendly_client # noqa
from flytekit.configuration import Config, SecretsConfig, SerializationSettings
from flytekit.core import mock_stats, utils
from flytekit.core.checkpointer import Checkpoint, SyncCheckpoint
Expand All @@ -39,7 +38,8 @@
from flytekit.models.core import identifier as _identifier

if typing.TYPE_CHECKING:
from flytekit.deck.deck import Deck
from flytekit import Deck
from flytekit.clients import friendly as friendly_client # noqa

# TODO: resolve circular import from flytekit.core.python_auto_container import TaskResolverMixin

Expand Down Expand Up @@ -262,7 +262,7 @@ def decks(self) -> typing.List:

@property
def default_deck(self) -> Deck:
from flytekit.deck.deck import Deck
from flytekit import Deck

return Deck("default")

Expand Down Expand Up @@ -551,7 +551,7 @@ class FlyteContext(object):

file_access: FileAccessProvider
level: int = 0
flyte_client: Optional[friendly_client.SynchronousFlyteClient] = None
flyte_client: Optional["friendly_client.SynchronousFlyteClient"] = None
compilation_state: Optional[CompilationState] = None
execution_state: Optional[ExecutionState] = None
serialization_settings: Optional[SerializationSettings] = None
Expand Down Expand Up @@ -660,7 +660,7 @@ class Builder(object):
level: int = 0
compilation_state: Optional[CompilationState] = None
execution_state: Optional[ExecutionState] = None
flyte_client: Optional[friendly_client.SynchronousFlyteClient] = None
flyte_client: Optional["friendly_client.SynchronousFlyteClient"] = None
serialization_settings: Optional[SerializationSettings] = None
in_a_condition: bool = False

Expand Down
4 changes: 3 additions & 1 deletion flytekit/core/local_cache.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from typing import Optional

import joblib
from diskcache import Cache

from flytekit import lazy_module
from flytekit.models.literals import Literal, LiteralCollection, LiteralMap

joblib = lazy_module("joblib")

# Location on the filesystem where serialized objects will be stored
# TODO: read from config
CACHE_LOCATION = "~/.flyte/local-cache"
Expand Down
17 changes: 11 additions & 6 deletions flytekit/core/pod_template.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,27 @@
from dataclasses import dataclass, field
from typing import Dict, Optional

from kubernetes.client.models import V1PodSpec
from dataclasses import dataclass
from typing import TYPE_CHECKING, Dict, Optional

from flytekit.exceptions import user as _user_exceptions

if TYPE_CHECKING:
from kubernetes.client import V1PodSpec

PRIMARY_CONTAINER_DEFAULT_NAME = "primary"


@dataclass(init=True, repr=True, eq=True, frozen=True)
@dataclass(init=True, repr=True, eq=True, frozen=False)
class PodTemplate(object):
"""Custom PodTemplate specification for a Task."""

pod_spec: V1PodSpec = field(default_factory=lambda: V1PodSpec(containers=[]))
pod_spec: Optional["V1PodSpec"] = None
primary_container_name: str = PRIMARY_CONTAINER_DEFAULT_NAME
labels: Optional[Dict[str, str]] = None
annotations: Optional[Dict[str, str]] = None

def __post_init__(self):
if self.pod_spec is None:
from kubernetes.client import V1PodSpec

self.pod_spec = V1PodSpec(containers=[])
if not self.primary_container_name:
raise _user_exceptions.FlyteValidationException("A primary container name cannot be undefined")
2 changes: 1 addition & 1 deletion flytekit/core/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def task(
task_resolver: Optional[TaskResolverMixin] = None,
docs: Optional[Documentation] = None,
disable_deck: bool = True,
pod_template: Optional[PodTemplate] = None,
pod_template: Optional["PodTemplate"] = None,
pod_template_name: Optional[str] = None,
) -> Union[Callable, PythonFunctionTask]:
"""
Expand Down
42 changes: 39 additions & 3 deletions flytekit/core/type_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
from google.protobuf.message import Message
from google.protobuf.struct_pb2 import Struct
from marshmallow_enum import EnumField, LoadDumpOptions
from marshmallow_jsonschema import JSONSchema
from typing_extensions import Annotated, get_args, get_origin

from flytekit.core.annotation import FlyteAnnotation
Expand All @@ -31,6 +30,7 @@
from flytekit.core.type_helpers import load_type_from_tag
from flytekit.core.utils import timeit
from flytekit.exceptions import user as user_exceptions
from flytekit.lazy_import.lazy_module import is_imported
from flytekit.loggers import logger
from flytekit.models import interface as _interface_models
from flytekit.models import types as _type_models
Expand Down Expand Up @@ -329,6 +329,8 @@ def get_literal_type(self, t: Type[T]) -> LiteralType:
# https://github.com/fuhrysteve/marshmallow-jsonschema/blob/81eada1a0c42ff67de216923968af0a6b54e5dcb/marshmallow_jsonschema/base.py#L228
if isinstance(v, EnumField):
v.load_by = LoadDumpOptions.name
from marshmallow_jsonschema import JSONSchema

schema = JSONSchema().dump(s)
except Exception as e:
# https://github.com/lovasoa/marshmallow_dataclass/issues/13
Expand Down Expand Up @@ -376,7 +378,7 @@ def _get_origin_type_in_annotation(self, python_type: Type[T]) -> Type[T]:
def _fix_structured_dataset_type(self, python_type: Type[T], python_val: typing.Any) -> T:
# In python 3.7, 3.8, DataclassJson will deserialize Annotated[StructuredDataset, kwtypes(..)] to a dict,
# so here we convert it back to the Structured Dataset.
from flytekit import StructuredDataset
from flytekit.types.structured import StructuredDataset

if python_type == StructuredDataset and type(python_val) == dict:
return StructuredDataset(**python_val)
Expand Down Expand Up @@ -672,6 +674,7 @@ class TypeEngine(typing.Generic[T]):
_REGISTRY: typing.Dict[type, TypeTransformer[T]] = {}
_RESTRICTED_TYPES: typing.List[type] = []
_DATACLASS_TRANSFORMER: TypeTransformer = DataclassTransformer() # type: ignore
has_lazy_import = False

@classmethod
def register(
Expand Down Expand Up @@ -729,7 +732,7 @@ def get_transformer(cls, python_type: Type) -> TypeTransformer[T]:
Step 4:
if v is of type data class, use the dataclass transformer
"""

cls.lazy_import_transformers()
# Step 1
if get_origin(python_type) is Annotated:
python_type = get_args(python_type)[0]
Expand Down Expand Up @@ -771,6 +774,39 @@ def get_transformer(cls, python_type: Type) -> TypeTransformer[T]:

raise ValueError(f"Type {python_type} not supported currently in Flytekit. Please register a new transformer")

@classmethod
def lazy_import_transformers(cls):
"""
Only load the transformers if needed.
"""
if cls.has_lazy_import:
return
cls.has_lazy_import = True
from flytekit.types.structured import (
register_arrow_handlers,
register_bigquery_handlers,
register_pandas_handlers,
)

if is_imported("tensorflow"):
from flytekit.extras import tensorflow # noqa: F401
if is_imported("torch"):
from flytekit.extras import pytorch # noqa: F401
if is_imported("sklearn"):
from flytekit.extras import sklearn # noqa: F401
if is_imported("pandas"):
try:
from flytekit.types import schema # noqa: F401
except ValueError:
logger.debug("Transformer for pandas is already registered.")
register_pandas_handlers()
if is_imported("pyarrow"):
register_arrow_handlers()
if is_imported("google.cloud.bigquery"):
register_bigquery_handlers()
if is_imported("numpy"):
from flytekit.types import numpy # noqa: F401

@classmethod
def to_literal_type(cls, python_type: Type) -> LiteralType:
"""
Expand Down
28 changes: 17 additions & 11 deletions flytekit/core/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,15 @@
from functools import wraps
from hashlib import sha224 as _sha224
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, cast
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, cast

from flyteidl.core import tasks_pb2 as _core_task
from kubernetes.client import ApiClient
from kubernetes.client.models import V1Container, V1EnvVar, V1ResourceRequirements

from flytekit.core.pod_template import PodTemplate
from flytekit.loggers import logger
from flytekit.models import task as _task_model
from flytekit.models import task as task_models

if TYPE_CHECKING:
from flytekit.models import task as task_models


def _dnsify(value: str) -> str:
Expand Down Expand Up @@ -60,7 +59,7 @@ def _get_container_definition(
image: str,
command: List[str],
args: Optional[List[str]] = None,
data_loading_config: Optional[task_models.DataLoadingConfig] = None,
data_loading_config: Optional["task_models.DataLoadingConfig"] = None,
storage_request: Optional[str] = None,
ephemeral_storage_request: Optional[str] = None,
cpu_request: Optional[str] = None,
Expand All @@ -72,7 +71,7 @@ def _get_container_definition(
gpu_limit: Optional[str] = None,
memory_limit: Optional[str] = None,
environment: Optional[Dict[str, str]] = None,
) -> task_models.Container:
) -> "task_models.Container":
storage_limit = storage_limit
storage_request = storage_request
ephemeral_storage_limit = ephemeral_storage_limit
Expand All @@ -84,6 +83,8 @@ def _get_container_definition(
memory_limit = memory_limit
memory_request = memory_request

from flytekit.models import task as task_models

# TODO: Use convert_resources_to_resource_model instead of manually fixing the resources.
requests = []
if storage_request:
Expand Down Expand Up @@ -133,12 +134,17 @@ def _get_container_definition(
)


def _sanitize_resource_name(resource: _task_model.Resources.ResourceEntry) -> str:
def _sanitize_resource_name(resource: "task_models.Resources.ResourceEntry") -> str:
return _core_task.Resources.ResourceName.Name(resource.name).lower().replace("_", "-")


def _serialize_pod_spec(pod_template: PodTemplate, primary_container: _task_model.Container) -> Dict[str, Any]:
containers = cast(PodTemplate, pod_template).pod_spec.containers
def _serialize_pod_spec(pod_template: "PodTemplate", primary_container: "task_models.Container") -> Dict[str, Any]:
from kubernetes.client import ApiClient, V1PodSpec
from kubernetes.client.models import V1Container, V1EnvVar, V1ResourceRequirements

if pod_template.pod_spec is None:
return {}
containers = cast(V1PodSpec, pod_template.pod_spec).containers
primary_exists = False

for container in containers:
Expand Down Expand Up @@ -173,7 +179,7 @@ def _serialize_pod_spec(pod_template: PodTemplate, primary_container: _task_mode
container.env or []
)
final_containers.append(container)
cast(PodTemplate, pod_template).pod_spec.containers = final_containers
cast(V1PodSpec, pod_template.pod_spec).containers = final_containers

return ApiClient().sanitize_for_serialization(cast(PodTemplate, pod_template).pod_spec)

Expand Down
Loading