Skip to content

Commit

Permalink
add gpu support for Snakemake when not running in containers
Browse files Browse the repository at this point in the history
  • Loading branch information
rahuldesai1 committed Jan 25, 2024
1 parent a523c52 commit 5e995d3
Showing 1 changed file with 32 additions and 5 deletions.
37 changes: 32 additions & 5 deletions latch_cli/snakemake/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@
from typing_extensions import TypeAlias, TypedDict

import latch.types.metadata as metadata
from latch.resources.tasks import custom_task
from latch.resources.tasks import _get_small_gpu_pod, custom_task
from latch.types.directory import LatchDir
from latch.types.file import LatchFile
from latch_cli.snakemake.config.utils import type_repr
Expand Down Expand Up @@ -1084,17 +1084,38 @@ def __init__(
self._task_function = task_fn_placeholder

limits = self.job.resources
cores = limits.get("cpus", 4)

# convert MB to GiB
mem = limits.get("mem_mb", 8589) * 1000 * 1000 // 1024 // 1024 // 1024
self._uses_gpu = limits.get("nvidia_gpu") is not None
if self._uses_gpu and self.job.container_img_url is not None:
click.secho(
dedent("""
GPU tasks within container images are not yet supported. To resolve,
use conda for all GPU tasks and OR add all dependencies to your main
Dockerfile.
"""),
fg="red",
)
raise click.exceptions.Exit(1)

if not self._uses_gpu:
cores = limits.get("cpus", 4)
# convert MB to GiB
mem = limits.get("mem_mb", 8589) * 1000 * 1000 // 1024 // 1024 // 1024
task_config = custom_task(cpu=cores, memory=mem).keywords["task_config"]
else:
if limits.get("nvidia_gpu") > 1:
click.secho(
"Multi-GPU tasks are not supported. Set nvidia_gpu=1", fg="red"
)
raise click.exceptions.Exit(1)
task_config = _get_small_gpu_pod()

super().__init__(
task_type="sidecar",
task_type_version=2,
name=name,
interface=interface,
task_config=custom_task(cpu=cores, memory=mem).keywords["task_config"],
task_config=task_config,
task_resolver=SnakemakeJobTaskResolver(),
)

Expand Down Expand Up @@ -1321,6 +1342,12 @@ def get_fn_code(
snakefile_path_in_container,
*(["--use-conda"] if self.wf._use_conda else []),
*(["--use-singularity"] if self.wf._use_container else []),
# we use docker instead of singularity, so these are docker run arguments
*(
["--singularity-args", "--gpus all"]
if self.wf._use_container and self._uses_gpu
else []
),
"--target-jobs",
*jobs_cli_args(jobs),
"--allowed-rules",
Expand Down

0 comments on commit 5e995d3

Please sign in to comment.