diff --git a/flytekit/__init__.py b/flytekit/__init__.py index b37f51ea13..b23001b4d6 100644 --- a/flytekit/__init__.py +++ b/flytekit/__init__.py @@ -226,6 +226,7 @@ from flytekit.core.workflow import WorkflowFailurePolicy, reference_workflow, workflow from flytekit.deck import Deck from flytekit.extras import pytorch, sklearn, tensorflow +from flytekit.image_spec import ImageSpec from flytekit.loggers import logger from flytekit.models.common import Annotations, AuthRole, Labels from flytekit.models.core.execution import WorkflowExecutionPhase diff --git a/flytekit/configuration/__init__.py b/flytekit/configuration/__init__.py index 50eca10e58..ea52e64b89 100644 --- a/flytekit/configuration/__init__.py +++ b/flytekit/configuration/__init__.py @@ -795,6 +795,7 @@ def new_builder(self) -> Builder: flytekit_virtualenv_root=self.flytekit_virtualenv_root, python_interpreter=self.python_interpreter, fast_serialization_settings=self.fast_serialization_settings, + source_root=self.source_root, ) def should_fast_serialize(self) -> bool: @@ -845,6 +846,7 @@ class Builder(object): flytekit_virtualenv_root: Optional[str] = None python_interpreter: Optional[str] = None fast_serialization_settings: Optional[FastSerializationSettings] = None + source_root: Optional[str] = None def with_fast_serialization_settings(self, fss: fast_serialization_settings) -> SerializationSettings.Builder: self.fast_serialization_settings = fss @@ -861,4 +863,5 @@ def build(self) -> SerializationSettings: flytekit_virtualenv_root=self.flytekit_virtualenv_root, python_interpreter=self.python_interpreter, fast_serialization_settings=self.fast_serialization_settings, + source_root=self.source_root, ) diff --git a/flytekit/core/python_auto_container.py b/flytekit/core/python_auto_container.py index a33a673d48..66a49a819e 100644 --- a/flytekit/core/python_auto_container.py +++ b/flytekit/core/python_auto_container.py @@ -181,8 +181,9 @@ def _get_container(self, settings: SerializationSettings) -> _task_model.Contain for elem in (settings.env, self.environment): if elem: env.update(elem) - if isinstance(self.container_image, ImageSpec): - self.container_image.source_root = settings.source_root + if settings.fast_serialization_settings is None or not settings.fast_serialization_settings.enabled: + if isinstance(self.container_image, ImageSpec): + self.container_image.source_root = settings.source_root return _get_container_definition( image=get_registerable_container_image(self.container_image, settings.image_config), command=[], diff --git a/flytekit/image_spec/image_spec.py b/flytekit/image_spec/image_spec.py index 3ba118741b..e35fd5c597 100644 --- a/flytekit/image_spec/image_spec.py +++ b/flytekit/image_spec/image_spec.py @@ -4,15 +4,18 @@ import sys import typing from abc import abstractmethod +from copy import copy from dataclasses import dataclass from functools import lru_cache from typing import List, Optional import click import docker +import requests from dataclasses_json import dataclass_json from docker.errors import APIError, ImageNotFound +DOCKER_HUB = "docker.io" _F_IMG_ID = "_F_IMG_ID" @@ -62,12 +65,13 @@ def is_container(self) -> bool: return os.environ.get(_F_IMG_ID) == self.image_name() return True + @lru_cache def exist(self) -> bool: """ Check if the image exists in the registry. """ - client = docker.from_env() try: + client = docker.from_env() if self.registry: client.images.get_registry_data(self.image_name()) else: @@ -76,12 +80,23 @@ def exist(self) -> bool: except APIError as e: if e.response.status_code == 404: return False - if e.response.status_code == 403: - click.secho("Permission denied. Please login you docker registry first.", fg="red") - raise e - return False except ImageNotFound: return False + except Exception as e: + tag = calculate_hash_from_image_spec(self) + # if docker engine is not running locally + container_registry = DOCKER_HUB + if "/" in self.registry: + container_registry = self.registry.split("/")[0] + if container_registry == DOCKER_HUB: + url = f"https://hub.docker.com/v2/repositories/{self.registry}/{self.name}/tags/{tag}" + response = requests.get(url) + if response.status_code == 200: + return True + + click.secho(f"Failed to check if the image exists with error : {e}", fg="red") + click.secho("Flytekit assumes that the image already exists.", fg="blue") + return True def __hash__(self): return hash(self.to_json()) @@ -121,15 +136,16 @@ def build(cls, image_spec: ImageSpec): click.secho(f"Image {image_spec.image_name()} found. Skip building.", fg="blue") -@lru_cache(maxsize=None) +@lru_cache def calculate_hash_from_image_spec(image_spec: ImageSpec): """ Calculate the hash from the image spec. """ + # copy the image spec to avoid modifying the original image spec. otherwise, the hash will be different. + spec = copy(image_spec) + spec.source_root = hash_directory(image_spec.source_root) if image_spec.source_root else b"" image_spec_bytes = bytes(image_spec.to_json(), "utf-8") - source_root_bytes = hash_directory(image_spec.source_root) if image_spec.source_root else b"" - h = hashlib.md5(image_spec_bytes + source_root_bytes) - tag = base64.urlsafe_b64encode(h.digest()).decode("ascii") + tag = base64.urlsafe_b64encode(hashlib.md5(image_spec_bytes).digest()).decode("ascii") # replace "=" with "." to make it a valid tag return tag.replace("=", ".") diff --git a/plugins/flytekit-envd/flytekitplugins/envd/image_builder.py b/plugins/flytekit-envd/flytekitplugins/envd/image_builder.py index e5426b8731..e861b69310 100644 --- a/plugins/flytekit-envd/flytekitplugins/envd/image_builder.py +++ b/plugins/flytekit-envd/flytekitplugins/envd/image_builder.py @@ -38,8 +38,9 @@ def create_envd_config(image_spec: ImageSpec) -> str: base_image = DefaultImages.default_image() if image_spec.base_image is None else image_spec.base_image packages = [] if image_spec.packages is None else image_spec.packages apt_packages = [] if image_spec.apt_packages is None else image_spec.apt_packages - env = {} if image_spec.env is None else image_spec.env - env.update({"PYTHONPATH": "/root", _F_IMG_ID: image_spec.image_name()}) + env = {"PYTHONPATH": "/root", _F_IMG_ID: image_spec.image_name()} + if image_spec.env: + env.update(image_spec.env) envd_config = f"""# syntax=v1 diff --git a/tests/flytekit/unit/core/image_spec/test_image_spec.py b/tests/flytekit/unit/core/image_spec/test_image_spec.py index 8a1cfc75f6..be8ea61427 100644 --- a/tests/flytekit/unit/core/image_spec/test_image_spec.py +++ b/tests/flytekit/unit/core/image_spec/test_image_spec.py @@ -43,6 +43,7 @@ def build_image(self, img): ImageBuildEngine._REGISTRY["dummy"].build_image(image_spec) assert "dummy" in ImageBuildEngine._REGISTRY assert calculate_hash_from_image_spec(image_spec) == "yZ8jICcDTLoDArmNHbWNwg.." + assert image_spec.exist() is False with pytest.raises(Exception): image_spec.builder = "flyte" diff --git a/tests/flytekit/unit/core/test_python_function_task.py b/tests/flytekit/unit/core/test_python_function_task.py index 644b3c0963..160efb0504 100644 --- a/tests/flytekit/unit/core/test_python_function_task.py +++ b/tests/flytekit/unit/core/test_python_function_task.py @@ -78,9 +78,6 @@ def build_image(self, img): == "flytekit:0N8X-XowtpEkDYWDlb8Abg.." ) - with pytest.raises(Exception): - get_registerable_container_image(ImageSpec(builder="test", python_version="3.7", registry="hello"), cfg) - def test_get_registerable_container_image_no_images(): cfg = ImageConfig()