diff --git a/latch_cli/snakemake/workflow.py b/latch_cli/snakemake/workflow.py index 500b5437..20bdf4ac 100644 --- a/latch_cli/snakemake/workflow.py +++ b/latch_cli/snakemake/workflow.py @@ -71,7 +71,7 @@ from typing_extensions import TypeAlias, TypedDict import latch.types.metadata as metadata -from latch.resources.tasks import custom_task +from latch.resources.tasks import _get_small_gpu_pod, custom_task from latch.types.directory import LatchDir from latch.types.file import LatchFile from latch_cli.snakemake.config.utils import type_repr @@ -1084,17 +1084,38 @@ def __init__( self._task_function = task_fn_placeholder limits = self.job.resources - cores = limits.get("cpus", 4) - # convert MB to GiB - mem = limits.get("mem_mb", 8589) * 1000 * 1000 // 1024 // 1024 // 1024 + self._uses_gpu = limits.get("nvidia_gpu") is not None + if self._uses_gpu and self.job.container_img_url is not None: + click.secho( + dedent(""" + GPU tasks within container images are not yet supported. To resolve, + use conda for all GPU tasks and OR add all dependencies to your main + Dockerfile. + """), + fg="red", + ) + raise click.exceptions.Exit(1) + + if not self._uses_gpu: + cores = limits.get("cpus", 4) + # convert MB to GiB + mem = limits.get("mem_mb", 8589) * 1000 * 1000 // 1024 // 1024 // 1024 + task_config = custom_task(cpu=cores, memory=mem).keywords["task_config"] + else: + if limits.get("nvidia_gpu") > 1: + click.secho( + "Multi-GPU tasks are not supported. Set nvidia_gpu=1", fg="red" + ) + raise click.exceptions.Exit(1) + task_config = _get_small_gpu_pod() super().__init__( task_type="sidecar", task_type_version=2, name=name, interface=interface, - task_config=custom_task(cpu=cores, memory=mem).keywords["task_config"], + task_config=task_config, task_resolver=SnakemakeJobTaskResolver(), ) @@ -1321,6 +1342,12 @@ def get_fn_code( snakefile_path_in_container, *(["--use-conda"] if self.wf._use_conda else []), *(["--use-singularity"] if self.wf._use_container else []), + # we use docker instead of singularity, so these are docker run arguments + *( + ["--singularity-args", "--gpus all"] + if self.wf._use_container and self._uses_gpu + else [] + ), "--target-jobs", *jobs_cli_args(jobs), "--allowed-rules",