Skip to content

Commit

Permalink
Adds conda support to envd plugin (#2020)
Browse files Browse the repository at this point in the history
Signed-off-by: Thomas J. Fan <[email protected]>
  • Loading branch information
thomasjpfan authored Jan 13, 2024
1 parent ccc515e commit 892b474
Show file tree
Hide file tree
Showing 5 changed files with 77 additions and 8 deletions.
4 changes: 4 additions & 0 deletions flytekit/image_spec/image_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ class ImageSpec:
env: environment variables of the image.
registry: registry of the image.
packages: list of python packages to install.
conda_packages: list of conda packages to install.
conda_channels: list of conda channels.
requirements: path to the requirements.txt file.
apt_packages: list of apt packages to install.
cuda: version of cuda to install.
Expand All @@ -47,6 +49,8 @@ class ImageSpec:
env: Optional[typing.Dict[str, str]] = None
registry: Optional[str] = None
packages: Optional[List[str]] = None
conda_packages: Optional[List[str]] = None
conda_channels: Optional[List[str]] = None
requirements: Optional[str] = None
apt_packages: Optional[List[str]] = None
cuda: Optional[str] = None
Expand Down
18 changes: 18 additions & 0 deletions plugins/flytekit-envd/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,3 +24,21 @@ Example
# def t1() -> str:
# return "hello"
```

This plugin also supports install packages from `conda`:

```python
from flytekit import task, ImageSpec

image_spec = ImageSpec(
base_image="ubuntu:20.04",
python_version="3.11",
packages=["flytekit"],
conda_packages=["pytorch", "pytorch-cuda=12.1"],
conda_channels=["pytorch", "nvidia"]
)

@task(container_image=image_spec)
def run_pytorch():
...
```
24 changes: 18 additions & 6 deletions plugins/flytekit-envd/flytekitplugins/envd/image_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,17 +50,25 @@ def build_image(self, image_spec: ImageSpec):
self.execute_command(build_command)


def _create_str_from_package_list(packages):
if packages is None:
return ""
return ", ".join(f'"{name}"' for name in packages)


def create_envd_config(image_spec: ImageSpec) -> str:
base_image = DefaultImages.default_image() if image_spec.base_image is None else image_spec.base_image
if image_spec.cuda:
if image_spec.python_version is None:
raise Exception("python_version is required when cuda and cudnn are specified")
base_image = "ubuntu20.04"

packages = [] if image_spec.packages is None else image_spec.packages
apt_packages = [] if image_spec.apt_packages is None else image_spec.apt_packages
python_packages = _create_str_from_package_list(image_spec.packages)
conda_packages = _create_str_from_package_list(image_spec.conda_packages)
run_commands = _create_str_from_package_list(image_spec.commands)
conda_channels = _create_str_from_package_list(image_spec.conda_channels)
apt_packages = _create_str_from_package_list(image_spec.apt_packages)
env = {"PYTHONPATH": "/root", _F_IMG_ID: image_spec.image_name()}
commands = [] if image_spec.commands is None else image_spec.commands

if image_spec.env:
env.update(image_spec.env)
Expand All @@ -70,16 +78,20 @@ def create_envd_config(image_spec: ImageSpec) -> str:
def build():
base(image="{base_image}", dev=False)
run(commands=[{', '.join(map(str, map(lambda x: f'"{x}"', commands)))}])
install.python_packages(name=[{', '.join(map(str, map(lambda x: f'"{x}"', packages)))}])
install.apt_packages(name=[{', '.join(map(str, map(lambda x: f'"{x}"', apt_packages)))}])
run(commands=[{run_commands}])
install.python_packages(name=[{python_packages}])
install.apt_packages(name=[{apt_packages}])
runtime.environ(env={env}, extra_path=['/root'])
config.pip_index(url="{pip_index}")
"""
ctx = context_manager.FlyteContextManager.current_context()
cfg_path = ctx.file_access.get_random_local_path("build.envd")
pathlib.Path(cfg_path).parent.mkdir(parents=True, exist_ok=True)

if conda_packages:
envd_config += " install.conda(use_mamba=True)\n"
envd_config += f" install.conda_packages(name=[{conda_packages}], channel=[{conda_channels}])\n"

if image_spec.requirements:
requirement_path = f"{pathlib.Path(cfg_path).parent}{os.sep}{REQUIREMENTS_FILE_NAME}"
shutil.copyfile(image_spec.requirements, requirement_path)
Expand Down
35 changes: 35 additions & 0 deletions plugins/flytekit-envd/tests/test_image_spec.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from pathlib import Path
from textwrap import dedent

from flytekitplugins.envd.image_builder import EnvdImageSpecBuilder, create_envd_config

Expand Down Expand Up @@ -35,3 +36,37 @@ def build():
install.python(version="3.8")
"""
)


def test_image_spec_conda():
image_spec = ImageSpec(
base_image="ubuntu:20.04",
python_version="3.11",
packages=["flytekit"],
conda_packages=["pytorch", "cpuonly"],
conda_channels=["pytorch"],
)

EnvdImageSpecBuilder().build_image(image_spec)
config_path = create_envd_config(image_spec)
assert image_spec.platform == "linux/amd64"
image_name = image_spec.image_name()
contents = Path(config_path).read_text()
expected_contents = dedent(
f"""\
# syntax=v1
def build():
base(image="ubuntu:20.04", dev=False)
run(commands=[])
install.python_packages(name=["flytekit"])
install.apt_packages(name=[])
runtime.environ(env={{'PYTHONPATH': '/root', '_F_IMG_ID': '{image_name}'}}, extra_path=['/root'])
config.pip_index(url="https://pypi.org/simple")
install.conda(use_mamba=True)
install.conda_packages(name=["pytorch", "cpuonly"], channel=["pytorch"])
install.python(version="3.11")
"""
)

assert contents == expected_contents
4 changes: 2 additions & 2 deletions tests/flytekit/unit/cli/pyflyte/test_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,9 +287,9 @@ def test_list_default_arguments(wf_path):
)

ic_result_4 = ImageConfig(
default_image=Image(name="default", fqn="flytekit", tag="tbcFqCcdAEyJqPcyYsQ15A.."),
default_image=Image(name="default", fqn="flytekit", tag="DgQMqIi61py4I4P5iOeS0Q.."),
images=[
Image(name="default", fqn="flytekit", tag="tbcFqCcdAEyJqPcyYsQ15A.."),
Image(name="default", fqn="flytekit", tag="DgQMqIi61py4I4P5iOeS0Q.."),
Image(name="xyz", fqn="docker.io/xyz", tag="latest"),
Image(name="abc", fqn="docker.io/abc", tag=None),
],
Expand Down

0 comments on commit 892b474

Please sign in to comment.