Skip to content

Commit

Permalink
Plugins lazy module loading (#2049)
Browse files Browse the repository at this point in the history
* all flytekit plugins use lazy module loading

Signed-off-by: Niels Bantilan <[email protected]>

* fix test issues, lint

Signed-off-by: Niels Bantilan <[email protected]>

* add --show-fixes arg to ruff

Signed-off-by: Niels Bantilan <[email protected]>

* debug linter

Signed-off-by: Niels Bantilan <[email protected]>

* fix modin test

Signed-off-by: Niels Bantilan <[email protected]>

* run ci

Signed-off-by: Niels Bantilan <[email protected]>

* run ci

Signed-off-by: Niels Bantilan <[email protected]>

---------

Signed-off-by: Niels Bantilan <[email protected]>
  • Loading branch information
cosmicBboy authored Dec 18, 2023
1 parent 628ac89 commit 8fd3dfa
Show file tree
Hide file tree
Showing 34 changed files with 210 additions and 179 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ repos:
hooks:
# Run the linter.
- id: ruff
args: [--fix]
args: [--fix, --show-fixes, --show-source]
# Run the formatter.
- id: ruff-format
- repo: https://github.com/pre-commit/pre-commit-hooks
Expand Down
38 changes: 23 additions & 15 deletions plugins/flytekit-airflow/flytekitplugins/airflow/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,7 @@

import jsonpickle

from airflow import DAG
from airflow.exceptions import AirflowException
from airflow.models import BaseOperator
from airflow.sensors.base import BaseSensorOperator
from airflow.triggers.base import BaseTrigger
from airflow.utils.context import Context
from flytekit import FlyteContextManager, logger
from flytekit import FlyteContextManager, lazy_module, logger
from flytekit.configuration import SerializationSettings
from flytekit.core.base_task import PythonTask, TaskResolverMixin
from flytekit.core.interface import Interface
Expand All @@ -21,6 +15,12 @@
from flytekit.core.utils import timeit
from flytekit.extend.backend.base_agent import AsyncAgentExecutorMixin

airflow = lazy_module("airflow")
airflow_models = lazy_module("airflow.models")
airflow_sensors = lazy_module("airflow.sensors.base")
airflow_triggers = lazy_module("airflow.triggers.base")
airflow_context = lazy_module("airflow.utils.context")


@dataclass
class AirflowObj(object):
Expand Down Expand Up @@ -51,7 +51,9 @@ def name(self) -> str:
return "AirflowTaskResolver"

@timeit("Load airflow task")
def load_task(self, loader_args: typing.List[str]) -> typing.Union[BaseOperator, BaseSensorOperator, BaseTrigger]:
def load_task(
self, loader_args: typing.List[str]
) -> typing.Union[airflow_models.BaseOperator, airflow_sensors.BaseSensorOperator, airflow_triggers.BaseTrigger]:
"""
This method is used to load an Airflow task.
"""
Expand Down Expand Up @@ -103,7 +105,7 @@ def __init__(

def execute(self, **kwargs) -> Any:
logger.info("Executing Airflow task")
_get_airflow_instance(self.task_config).execute(context=Context())
_get_airflow_instance(self.task_config).execute(context=airflow_context.Context())


class AirflowTask(AsyncAgentExecutorMixin, PythonTask[AirflowObj]):
Expand Down Expand Up @@ -134,18 +136,24 @@ def get_custom(self, settings: SerializationSettings) -> Dict[str, Any]:
return {"task_config_pkl": jsonpickle.encode(self.task_config)}


def _get_airflow_instance(airflow_obj: AirflowObj) -> typing.Union[BaseOperator, BaseSensorOperator, BaseTrigger]:
def _get_airflow_instance(
airflow_obj: AirflowObj
) -> typing.Union[airflow_models.BaseOperator, airflow_sensors.BaseSensorOperator, airflow_triggers.BaseTrigger]:
# Set the GET_ORIGINAL_TASK attribute to True so that obj_def will return the original
# airflow task instead of the Flyte task.
ctx = FlyteContextManager.current_context()
ctx.user_space_params.builder().add_attr("GET_ORIGINAL_TASK", True).build()

obj_module = importlib.import_module(name=airflow_obj.module)
obj_def = getattr(obj_module, airflow_obj.name)
if issubclass(obj_def, BaseOperator) and not issubclass(obj_def, BaseSensorOperator) and _is_deferrable(obj_def):
if (
issubclass(obj_def, airflow_models.BaseOperator)
and not issubclass(obj_def, airflow_sensors.BaseSensorOperator)
and _is_deferrable(obj_def)
):
try:
return obj_def(**airflow_obj.parameters, deferrable=True)
except AirflowException as e:
except airflow.exceptions.AirflowException as e:
logger.debug(f"Failed to create operator {airflow_obj.name} with err: {e}.")
logger.debug(f"Airflow operator {airflow_obj.name} does not support deferring.")

Expand Down Expand Up @@ -212,9 +220,9 @@ def _flyte_xcom_push(*args, **kwargs):
params.builder().add_attr("GET_ORIGINAL_TASK", False).add_attr("XCOM_DATA", {}).build()

# Monkey patch the Airflow operator. Instead of creating an airflow task, it returns a Flyte task.
BaseOperator.__new__ = _flyte_operator
BaseOperator.xcom_push = _flyte_xcom_push
airflow_models.BaseOperator.__new__ = _flyte_operator
airflow_models.BaseOperator.xcom_push = _flyte_xcom_push
# Monkey patch the xcom_push method to store the data in the Flyte context.
# Create a dummy DAG to avoid Airflow errors. This DAG is not used.
# TODO: Add support using Airflow DAG in Flyte workflow. We can probably convert the Airflow DAG to a Flyte subworkflow.
BaseSensorOperator.dag = DAG(dag_id="flyte_dag")
airflow_sensors.BaseSensorOperator.dag = airflow.DAG(dag_id="flyte_dag")
4 changes: 3 additions & 1 deletion plugins/flytekit-bigquery/flytekitplugins/bigquery/task.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,18 @@
from dataclasses import dataclass
from typing import Any, Dict, Optional, Type

from google.cloud import bigquery
from google.protobuf import json_format
from google.protobuf.struct_pb2 import Struct

from flytekit import lazy_module
from flytekit.configuration import SerializationSettings
from flytekit.extend import SQLTask
from flytekit.extend.backend.base_agent import AsyncAgentExecutorMixin
from flytekit.models import task as _task_model
from flytekit.types.structured import StructuredDataset

bigquery = lazy_module("google.cloud.bigquery")


@dataclass
class BigQueryConfig(object):
Expand Down
8 changes: 5 additions & 3 deletions plugins/flytekit-dolt/flytekitplugins/dolt/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,20 @@
from typing import Type

import dolt_integrations.core as dolt_int
import doltcli as dolt
import pandas
from dataclasses_json import DataClassJsonMixin
from google.protobuf.json_format import MessageToDict
from google.protobuf.struct_pb2 import Struct

from flytekit import FlyteContext
from flytekit import FlyteContext, lazy_module
from flytekit.extend import TypeEngine, TypeTransformer
from flytekit.models import types as _type_models
from flytekit.models.literals import Literal, Scalar
from flytekit.models.types import LiteralType

# dolt_int = lazy_module("dolt_integrations")
dolt = lazy_module("doltcli")
pandas = lazy_module("pandas")


@dataclass
class DoltConfig(DataClassJsonMixin):
Expand Down
10 changes: 5 additions & 5 deletions plugins/flytekit-duckdb/flytekitplugins/duckdb/task.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
import json
from typing import Dict, List, NamedTuple, Optional, Union

import pandas as pd
import pyarrow as pa

import duckdb
from flytekit import PythonInstanceTask
from flytekit import PythonInstanceTask, lazy_module
from flytekit.extend import Interface
from flytekit.types.structured.structured_dataset import StructuredDataset

duckdb = lazy_module("duckdb")
pd = lazy_module("pandas")
pa = lazy_module("pyarrow")


class QueryOutput(NamedTuple):
counter: int = -1
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,22 +6,19 @@

from dataclasses_json import DataClassJsonMixin

import great_expectations as ge
from flytekit import FlyteContext
from flytekit import FlyteContext, lazy_module
from flytekit.extend import TypeEngine, TypeTransformer
from flytekit.loggers import logger
from flytekit.models import types as _type_models
from flytekit.models.literals import Literal, Primitive, Scalar
from flytekit.models.types import LiteralType
from flytekit.types.file.file import FlyteFile, FlyteFilePathTransformer
from flytekit.types.schema.types import FlyteSchema, FlyteSchemaTransformer
from great_expectations.checkpoint import SimpleCheckpoint
from great_expectations.core.run_identifier import RunIdentifier
from great_expectations.core.util import convert_to_json_serializable
from great_expectations.exceptions import ValidationError

from .task import BatchRequestConfig

ge = lazy_module("great_expectations")


@dataclass
class GreatExpectationsFlyteConfig(DataClassJsonMixin):
Expand Down Expand Up @@ -281,16 +278,16 @@ def to_python_value(
)

if ge_conf.checkpoint_params:
checkpoint = SimpleCheckpoint(
checkpoint = ge.checkpoint.SimpleCheckpoint(
f"_tmp_checkpoint_{ge_conf.expectation_suite_name}",
context,
**ge_conf.checkpoint_params,
)
else:
checkpoint = SimpleCheckpoint(f"_tmp_checkpoint_{ge_conf.expectation_suite_name}", context)
checkpoint = ge.checkpoint.SimpleCheckpoint(f"_tmp_checkpoint_{ge_conf.expectation_suite_name}", context)

# identify every run uniquely
run_id = RunIdentifier(
run_id = ge.core.run_identifier.RunIdentifier(
**{
"run_name": ge_conf.datasource_name + "_run",
"run_time": datetime.datetime.utcnow(),
Expand All @@ -306,7 +303,7 @@ def to_python_value(
}
],
)
final_result = convert_to_json_serializable(checkpoint_result.list_validation_results())[0]
final_result = ge.core.util.convert_to_json_serializable(checkpoint_result.list_validation_results())[0]

result_string = ""
if final_result["success"] is False:
Expand All @@ -320,7 +317,7 @@ def to_python_value(
)

# raise a Great Expectations' exception
raise ValidationError("Validation failed!\nCOLUMN\t\tFAILED EXPECTATION\n" + result_string)
raise ge.exceptions.ValidationError("Validation failed!\nCOLUMN\t\tFAILED EXPECTATION\n" + result_string)

logger.info("Validation succeeded!")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,14 @@

from dataclasses_json import DataClassJsonMixin

import great_expectations as ge
from flytekit import PythonInstanceTask
from flytekit import PythonInstanceTask, lazy_module
from flytekit.core.context_manager import FlyteContext
from flytekit.extend import Interface
from flytekit.loggers import logger
from flytekit.types.file.file import FlyteFile
from flytekit.types.schema import FlyteSchema
from great_expectations.checkpoint import SimpleCheckpoint
from great_expectations.core.run_identifier import RunIdentifier
from great_expectations.core.util import convert_to_json_serializable
from great_expectations.exceptions import ValidationError

ge = lazy_module("great_expectations")


@dataclass
Expand Down Expand Up @@ -204,19 +201,19 @@ def execute(self, **kwargs) -> Any:
)

if self._checkpoint_params:
checkpoint = SimpleCheckpoint(
checkpoint = ge.checkpoint.SimpleCheckpoint(
f"_tmp_checkpoint_{self._expectation_suite_name}",
context,
**self._checkpoint_params,
)
else:
checkpoint = SimpleCheckpoint(
checkpoint = ge.checkpoint.SimpleCheckpoint(
f"_tmp_checkpoint_{self._expectation_suite_name}",
context,
)

# identify every run uniquely
run_id = RunIdentifier(
run_id = ge.core.run_identifier.RunIdentifier(
**{
"run_name": self._datasource_name + "_run",
"run_time": datetime.datetime.utcnow(),
Expand All @@ -232,7 +229,7 @@ def execute(self, **kwargs) -> Any:
}
],
)
final_result = convert_to_json_serializable(checkpoint_result.list_validation_results())[0]
final_result = ge.core.util.convert_to_json_serializable(checkpoint_result.list_validation_results())[0]

result_string = ""
if final_result["success"] is False:
Expand All @@ -246,7 +243,7 @@ def execute(self, **kwargs) -> Any:
)

# raise a Great Expectations' exception
raise ValidationError("Validation failed!\nCOLUMN\t\tFAILED EXPECTATION\n" + result_string)
raise ge.exceptions.ValidationError("Validation failed!\nCOLUMN\t\tFAILED EXPECTATION\n" + result_string)

logger.info("Validation succeeded!")

Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
import os
import typing

import datasets

from flytekit import FlyteContext
from flytekit import FlyteContext, lazy_module
from flytekit.models import literals
from flytekit.models.literals import StructuredDatasetMetadata
from flytekit.models.types import StructuredDatasetType
Expand All @@ -15,6 +13,8 @@
StructuredDatasetTransformerEngine,
)

datasets = lazy_module("datasets")


class HuggingFaceDatasetRenderer:
"""
Expand Down
22 changes: 12 additions & 10 deletions plugins/flytekit-k8s-pod/flytekitplugins/pod/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,8 @@
from typing import Any, Callable, Dict, Optional, Tuple, Union

from flyteidl.core import tasks_pb2 as _core_task
from kubernetes.client import ApiClient
from kubernetes.client.models import V1Container, V1EnvVar, V1PodSpec, V1ResourceRequirements

from flytekit import FlyteContext, PythonFunctionTask
from flytekit import FlyteContext, PythonFunctionTask, lazy_module
from flytekit.configuration import SerializationSettings
from flytekit.exceptions import user as _user_exceptions
from flytekit.extend import Promise, TaskPlugins
Expand All @@ -16,6 +14,10 @@
PRIMARY_CONTAINER_DEFAULT_NAME = "primary"


k8s_client = lazy_module("kubernetes.client")
k8s_models = lazy_module("kubernetes.client.models")


def _sanitize_resource_name(resource: _task_models.Resources.ResourceEntry) -> str:
return _core_task.Resources.ResourceName.Name(resource.name).lower().replace("_", "-")

Expand All @@ -35,7 +37,7 @@ class Pod(object):
:param Optional[Dict[str, str]] annotations: Annotations are key/value pairs that are attached to arbitrary non-identifying metadata to pod spec.
"""

pod_spec: V1PodSpec
pod_spec: k8s_models.V1PodSpec
primary_container_name: str = PRIMARY_CONTAINER_DEFAULT_NAME
labels: Optional[Dict[str, str]] = None
annotations: Optional[Dict[str, str]] = None
Expand Down Expand Up @@ -66,7 +68,7 @@ def _serialize_pod_spec(self, settings: SerializationSettings) -> Dict[str, Any]
break
if not primary_exists:
# insert a placeholder primary container if it is not defined in the pod spec.
containers.append(V1Container(name=self.task_config.primary_container_name))
containers.append(k8s_models.V1Container(name=self.task_config.primary_container_name))

final_containers = []
for container in containers:
Expand All @@ -87,20 +89,20 @@ def _serialize_pod_spec(self, settings: SerializationSettings) -> Dict[str, Any]
for resource in sdk_default_container.resources.requests:
requests[_sanitize_resource_name(resource)] = resource.value

resource_requirements = V1ResourceRequirements(limits=limits, requests=requests)
resource_requirements = k8s_models.V1ResourceRequirements(limits=limits, requests=requests)
if len(limits) > 0 or len(requests) > 0:
# Important! Only copy over resource requirements if they are non-empty.
container.resources = resource_requirements

container.env = [V1EnvVar(name=key, value=val) for key, val in sdk_default_container.env.items()] + (
container.env or []
)
container.env = [
k8s_models.V1EnvVar(name=key, value=val) for key, val in sdk_default_container.env.items()
] + (container.env or [])

final_containers.append(container)

self.task_config.pod_spec.containers = final_containers

return ApiClient().sanitize_for_serialization(self.task_config.pod_spec)
return k8s_client.ApiClient().sanitize_for_serialization(self.task_config.pod_spec)

def get_k8s_pod(self, settings: SerializationSettings) -> _task_models.K8sPod:
return _task_models.K8sPod(
Expand Down
5 changes: 3 additions & 2 deletions plugins/flytekit-kf-pytorch/flytekitplugins/kfpytorch/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,12 @@
from enum import Enum
from typing import Any, Callable, Dict, List, NamedTuple, Optional, Union

import cloudpickle
from flyteidl.plugins.kubeflow import common_pb2 as kubeflow_common
from flyteidl.plugins.kubeflow import pytorch_pb2 as pytorch_task
from google.protobuf.json_format import MessageToDict

import flytekit
from flytekit import PythonFunctionTask, Resources
from flytekit import PythonFunctionTask, Resources, lazy_module
from flytekit.configuration import SerializationSettings
from flytekit.core.resources import convert_resources_to_resource_model
from flytekit.exceptions.user import FlyteRecoverableException
Expand All @@ -22,6 +21,8 @@

from .error_handling import create_recoverable_error_file, is_recoverable_worker_error

cloudpickle = lazy_module("cloudpickle")

TORCH_IMPORT_ERROR_MESSAGE = "PyTorch is not installed. Please install `flytekitplugins-kfpytorch['elastic']`."


Expand Down
Loading

0 comments on commit 8fd3dfa

Please sign in to comment.