Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

snakemake caching #354

Merged
merged 2 commits into from
Nov 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions latch_cli/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,17 @@ def dockerfile(pkg_root: str, snakemake: bool = False):
default=None,
help="Path to a Snakefile to register.",
)
@click.option(
"--cache-tasks/--no-cache-tasks",
"-c/-C",
is_flag=True,
default=False,
type=bool,
help=(
"Whether or not to cache snakemake tasks. Ignored if --snakefile is not"
" provided."
),
)
@requires_login
def register(
pkg_root: str,
Expand All @@ -174,6 +185,7 @@ def register(
docker_progress: str,
yes: bool,
snakefile: Optional[Path],
cache_tasks: bool,
):
"""Register local workflow code to Latch.

Expand All @@ -196,6 +208,7 @@ def register(
progress_plain=(docker_progress == "auto" and not sys.stdout.isatty())
or docker_progress == "plain",
use_new_centromere=use_new_centromere,
cache_tasks=cache_tasks,
)


Expand Down
7 changes: 5 additions & 2 deletions latch_cli/services/register/register.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,7 @@ def _build_and_serialize(
dockerfile: Optional[Path] = None,
*,
progress_plain: bool = False,
cache_tasks: bool = False,
):
assert ctx.pkg_root is not None

Expand All @@ -207,7 +208,7 @@ def _build_and_serialize(
from ...snakemake.serialize import generate_jit_register_code
from ...snakemake.workflow import build_jit_register_wrapper

jit_wf = build_jit_register_wrapper()
jit_wf = build_jit_register_wrapper(cache_tasks)
generate_jit_register_code(
jit_wf,
ctx.pkg_root,
Expand Down Expand Up @@ -268,7 +269,8 @@ def register(
skip_confirmation: bool = False,
snakefile: Optional[Path] = None,
*,
progress_plain=False,
progress_plain: bool = False,
cache_tasks: bool = False,
use_new_centromere: bool = False,
):
"""Registers a workflow, defined as python code, with Latch.
Expand Down Expand Up @@ -415,6 +417,7 @@ def register(
td,
dockerfile=ctx.default_container.dockerfile,
progress_plain=progress_plain,
cache_tasks=cache_tasks,
)

if remote and snakefile is None:
Expand Down
2 changes: 2 additions & 0 deletions latch_cli/snakemake/serialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,7 @@ def extract_snakemake_workflow(
jit_exec_display_name: str,
local_to_remote_path_mapping: Optional[Dict[str, str]] = None,
non_blob_parameters: Optional[Dict[str, Any]] = None,
cache_tasks: bool = False,
) -> SnakemakeWorkflow:
extractor = snakemake_workflow_extractor(pkg_root, snakefile, non_blob_parameters)
with extractor:
Expand All @@ -223,6 +224,7 @@ def extract_snakemake_workflow(
jit_wf_version,
jit_exec_display_name,
local_to_remote_path_mapping,
cache_tasks,
)
wf.compile()

Expand Down
51 changes: 35 additions & 16 deletions latch_cli/snakemake/workflow.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
import hashlib
import importlib
import json
import sys
import textwrap
import typing
from dataclasses import dataclass, is_dataclass
from enum import Enum
from dataclasses import dataclass
from pathlib import Path
from typing import (
Any,
Expand All @@ -17,8 +16,6 @@
Type,
TypeVar,
Union,
get_args,
get_origin,
)
from urllib.parse import urlparse

Expand All @@ -27,8 +24,9 @@
import snakemake.jobs
from flytekit.configuration import SerializationSettings
from flytekit.core import constants as _common_constants
from flytekit.core.base_task import TaskMetadata
from flytekit.core.class_based_resolver import ClassStorageTaskResolver
from flytekit.core.context_manager import FlyteContext, FlyteContextManager
from flytekit.core.context_manager import FlyteContextManager
from flytekit.core.docstring import Docstring
from flytekit.core.interface import Interface, transform_interface_to_typed_interface
from flytekit.core.node import Node
Expand Down Expand Up @@ -66,7 +64,7 @@
from latch.resources.tasks import custom_task
from latch.types.directory import LatchDir
from latch.types.file import LatchFile
from latch_cli.snakemake.config.utils import is_primitive_type, type_repr
from latch_cli.snakemake.config.utils import type_repr

from ..utils import identifier_suffix_from_str

Expand Down Expand Up @@ -291,9 +289,9 @@ def interface_to_parameters(
class JITRegisterWorkflow(WorkflowBase, ClassStorageTaskResolver):
out_parameter_name = "o0" # must be "o0"

def __init__(
self,
):
def __init__(self, cache_tasks: bool = False):
self.cache_tasks = cache_tasks

assert metadata._snakemake_metadata is not None

parameter_metadata = metadata._snakemake_metadata.parameters
Expand Down Expand Up @@ -497,7 +495,7 @@ def get_fn_code(
print(f"JIT Workflow Version: {{jit_wf_version}}")
print(f"JIT Execution Display Name: {{jit_exec_display_name}}")

wf = extract_snakemake_workflow(pkg_root, snakefile, jit_wf_version, jit_exec_display_name, local_to_remote_path_mapping, non_blob_parameters)
wf = extract_snakemake_workflow(pkg_root, snakefile, jit_wf_version, jit_exec_display_name, local_to_remote_path_mapping, non_blob_parameters, {self.cache_tasks})
wf_name = wf.name
generate_snakemake_entrypoint(wf, pkg_root, snakefile, {repr(remote_output_url)}, non_blob_parameters)

Expand All @@ -520,7 +518,7 @@ def get_fn_code(

protos = _recursive_list(td)
reg_resp = register_serialized_pkg(protos, None, version, account_id)
# _print_reg_resp(reg_resp, image_name, silent=True)
_print_reg_resp(reg_resp, image_name)

wf_spec_remote = f"latch:///.snakemake_latch/workflows/{wf_name}/{version}/spec"
spec_dir = Path("spec")
Expand Down Expand Up @@ -605,6 +603,7 @@ def __init__(
jit_wf_version: str,
jit_exec_display_name: str,
local_to_remote_path_mapping: Optional[Dict[str, str]] = None,
cache_tasks: bool = False,
):
assert metadata._snakemake_metadata is not None
name = metadata._snakemake_metadata.name
Expand All @@ -624,6 +623,7 @@ def __init__(
self.return_files = return_files
self._input_parameters = None
self._dag = dag
self._cache_tasks = cache_tasks
self.snakemake_tasks: List[SnakemakeJobTask] = []

workflow_metadata = WorkflowMetadata(
Expand Down Expand Up @@ -652,6 +652,7 @@ def compile(self, **kwargs):

target_files = [x for job in self._dag.targetjobs for x in job.input]

node_id = 0
for layer in self._dag.toposorted():
for job in layer:
assert isinstance(job, snakemake.jobs.Job)
Expand Down Expand Up @@ -739,6 +740,21 @@ def compile(self, **kwargs):
is_target=is_target,
interface=interface,
)

if getattr(task, "_metadata") is None:
task._metadata = TaskMetadata()

if self._cache_tasks:
task._metadata.cache = True
task._metadata.cache_serialize = True

hash = hashlib.new("sha256")
hash.update(job.properties().encode())
if job.is_script:
hash.update(Path(job.rule.script).read_bytes())

task._metadata.cache_version = hash.hexdigest()

self.snakemake_tasks.append(task)

typed_interface = transform_interface_to_typed_interface(interface)
Expand Down Expand Up @@ -776,14 +792,16 @@ def compile(self, **kwargs):
upstream_nodes.append(node_map[x.jobid])

node = Node(
id=f"n{job.jobid}",
id=f"n{node_id}",
metadata=task.construct_node_metadata(),
bindings=sorted(bindings, key=lambda b: b.var),
upstream_nodes=upstream_nodes,
flyte_entity=task,
)
node_map[job.jobid] = node

node_id += 1

bindings: List[literals_models.Binding] = []
for i, out in enumerate(self.interface.outputs.keys()):
upstream_id, upstream_var = self.find_upstream_node_matching_output_var(out)
Expand Down Expand Up @@ -816,8 +834,8 @@ def execute(self, **kwargs):
return exception_scopes.user_entry_point(self._workflow_function)(**kwargs)


def build_jit_register_wrapper() -> JITRegisterWorkflow:
wrapper_wf = JITRegisterWorkflow()
def build_jit_register_wrapper(cache_tasks: bool = False) -> JITRegisterWorkflow:
wrapper_wf = JITRegisterWorkflow(cache_tasks)
out_parameter_name = wrapper_wf.out_parameter_name

python_interface = wrapper_wf.python_interface
Expand Down Expand Up @@ -1132,10 +1150,11 @@ def get_fn_return_stmt(self, remote_output_url: Optional[str] = None):
)

if not self._is_target:
remote_path = f"latch:///.snakemake_latch/workflows/{self.wf.name}/task_outputs/{self.wf.jit_wf_version}/{self.wf.jit_exec_display_name}/{self.name}/{out_name}"
results.append(
reindent(
rf"""
{out_name}={out_type.__name__}("{target_path}")
{out_name}={out_type.__name__}({repr(target_path)}, {repr(remote_path)})
""",
2,
).rstrip()
Expand Down
Loading