diff --git a/flytekit/image_spec/image_spec.py b/flytekit/image_spec/image_spec.py index 766e4e4a97..c7c9235a4e 100644 --- a/flytekit/image_spec/image_spec.py +++ b/flytekit/image_spec/image_spec.py @@ -61,7 +61,7 @@ class ImageSpec: apt_packages: Optional[List[str]] = None cuda: Optional[str] = None cudnn: Optional[str] = None - base_image: Optional[str] = None + base_image: Optional[Union[str, "ImageSpec"]] = None platform: str = "linux/amd64" pip_index: Optional[str] = None pip_extra_index_url: Optional[List[str]] = None @@ -228,6 +228,10 @@ def register(cls, builder_type: str, image_spec_builder: ImageSpecBuilder, prior @classmethod @lru_cache def build(cls, image_spec: ImageSpec) -> str: + if isinstance(image_spec.base_image, ImageSpec): + cls.build(image_spec.base_image) + image_spec.base_image = image_spec.base_image.image_name() + if image_spec.builder is None and cls._REGISTRY: builder = max(cls._REGISTRY, key=lambda name: cls._REGISTRY[name][1]) else: @@ -269,6 +273,8 @@ def calculate_hash_from_image_spec(image_spec: ImageSpec): """ # copy the image spec to avoid modifying the original image spec. otherwise, the hash will be different. spec = copy.deepcopy(image_spec) + if isinstance(spec.base_image, ImageSpec): + spec.base_image = spec.base_image.image_name() spec.source_root = hash_directory(image_spec.source_root) if image_spec.source_root else b"" if spec.requirements: spec.requirements = hashlib.sha1(pathlib.Path(spec.requirements).read_bytes()).hexdigest() diff --git a/plugins/flytekit-envd/tests/test_image_spec.py b/plugins/flytekit-envd/tests/test_image_spec.py index 8e5d8b1631..f7c8e3f370 100644 --- a/plugins/flytekit-envd/tests/test_image_spec.py +++ b/plugins/flytekit-envd/tests/test_image_spec.py @@ -21,17 +21,27 @@ def register_envd_higher_priority(): def test_image_spec(): + base_image = ImageSpec( + packages=["numpy"], + python_version="3.8", + registry="", + base_image="cr.flyte.org/flyteorg/flytekit:py3.8-latest", + ) + # Replace the base image name with the default flytekit image name, + # so Envd can find the base image when building imageSpec below + ImageBuildEngine._IMAGE_NAME_TO_REAL_NAME[base_image.image_name()] = "cr.flyte.org/flyteorg/flytekit:py3.8-latest" + image_spec = ImageSpec( packages=["pandas"], apt_packages=["git"], python_version="3.8", - base_image="cr.flyte.org/flyteorg/flytekit:py3.8-latest", + base_image=base_image, pip_index="https://private-pip-index/simple", ) image_spec = image_spec.with_commands("echo hello") - EnvdImageSpecBuilder().build_image(image_spec) + ImageBuildEngine.build(image_spec) config_path = create_envd_config(image_spec) assert image_spec.platform == "linux/amd64" image_name = image_spec.image_name()