Skip to content

Commit

Permalink
save state
Browse files Browse the repository at this point in the history
Signed-off-by: Ayush Kamat <[email protected]>
  • Loading branch information
ayushkamat committed Jan 18, 2024
1 parent 5527ba8 commit a2670cb
Show file tree
Hide file tree
Showing 12 changed files with 138 additions and 543 deletions.
12 changes: 10 additions & 2 deletions latch/types/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -416,6 +416,15 @@ class SnakemakeFileParameter(SnakemakeParameter[Union[LatchFile, LatchDir]]):
"""


@dataclass
class NextflowParameter(Generic[T], LatchParameter):
type: Optional[Type[T]] = None
"""
The python type of the parameter.
"""
default: Optional[T] = None


@dataclass
class LatchMetadata:
"""Class for organizing workflow metadata
Expand Down Expand Up @@ -541,7 +550,7 @@ def __post_init__(self):
@dataclass
class NextflowMetadata(LatchMetadata):
name: Optional[str] = None
parameters: Dict[str, SnakemakeParameter] = field(default_factory=dict)
parameters: Dict[str, NextflowParameter] = field(default_factory=dict)

def __post_init__(self):
if self.name is None:
Expand All @@ -551,5 +560,4 @@ def __post_init__(self):
_nextflow_metadata = self


_snakemake_metadata: Optional[SnakemakeMetadata] = None
_nextflow_metadata: Optional[NextflowMetadata] = None
119 changes: 22 additions & 97 deletions latch_cli/extras/common/serialize.py
Original file line number Diff line number Diff line change
@@ -1,91 +1,27 @@
import hashlib
import importlib
import json
import re
import textwrap
import typing
from dataclasses import dataclass
from pathlib import Path
from typing import (
Any,
Dict,
Generator,
Iterable,
List,
Optional,
Tuple,
Type,
TypeVar,
Union,
get_args,
)
from urllib.parse import urlparse
from typing import Dict, Optional, Type, Union, get_args

import click
import snakemake
import snakemake.io
import snakemake.jobs
from flytekit import LaunchPlan
from flytekit.configuration import Image, ImageConfig, SerializationSettings
from flytekit.core import constants as _common_constants
from flytekit.core import constants as common_constants
from flytekit.core.base_task import PythonTask, TaskMetadata
from flytekit.core.class_based_resolver import ClassStorageTaskResolver
from flytekit.core.base_task import PythonTask
from flytekit.core.context_manager import FlyteContextManager
from flytekit.core.docstring import Docstring
from flytekit.core.interface import Interface, transform_interface_to_typed_interface
from flytekit.core.interface import Interface
from flytekit.core.node import Node
from flytekit.core.promise import NodeOutput, Promise
from flytekit.core.python_auto_container import (
DefaultTaskResolver,
PythonAutoContainerTask,
)
from flytekit.core.type_engine import TypeEngine
from flytekit.core.utils import _dnsify
from flytekit.core.workflow import (
WorkflowBase,
WorkflowFailurePolicy,
WorkflowMetadata,
WorkflowMetadataDefaults,
)
from flytekit.exceptions import scopes as exception_scopes
from flytekit.core.workflow import WorkflowBase
from flytekit.models import common as common_models
from flytekit.models import interface as interface_models
from flytekit.models import launch_plan as launch_plan_models
from flytekit.models import literals as literals_models
from flytekit.models import task as _task_models
from flytekit.models import task as task_models
from flytekit.models import types as type_models
from flytekit.models.admin import workflow as admin_workflow_models
from flytekit.models.core import identifier as identifier_model
from flytekit.models.core import workflow as workflow_model
from flytekit.models.core.types import BlobType
from flytekit.models.core.workflow import TaskNodeOverrides
from flytekit.models.literals import Blob, BlobMetadata, Literal, LiteralMap, Scalar
from flytekit.tools.serialize_helpers import persist_registrable_entities
from flytekitplugins.pod.task import (
_PRIMARY_CONTAINER_NAME_FIELD,
Pod,
_sanitize_resource_name,
)
from kubernetes.client import ApiClient
from kubernetes.client.models import V1Container, V1EnvVar, V1ResourceRequirements
from snakemake.dag import DAG
from snakemake.jobs import GroupJob, Job
from typing_extensions import TypeAlias, TypedDict

from latch_cli.utils import urljoins

RegistrableEntity = Union[
task_models.TaskSpec,
launch_plan_models.LaunchPlan,
admin_workflow_models.WorkflowSpec,
]


def should_register_with_admin(entity: RegistrableEntity) -> bool:
return isinstance(entity, get_args(RegistrableEntity))

from typing_extensions import TypeAlias

FlyteLocalEntity: TypeAlias = Union[
PythonTask,
Expand All @@ -104,28 +40,6 @@ def should_register_with_admin(entity: RegistrableEntity) -> bool:
EntityCache: TypeAlias = Dict[FlyteLocalEntity, FlyteSerializableModel]


def binding_data_from_python(
expected_literal_type: type_models.LiteralType,
t_value: typing.Any,
t_value_type: Optional[Type] = None,
) -> Optional[literals_models.BindingData]:
if isinstance(t_value, Promise):
if not t_value.is_ready:
return literals_models.BindingData(promise=t_value.ref)


def binding_from_python(
var_name: str,
expected_literal_type: type_models.LiteralType,
t_value: typing.Any,
t_value_type: Type,
) -> literals_models.Binding:
binding_data = binding_data_from_python(
expected_literal_type, t_value, t_value_type
)
return literals_models.Binding(var=var_name, binding=binding_data)


def transform_type(
x: Type, description: Optional[str] = None
) -> interface_models.Variable:
Expand Down Expand Up @@ -354,8 +268,19 @@ def get_serializable_workflow(
return admin_wf


def serialize_jit_register_workflow(
jit_wf: WorkflowBase,
RegistrableEntity = Union[
task_models.TaskSpec,
launch_plan_models.LaunchPlan,
admin_workflow_models.WorkflowSpec,
]


def should_register_with_admin(entity: RegistrableEntity) -> bool:
return isinstance(entity, get_args(RegistrableEntity))


def serialize(
wf: WorkflowBase,
output_dir: str,
image_name: str,
dkr_repo: str,
Expand All @@ -372,12 +297,12 @@ def serialize_jit_register_workflow(

registrable_entity_cache: EntityCache = {}

get_serializable_workflow(jit_wf, settings, registrable_entity_cache)
get_serializable_workflow(wf, settings, registrable_entity_cache)

parameter_map = interface_to_parameters(jit_wf.python_interface)
parameter_map = interface_to_parameters(wf.python_interface)
lp = LaunchPlan(
name=jit_wf.name,
workflow=jit_wf,
name=wf.name,
workflow=wf,
parameters=parameter_map,
fixed_inputs=literals_models.LiteralMap(literals={}),
)
Expand Down
8 changes: 8 additions & 0 deletions latch_cli/extras/common/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
import textwrap


# todo(maximsmol): use a stateful writer that keeps track of indent level
def reindent(x: str, level: int) -> str:
if x[0] == "\n":
x = x[1:]
return textwrap.indent(textwrap.dedent(x), " " * level)
48 changes: 2 additions & 46 deletions latch_cli/extras/nextflow/serialize.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,5 @@
import click
from flytekit import LaunchPlan
from flytekit.configuration import Image, ImageConfig, SerializationSettings
from flytekit.models import literals as literals_models
from flytekit.tools.serialize_helpers import persist_registrable_entities

from latch_cli.extras.common.serialize import serialize
from latch_cli.extras.nextflow.workflow import NextflowWorkflow
from latch_cli.extras.snakemake.serialize import should_register_with_admin
from latch_cli.extras.snakemake.serialize_utils import (
EntityCache,
get_serializable_launch_plan,
get_serializable_workflow,
)
from latch_cli.extras.snakemake.workflow import interface_to_parameters


def serialize_nf(
Expand All @@ -20,36 +8,4 @@ def serialize_nf(
image_name: str,
dkr_repo: str,
):
image_name_no_version, version = image_name.split(":")
default_img = Image(
name=image_name,
fqn=f"{dkr_repo}/{image_name_no_version}",
tag=version,
)
settings = SerializationSettings(
image_config=ImageConfig(default_image=default_img, images=[default_img]),
)

registrable_entity_cache: EntityCache = {}

get_serializable_workflow(nf_wf, settings, registrable_entity_cache)

parameter_map = interface_to_parameters(nf_wf.python_interface)
lp = LaunchPlan(
name=nf_wf.name,
workflow=nf_wf,
parameters=parameter_map,
fixed_inputs=literals_models.LiteralMap(literals={}),
)
admin_lp = get_serializable_launch_plan(lp, settings, registrable_entity_cache)

registrable_entities = [
x.to_flyte_idl()
for x in list(
filter(should_register_with_admin, list(registrable_entity_cache.values()))
)
+ [admin_lp]
]

click.secho("\nSerializing workflow entities", bold=True)
persist_registrable_entities(registrable_entities, output_dir)
serialize(nf_wf, output_dir, image_name, dkr_repo)
Loading

0 comments on commit a2670cb

Please sign in to comment.