Skip to content

Commit

Permalink
test stuff
Browse files Browse the repository at this point in the history
  • Loading branch information
wild-endeavor committed Aug 31, 2023
1 parent c0b954f commit 1570071
Show file tree
Hide file tree
Showing 5 changed files with 192 additions and 109 deletions.
123 changes: 78 additions & 45 deletions flytekit/core/artifact.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,21 @@

from flyteidl.artifact import artifacts_pb2
from flyteidl.core.identifier_pb2 import (
ArtifactAlias,
ArtifactID,
ArtifactKey,
ArtifactQuery,
ArtifactTag,
Partitions,
TaskExecutionIdentifier,
WorkflowExecutionIdentifier,
)
from flyteidl.core.literals_pb2 import Literal
from flyteidl.core.types_pb2 import LiteralType

from flytekit.core.context_manager import FlyteContextManager
from flytekit.loggers import logger
from flytekit.models.literals import Literal
from flytekit.models.types import LiteralType

if typing.TYPE_CHECKING:
from flytekit.remote.remote import FlyteRemote
Expand All @@ -34,22 +38,23 @@ class Artifact(object):
Control creation parameters at task/workflow execution time ::
@task
def t1() -> Annotated[nn.Module, Artifact(name="my.artifact.name", tags={type: "validation"},
aliases={"version": "latest", "semver": "1.0.0"})]:
def t1() -> Annotated[nn.Module, Artifact(name="my.artifact.name",
tags=["latest", "1.0.0"])]:
...
"""

def __init__(
self,
project: Optional[str] = None,
domain: Optional[str] = None,
suffix: Optional[str] = None,
name: Optional[str] = None,
version: Optional[str] = None,
partitions: Optional[typing.Dict[str, str]] = None,
tags: Optional[typing.List[str]] = None,
python_val: Optional[typing.Any] = None,
python_type: Optional[typing.Type] = None,
literal: Optional[Literal] = None,
literal_type: Optional[LiteralType] = None,
aliases: Optional[typing.List[str]] = None,
short_description: Optional[str] = None,
long_description: Optional[str] = None,
source: Optional[typing.Union[WorkflowExecutionIdentifier, TaskExecutionIdentifier, str]] = None,
Expand All @@ -61,30 +66,29 @@ def __init__(
:param project:
:param domain:
:param suffix: The key portion of the key value store. We expect users to not be too concerned with this.
:param name: Name is special because it doesn't exist in the IDL. In the backend, the primary uniqueness
constraint is project/domain/key (aka suffix). But the suffix is often not user-friendly so expose a
name field instead that resolves to an Alias.
:param name: The name of the Artifact.
:param version: Version of the Artifact, typically the execution ID.
"""
self.project = project
self.domain = domain
self.name = name
self.suffix = suffix
self.version = version
self.partitions = partitions or None # Don't let users set empty partitions
self.python_val = python_val
self.python_type = python_type
self.literal = literal
self.literal_type = literal_type
self.aliases = aliases
self.tags = tags
self.short_description = short_description
self.long_description = long_description
self.source = source

def __str__(self):
return (
f"Artifact: project={self.project}, domain={self.domain}, suffix={self.suffix}\n"
f"Artifact: project={self.project}, domain={self.domain}, name={self.name}, version={self.version}\n"
f" name={self.name}\n"
f" aliases={self.aliases}\n"
f" partitions={self.partitions}\n"
f" tags={self.tags}\n"
f" literal_type={self.literal_type}, literal={self.literal})"
)

Expand All @@ -100,15 +104,17 @@ def artifact_id(self) -> Optional[ArtifactID]:
artifact_key=ArtifactKey(
project=self.project,
domain=self.domain,
suffix=self.suffix,
name=self.name,
),
version=self.version,
partitions=self.partitions,
)

@classmethod
def get(
cls,
uri: Optional[str],
artifact_id: Optional[artifacts_pb2.ArtifactID],
artifact_id: Optional[ArtifactID],
remote: FlyteRemote,
get_details: bool = False,
) -> Optional[Artifact]:
Expand All @@ -122,25 +128,42 @@ def get(

def as_query(self, project: Optional[str] = None, domain: Optional[str] = None) -> ArtifactQuery:
"""
model_artifact = Artifact(name="models.nn.lidar", alias=["latest"])
model_artifact = Artifact(name="models.nn.lidar", tags=["latest"])
@task
def t1() -> Annotated[nn.Module, model_artifact]: ...
@workflow
def wf(model: nn.Module = model_artifact.as_query()): ...
"""
# todo: add artifact by ID or key when added to IDL
if not self.name or not self.aliases:
raise ValueError(f"Cannot bind artifact {self} as query, name or aliases are missing")
if (not self.project and not project) or (not self.domain and not domain):
raise ValueError(f"Cannot bind artifact {self} as query, project or domain are missing")

# just use the first alias for now
return ArtifactQuery(
project=project or self.project,
domain=domain or self.domain,
alias=ArtifactAlias(name=self.name, value=self.aliases[0]),
# if (not self.project and not project) or (not self.domain and not domain):
# raise ValueError(f"Cannot bind artifact {self} as query, project or domain are missing")
if not self.name:
raise ValueError(f"Cannot bind artifact {self} as query, name is missing")
ak = ArtifactKey(
project=project or self.project or None,
domain=domain or self.domain or None,
name=self.name,
)
if self.tags:
# If tags are present, assume it's a query by tag
if len(self.tags) > 1:
logger.warning(f"Multiple tags specified: {self.tags}, only using the first one")
return ArtifactQuery(
artifact_tag=ArtifactTag(
artifact_key=ak,
value=self.tags[0],
)
)
else:
# Otherwise assume it's a query by ArtifactID - keep in mind not all fields are specified, if it is, it's
# just a fully specified get request, not a search of any kind.
return ArtifactQuery(
artifact_id=ArtifactID(
artifact_key=ak,
version=self.version,
partitions=Partitions(value=self.partitions) if self.partitions else None,
)
)

def download(self):
"""
Expand Down Expand Up @@ -168,7 +191,7 @@ def initialize(
python_type: typing.Type,
name: Optional[str] = None,
literal_type: Optional[LiteralType] = None,
aliases: Optional[typing.List[str]] = None,
tags: Optional[typing.List[str]] = None,
) -> Artifact:
"""
Use this for when you have a Python value you want to get an Artifact object out of.
Expand All @@ -178,7 +201,7 @@ def initialize(
remote.create_artifact(Artifact.initialize(...))
Artifact.initialize("/path/to/file", tags={"tag1": "val1"})
Artifact.initialize("/path/to/parquet", type=pd.DataFrame, aliases={"ver": "0.1.0"})
Artifact.initialize("/path/to/parquet", type=pd.DataFrame, tags=["0.1.0"])
What's set here is everything that isn't set by the server. What is set by the server?
- name, version, if not set by user.
Expand All @@ -191,58 +214,68 @@ def initialize(
python_val=python_val,
python_type=python_type,
literal_type=literal_type,
aliases=aliases,
tags=tags,
name=name,
)

@property
def as_artifact_id(self) -> ArtifactID:
if self.name is None or self.project is None or self.domain is None or self.version is None:
raise ValueError("Cannot create artifact id without name, project, domain, version")
return self.to_flyte_idl().artifact_id

def to_flyte_idl(self) -> artifacts_pb2.Artifact:
"""
Converts this object to the IDL representation.
This is here instead of translator because it's in the interface, a relatively simple proto object
that's exposed to the user.
todo: where is this called besides in as_artifact_id?
"""
return artifacts_pb2.Artifact(
artifact_id=ArtifactID(
artifact_key=ArtifactKey(
project=self.project,
domain=self.domain,
suffix=self.suffix,
name=self.name,
),
version=self.version,
partitions=Partitions(value=self.partitions) if self.partitions else None,
),
spec=artifacts_pb2.ArtifactSpec(aliases=[ArtifactAlias(name=self.name, value=a) for a in self.aliases]),
spec=artifacts_pb2.ArtifactSpec(),
tags=self.tags,
)

def as_create_request(self) -> artifacts_pb2.CreateArtifactRequest:
if not self.project or not self.domain:
raise ValueError("Project and domain are required to create an artifact")
suffix = self.suffix or UUID(int=random.getrandbits(128)).hex
ak = ArtifactKey(project=self.project, domain=self.domain, suffix=suffix)
name = self.name or UUID(int=random.getrandbits(128)).hex
ak = ArtifactKey(project=self.project, domain=self.domain, name=name)

spec = artifacts_pb2.ArtifactSpec(
value=self.literal,
type=self.literal_type,
aliases=[ArtifactAlias(name=self.name, value=a) for a in self.aliases],
)
return artifacts_pb2.CreateArtifactRequest(artifact_key=ak, spec=spec)
partitions = self.partitions
tag = self.tags[0] if self.tags else None
return artifacts_pb2.CreateArtifactRequest(artifact_key=ak, spec=spec, partitions=partitions, tag=tag)

@classmethod
def from_flyte_idl(cls, pb2: artifacts_pb2.Artifact) -> Artifact:
"""
Converts the IDL representation to this object.
"""
from flytekit.models.literals import Literal
from flytekit.models.types import LiteralType

aliases = [a.value for a in pb2.spec.aliases] if len(pb2.spec.aliases) > 0 else None
alias_name = pb2.spec.aliases[0].name if len(pb2.spec.aliases) > 0 else None
tags = [t for t in pb2.tags] if pb2.tags else None
a = Artifact(
project=pb2.artifact_id.artifact_key.project,
domain=pb2.artifact_id.artifact_key.domain,
suffix=pb2.artifact_id.artifact_key.suffix,
name=alias_name,
aliases=aliases,
name=pb2.artifact_id.artifact_key.name,
version=pb2.artifact_id.version,
tags=tags,
literal_type=LiteralType.from_flyte_idl(pb2.spec.type),
literal=Literal.from_flyte_idl(pb2.spec.value),
# source=pb2.spec.source, # todo: source isn't installed in artifact service yet
)
if pb2.artifact_id.HasField("partitions"):
a.partitions = pb2.artifact_id.partitions.value if len(pb2.artifact_id.partitions.value) > 0 else None

return a
43 changes: 23 additions & 20 deletions flytekit/core/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,13 +224,9 @@ def transform_inputs_to_parameters(
if isinstance(_default, identifier_pb2.ArtifactQuery):
params[k] = _interface_models.Parameter(var=v, required=False, artifact_query=_default)
elif isinstance(_default, Artifact):
# todo: move this and code in remote to Artifact. Add checks
ak = identifier_pb2.ArtifactKey(
project=_default.project, domain=_default.domain, suffix=_default.suffix
)
artifact_id = identifier_pb2.ArtifactID(artifact_key=ak)
artifact_id = _default.as_artifact_id
lit = Literal(artifact_id=artifact_id)
params[k] = _interface_models.Parameter(var=v, required=False, default=lit)
params[k] = _interface_models.Parameter(var=v, required=False) # fix this, placeholder
else:
required = _default is None
default_lv = None
Expand Down Expand Up @@ -359,25 +355,32 @@ def transform_variable_map(
return res


def detect_artifact(ts: typing.Tuple[typing.Any]) -> typing.List[identifier_pb2.ArtifactAlias]:
aliases = []
def detect_artifact(
ts: typing.Tuple[typing.Any],
) -> Tuple[Optional[identifier_pb2.ArtifactID], Optional[identifier_pb2.ArtifactTag]]:
"""
If the user wishes to control how Artifacts are created (i.e. naming them, etc.) this is where we pick it up and
store it in the interface. There are two fields, the ID and a tag. For this to take effect, the name field
must have been specified.
"""
for t in ts:
# TODO: Maybe make this an Alias object
if isinstance(t, Artifact):
if not t.name or len(t.aliases) == 0:
logger.info(f"Incorrect Artifact specified, skipping alias detection, {t}")
continue
for a in t.aliases:
if not isinstance(a, str):
logger.info(f"Aliases should be strings, skipping alias, {a}")
else:
aliases.append(identifier_pb2.ArtifactAlias(artifact_id=None, name=t.name, value=a))
return aliases
if isinstance(t, Artifact) and t.name:
if t.tags:
tag = identifier_pb2.ArtifactTag(value=t.tags[0])
else:
tag = None

artifact_id = t.to_flyte_idl().artifact_id

return artifact_id, tag

return None, None


def transform_type(x: type, description: Optional[str] = None) -> _interface_models.Variable:
artifact_id, tag = detect_artifact(get_args(x))
return _interface_models.Variable(
type=TypeEngine.to_literal_type(x), description=description, aliases=detect_artifact(get_args(x))
type=TypeEngine.to_literal_type(x), description=description, artifact_partial_id=artifact_id, artifact_tag=tag
)


Expand Down
Loading

0 comments on commit 1570071

Please sign in to comment.