Skip to content

Commit

Permalink
make caching opt in
Browse files Browse the repository at this point in the history
Signed-off-by: Ayush Kamat <[email protected]>
  • Loading branch information
ayushkamat committed Nov 27, 2023
1 parent b1f536c commit 420ef84
Show file tree
Hide file tree
Showing 4 changed files with 37 additions and 16 deletions.
13 changes: 13 additions & 0 deletions latch_cli/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,17 @@ def dockerfile(pkg_root: str, snakemake: bool = False):
default=None,
help="Path to a Snakefile to register.",
)
@click.option(
"--cache-tasks/--no-cache-tasks",
"-c/-C",
is_flag=True,
default=False,
type=bool,
help=(
"Whether or not to cache snakemake tasks. Ignored if --snakefile is not"
" provided."
),
)
@requires_login
def register(
pkg_root: str,
Expand All @@ -174,6 +185,7 @@ def register(
docker_progress: str,
yes: bool,
snakefile: Optional[Path],
cache_tasks: bool,
):
"""Register local workflow code to Latch.
Expand All @@ -196,6 +208,7 @@ def register(
progress_plain=(docker_progress == "auto" and not sys.stdout.isatty())
or docker_progress == "plain",
use_new_centromere=use_new_centromere,
cache_tasks=cache_tasks,
)


Expand Down
7 changes: 5 additions & 2 deletions latch_cli/services/register/register.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,7 @@ def _build_and_serialize(
dockerfile: Optional[Path] = None,
*,
progress_plain: bool = False,
cache_tasks: bool = False,
):
assert ctx.pkg_root is not None

Expand All @@ -207,7 +208,7 @@ def _build_and_serialize(
from ...snakemake.serialize import generate_jit_register_code
from ...snakemake.workflow import build_jit_register_wrapper

jit_wf = build_jit_register_wrapper()
jit_wf = build_jit_register_wrapper(cache_tasks)
generate_jit_register_code(
jit_wf,
ctx.pkg_root,
Expand Down Expand Up @@ -268,7 +269,8 @@ def register(
skip_confirmation: bool = False,
snakefile: Optional[Path] = None,
*,
progress_plain=False,
progress_plain: bool = False,
cache_tasks: bool = False,
use_new_centromere: bool = False,
):
"""Registers a workflow, defined as python code, with Latch.
Expand Down Expand Up @@ -415,6 +417,7 @@ def register(
td,
dockerfile=ctx.default_container.dockerfile,
progress_plain=progress_plain,
cache_tasks=cache_tasks,
)

if remote and snakefile is None:
Expand Down
2 changes: 2 additions & 0 deletions latch_cli/snakemake/serialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,7 @@ def extract_snakemake_workflow(
jit_exec_display_name: str,
local_to_remote_path_mapping: Optional[Dict[str, str]] = None,
non_blob_parameters: Optional[Dict[str, Any]] = None,
cache_tasks: bool = False,
) -> SnakemakeWorkflow:
extractor = snakemake_workflow_extractor(pkg_root, snakefile, non_blob_parameters)
with extractor:
Expand All @@ -223,6 +224,7 @@ def extract_snakemake_workflow(
jit_wf_version,
jit_exec_display_name,
local_to_remote_path_mapping,
cache_tasks,
)
wf.compile()

Expand Down
31 changes: 17 additions & 14 deletions latch_cli/snakemake/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,9 +289,9 @@ def interface_to_parameters(
class JITRegisterWorkflow(WorkflowBase, ClassStorageTaskResolver):
out_parameter_name = "o0" # must be "o0"

def __init__(
self,
):
def __init__(self, cache_tasks: bool = False):
self.cache_tasks = cache_tasks

assert metadata._snakemake_metadata is not None

parameter_metadata = metadata._snakemake_metadata.parameters
Expand Down Expand Up @@ -495,7 +495,7 @@ def get_fn_code(
print(f"JIT Workflow Version: {{jit_wf_version}}")
print(f"JIT Execution Display Name: {{jit_exec_display_name}}")
wf = extract_snakemake_workflow(pkg_root, snakefile, jit_wf_version, jit_exec_display_name, local_to_remote_path_mapping, non_blob_parameters)
wf = extract_snakemake_workflow(pkg_root, snakefile, jit_wf_version, jit_exec_display_name, local_to_remote_path_mapping, non_blob_parameters, {self.cache_tasks})
wf_name = wf.name
generate_snakemake_entrypoint(wf, pkg_root, snakefile, {repr(remote_output_url)}, non_blob_parameters)
Expand All @@ -518,7 +518,7 @@ def get_fn_code(
protos = _recursive_list(td)
reg_resp = register_serialized_pkg(protos, None, version, account_id)
_print_reg_resp(reg_resp, image_name, silent=True)
_print_reg_resp(reg_resp, image_name)
wf_spec_remote = f"latch:///.snakemake_latch/workflows/{wf_name}/{version}/spec"
spec_dir = Path("spec")
Expand Down Expand Up @@ -603,6 +603,7 @@ def __init__(
jit_wf_version: str,
jit_exec_display_name: str,
local_to_remote_path_mapping: Optional[Dict[str, str]] = None,
cache_tasks: bool = False,
):
assert metadata._snakemake_metadata is not None
name = metadata._snakemake_metadata.name
Expand All @@ -622,6 +623,7 @@ def __init__(
self.return_files = return_files
self._input_parameters = None
self._dag = dag
self._cache_tasks = cache_tasks
self.snakemake_tasks: List[SnakemakeJobTask] = []

workflow_metadata = WorkflowMetadata(
Expand Down Expand Up @@ -742,15 +744,16 @@ def compile(self, **kwargs):
if getattr(task, "_metadata") is None:
task._metadata = TaskMetadata()

task._metadata.cache = True
task._metadata.cache_serialize = True
if self._cache_tasks:
task._metadata.cache = True
task._metadata.cache_serialize = True

hash = hashlib.new("sha256")
hash.update(job.properties().encode())
if job.is_script:
hash.update(Path(job.rule.script).read_bytes())
hash = hashlib.new("sha256")
hash.update(job.properties().encode())
if job.is_script:
hash.update(Path(job.rule.script).read_bytes())

task._metadata.cache_version = hash.hexdigest()
task._metadata.cache_version = hash.hexdigest()

self.snakemake_tasks.append(task)

Expand Down Expand Up @@ -831,8 +834,8 @@ def execute(self, **kwargs):
return exception_scopes.user_entry_point(self._workflow_function)(**kwargs)


def build_jit_register_wrapper() -> JITRegisterWorkflow:
wrapper_wf = JITRegisterWorkflow()
def build_jit_register_wrapper(cache_tasks: bool = False) -> JITRegisterWorkflow:
wrapper_wf = JITRegisterWorkflow(cache_tasks)
out_parameter_name = wrapper_wf.out_parameter_name

python_interface = wrapper_wf.python_interface
Expand Down

0 comments on commit 420ef84

Please sign in to comment.