diff --git a/.github/workflows/pythonbuild.yml b/.github/workflows/pythonbuild.yml index 81c68e8d5b..38e7f34132 100644 --- a/.github/workflows/pythonbuild.yml +++ b/.github/workflows/pythonbuild.yml @@ -65,7 +65,6 @@ jobs: - flytekit-k8s-pod - flytekit-kf-pytorch - flytekit-kf-tensorflow - - flytekit-papermill - flytekit-spark - flytekit-sqlalchemy - flytekit-pandera diff --git a/flytekit/bin/entrypoint.py b/flytekit/bin/entrypoint.py index 6153ed6fa9..19749f3e1b 100644 --- a/flytekit/bin/entrypoint.py +++ b/flytekit/bin/entrypoint.py @@ -37,11 +37,11 @@ from flytekit.interfaces import random as _flyte_random from flytekit.interfaces.data import data_proxy as _data_proxy from flytekit.interfaces.stats.taggable import get_stats as _get_stats -from flytekit.models import dynamic_job as _dynamic_job -from flytekit.models import literals as _literal_models +from flytekit.models.core import dynamic_job as _dynamic_job from flytekit.models.core import errors as _error_models from flytekit.models.core import execution as _execution_models from flytekit.models.core import identifier as _identifier +from flytekit.models.core import literals as _literal_models from flytekit.tools.fast_registration import download_distribution as _download_distribution from flytekit.tools.module_loader import load_object_from_module diff --git a/flytekit/clients/friendly.py b/flytekit/clients/friendly.py index a34bf31a48..ca121f9b03 100644 --- a/flytekit/clients/friendly.py +++ b/flytekit/clients/friendly.py @@ -13,16 +13,16 @@ from flyteidl.admin import workflow_pb2 as _workflow_pb2 from flytekit.clients.raw import RawSynchronousFlyteClient as _RawSynchronousFlyteClient -from flytekit.models import common as _common -from flytekit.models import execution as _execution from flytekit.models import filters as _filters -from flytekit.models import launch_plan as _launch_plan -from flytekit.models import node_execution as _node_execution -from flytekit.models import project as _project -from flytekit.models import task as _task from flytekit.models.admin import common as _admin_common +from flytekit.models.admin import execution as _execution +from flytekit.models.admin import launch_plan as _launch_plan +from flytekit.models.admin import node_execution as _node_execution +from flytekit.models.admin import project as _project +from flytekit.models.admin import task as _task from flytekit.models.admin import task_execution as _task_execution from flytekit.models.admin import workflow as _workflow +from flytekit.models.admin.common import NamedEntityIdentifier as _namedEntityIdentifier from flytekit.models.core import identifier as _identifier @@ -64,7 +64,7 @@ def create_task(self, task_identifer, task_spec): remains identical, calling this method multiple times will result in success. :param flytekit.models.core.identifier.Identifier task_identifer: The identifier for this task. - :param flytekit.models.task.TaskSpec task_spec: This is the actual definition of the task that + :param flytekit.models.admin.task.TaskSpec task_spec: This is the actual definition of the task that should be created. :raises flytekit.common.exceptions.user.FlyteEntityAlreadyExistsException: If an identical version of the task is found, this exception is raised. The client might choose to ignore this exception because the @@ -112,7 +112,7 @@ def list_task_ids_paginated(self, project, domain, limit=100, token=None, sort_b ) ) return ( - [_common.NamedEntityIdentifier.from_flyte_idl(identifier_pb) for identifier_pb in identifier_list.entities], + [_namedEntityIdentifier.from_flyte_idl(identifier_pb) for identifier_pb in identifier_list.entities], str(identifier_list.token), ) @@ -131,7 +131,7 @@ def list_tasks_paginated(self, identifier, limit=100, token=None, filters=None, If entries are added to the database between requests for different pages, it is possible to receive entries on the second page that also appeared on the first. - :param flytekit.models.common.NamedEntityIdentifier identifier: NamedEntityIdentifier to list. + :param flytekit.models.admin.common.NamedEntityIdentifier identifier: NamedEntityIdentifier to list. :param int limit: [Optional] The maximum number of entries to return. Must be greater than 0. The maximum page size is determined by the Flyte Admin Service configuration. If limit is greater than the maximum page size, an exception will be raised. @@ -167,7 +167,7 @@ def get_task(self, id): :param flytekit.models.core.identifier.Identifier id: The ID representing a given task. :raises: TODO - :rtype: flytekit.models.task.Task + :rtype: flytekit.models.admin.task.Task """ return _task.Task.from_flyte_idl( super(SynchronousFlyteClient, self).get_task(_common_pb2.ObjectGetRequest(id=id.to_flyte_idl())) @@ -241,7 +241,7 @@ def list_workflow_ids_paginated(self, project, domain, limit=100, token=None, so ) ) return ( - [_common.NamedEntityIdentifier.from_flyte_idl(identifier_pb) for identifier_pb in identifier_list.entities], + [_namedEntityIdentifier.from_flyte_idl(identifier_pb) for identifier_pb in identifier_list.entities], str(identifier_list.token), ) @@ -260,7 +260,7 @@ def list_workflows_paginated(self, identifier, limit=100, token=None, filters=No If entries are added to the database between requests for different pages, it is possible to receive entries on the second page that also appeared on the first. - :param flytekit.models.common.NamedEntityIdentifier identifier: NamedEntityIdentifier to list. + :param flytekit.models.admin.common.NamedEntityIdentifier identifier: NamedEntityIdentifier to list. :param int limit: [Optional] The maximum number of entries to return. Must be greater than 0. The maximum page size is determined by the Flyte Admin Service configuration. If limit is greater than the maximum page size, an exception will be raised. @@ -350,7 +350,7 @@ def get_active_launch_plan(self, identifier): Retrieves the active launch plan entity given a named entity identifier (project, domain, name). Raises an error if no active launch plan exists. - :param flytekit.models.common.NamedEntityIdentifier identifier: NamedEntityIdentifier to list. + :param flytekit.models.admin.common.NamedEntityIdentifier identifier: NamedEntityIdentifier to list. :rtype: flytekit.models.launch_plan.LaunchPlan """ return _launch_plan.LaunchPlan.from_flyte_idl( @@ -396,7 +396,7 @@ def list_launch_plan_ids_paginated(self, project, domain, limit=100, token=None, ) ) return ( - [_common.NamedEntityIdentifier.from_flyte_idl(identifier_pb) for identifier_pb in identifier_list.entities], + [_namedEntityIdentifier.from_flyte_idl(identifier_pb) for identifier_pb in identifier_list.entities], str(identifier_list.token), ) @@ -415,7 +415,7 @@ def list_launch_plans_paginated(self, identifier, limit=100, token=None, filters If entries are added to the database between requests for different pages, it is possible to receive entries on the second page that also appeared on the first. - :param flytekit.models.common.NamedEntityIdentifier identifier: NamedEntityIdentifier to list. + :param flytekit.models.admin.common.NamedEntityIdentifier identifier: NamedEntityIdentifier to list. :param int limit: [Optional] The maximum number of entries to return. Must be greater than 0. The maximum page size is determined by the Flyte Admin Service configuration. If limit is greater than the maximum page size, an exception will be raised. diff --git a/flytekit/clis/flyte_cli/main.py b/flytekit/clis/flyte_cli/main.py index cfce244c0a..2e0bf67f83 100644 --- a/flytekit/clis/flyte_cli/main.py +++ b/flytekit/clis/flyte_cli/main.py @@ -19,6 +19,7 @@ from google.protobuf.json_format import MessageToJson from google.protobuf.pyext.cpp_message import GeneratedProtocolMessageType as _GeneratedProtocolMessageType +import flytekit.models.admin.common from flytekit import __version__ from flytekit.clients import friendly as _friendly_client from flytekit.clis.helpers import construct_literal_map_from_parameter_map as _construct_literal_map_from_parameter_map @@ -38,27 +39,25 @@ from flytekit.configuration import set_flyte_config_file from flytekit.interfaces.data import data_proxy as _data_proxy from flytekit.interfaces.data.data_proxy import Data -from flytekit.models import common as _common_models from flytekit.models import filters as _filters -from flytekit.models import launch_plan as _launch_plan -from flytekit.models import literals as _literals -from flytekit.models import named_entity as _named_entity from flytekit.models.admin import common as _admin_common -from flytekit.models.common import AuthRole as _AuthRole -from flytekit.models.common import RawOutputDataConfig as _RawOutputDataConfig +from flytekit.models.admin import launch_plan as _launch_plan +from flytekit.models.admin.common import AuthRole as _AuthRole +from flytekit.models.admin.common import RawOutputDataConfig as _RawOutputDataConfig +from flytekit.models.admin.execution import ExecutionMetadata as _ExecutionMetadata +from flytekit.models.admin.execution import ExecutionSpec as _ExecutionSpec +from flytekit.models.admin.matchable_resource import ClusterResourceAttributes as _ClusterResourceAttributes +from flytekit.models.admin.matchable_resource import ExecutionClusterLabel as _ExecutionClusterLabel +from flytekit.models.admin.matchable_resource import ExecutionQueueAttributes as _ExecutionQueueAttributes +from flytekit.models.admin.matchable_resource import MatchableResource as _MatchableResource +from flytekit.models.admin.matchable_resource import MatchingAttributes as _MatchingAttributes +from flytekit.models.admin.matchable_resource import PluginOverride as _PluginOverride +from flytekit.models.admin.matchable_resource import PluginOverrides as _PluginOverrides +from flytekit.models.admin.project import Project as _Project +from flytekit.models.admin.schedule import Schedule as _Schedule from flytekit.models.core import execution as _core_execution_models from flytekit.models.core import identifier as _core_identifier -from flytekit.models.execution import ExecutionMetadata as _ExecutionMetadata -from flytekit.models.execution import ExecutionSpec as _ExecutionSpec -from flytekit.models.matchable_resource import ClusterResourceAttributes as _ClusterResourceAttributes -from flytekit.models.matchable_resource import ExecutionClusterLabel as _ExecutionClusterLabel -from flytekit.models.matchable_resource import ExecutionQueueAttributes as _ExecutionQueueAttributes -from flytekit.models.matchable_resource import MatchableResource as _MatchableResource -from flytekit.models.matchable_resource import MatchingAttributes as _MatchingAttributes -from flytekit.models.matchable_resource import PluginOverride as _PluginOverride -from flytekit.models.matchable_resource import PluginOverrides as _PluginOverrides -from flytekit.models.project import Project as _Project -from flytekit.models.schedule import Schedule as _Schedule +from flytekit.models.core import literals as _literals from flytekit.tools.fast_registration import get_additional_distribution_loc as _get_additional_distribution_loc try: # Python 3 @@ -699,7 +698,7 @@ def list_task_versions(project, domain, name, host, insecure, token, limit, show _click.echo("{:50} {:40}".format("Version", "Urn")) while True: task_list, next_token = client.list_tasks_paginated( - _common_models.NamedEntityIdentifier(project, domain, name), + flytekit.models.admin.common.NamedEntityIdentifier(project, domain, name), limit=limit, token=token, filters=[_filters.Filter.from_python_std(f) for f in filter], @@ -854,7 +853,7 @@ def list_workflow_versions(project, domain, name, host, insecure, token, limit, _click.echo("{:50} {:40}".format("Version", "Urn")) while True: wf_list, next_token = client.list_workflows_paginated( - _common_models.NamedEntityIdentifier(project, domain, name), + flytekit.models.admin.common.NamedEntityIdentifier(project, domain, name), limit=limit, token=token, filters=[_filters.Filter.from_python_std(f) for f in filter], @@ -1032,7 +1031,7 @@ def list_launch_plan_versions( while True: lp_list, next_token = client.list_launch_plans_paginated( - _common_models.NamedEntityIdentifier(project, domain, name), + flytekit.models.admin.common.NamedEntityIdentifier(project, domain, name), limit=limit, token=token, filters=[_filters.Filter.from_python_std(f) for f in filter], @@ -1100,7 +1099,7 @@ def get_active_launch_plan(project, domain, name, host, insecure): _welcome_message() client = _friendly_client.SynchronousFlyteClient(host, insecure=insecure) - lp = client.get_active_launch_plan(_common_models.NamedEntityIdentifier(project, domain, name)) + lp = client.get_active_launch_plan(flytekit.models.admin.common.NamedEntityIdentifier(project, domain, name)) _click.echo("Active Launch Plan for {}:{}:{}\n".format(_tt(project), _tt(domain), _tt(name))) _click.echo(lp) _click.echo("") @@ -2089,13 +2088,13 @@ def update_workflow_meta(description, state, host, insecure, project, domain, na _welcome_message() client = _friendly_client.SynchronousFlyteClient(host, insecure=insecure) if state == "active": - state = _named_entity.NamedEntityState.ACTIVE + state = flytekit.models.admin.common.NamedEntityState.ACTIVE elif state == "archived": - state = _named_entity.NamedEntityState.ARCHIVED + state = flytekit.models.admin.common.NamedEntityState.ARCHIVED client.update_named_entity( _core_identifier.ResourceType.WORKFLOW, - _named_entity.NamedEntityIdentifier(project, domain, name), - _named_entity.NamedEntityMetadata(description, state), + flytekit.models.admin.common.NamedEntityIdentifier(project, domain, name), + flytekit.models.admin.common.NamedEntityMetadata(description, state), ) _click.echo("Successfully updated workflow") @@ -2115,8 +2114,10 @@ def update_task_meta(description, host, insecure, project, domain, name): client = _friendly_client.SynchronousFlyteClient(host, insecure=insecure) client.update_named_entity( _core_identifier.ResourceType.TASK, - _named_entity.NamedEntityIdentifier(project, domain, name), - _named_entity.NamedEntityMetadata(description, _named_entity.NamedEntityState.ACTIVE), + flytekit.models.admin.common.NamedEntityIdentifier(project, domain, name), + flytekit.models.admin.common.NamedEntityMetadata( + description, flytekit.models.admin.common.NamedEntityState.ACTIVE + ), ) _click.echo("Successfully updated task") @@ -2136,8 +2137,10 @@ def update_launch_plan_meta(description, host, insecure, project, domain, name): client = _friendly_client.SynchronousFlyteClient(host, insecure=insecure) client.update_named_entity( _core_identifier.ResourceType.LAUNCH_PLAN, - _named_entity.NamedEntityIdentifier(project, domain, name), - _named_entity.NamedEntityMetadata(description, _named_entity.NamedEntityState.ACTIVE), + flytekit.models.admin.common.NamedEntityIdentifier(project, domain, name), + flytekit.models.admin.common.NamedEntityMetadata( + description, flytekit.models.admin.common.NamedEntityState.ACTIVE + ), ) _click.echo("Successfully updated launch plan") diff --git a/flytekit/clis/helpers.py b/flytekit/clis/helpers.py index c0ab7397c2..7c87a20519 100644 --- a/flytekit/clis/helpers.py +++ b/flytekit/clis/helpers.py @@ -8,7 +8,7 @@ from flytekit.clis.sdk_in_container.serialize import _DOMAIN_PLACEHOLDER, _PROJECT_PLACEHOLDER, _VERSION_PLACEHOLDER from flytekit.common.types.helpers import get_sdk_type_from_literal_type as _get_sdk_type_from_literal_type -from flytekit.models import literals as _literals +from flytekit.models.core import literals as _literals def construct_literal_map_from_variable_map(variable_dict, text_args): diff --git a/flytekit/clis/sdk_in_container/launch_plan.py b/flytekit/clis/sdk_in_container/launch_plan.py index e367fa08d8..df15bcb5eb 100644 --- a/flytekit/clis/sdk_in_container/launch_plan.py +++ b/flytekit/clis/sdk_in_container/launch_plan.py @@ -21,7 +21,7 @@ from flytekit.configuration.internal import PROJECT as _PROJECT from flytekit.configuration.internal import VERSION as _VERSION from flytekit.configuration.internal import look_up_version_from_image_tag as _look_up_version_from_image_tag -from flytekit.models import launch_plan as _launch_plan_model +from flytekit.models.admin import launch_plan as _launch_plan_model from flytekit.models.core import identifier as _identifier from flytekit.tools.module_loader import iterate_registerable_entities_in_order diff --git a/flytekit/clis/sdk_in_container/serialize.py b/flytekit/clis/sdk_in_container/serialize.py index f02526a5dd..c86e67dcbd 100644 --- a/flytekit/clis/sdk_in_container/serialize.py +++ b/flytekit/clis/sdk_in_container/serialize.py @@ -26,9 +26,9 @@ from flytekit.core.base_task import PythonTask from flytekit.core.launch_plan import LaunchPlan from flytekit.core.workflow import WorkflowBase -from flytekit.models import launch_plan as _launch_plan_models -from flytekit.models import task as task_models +from flytekit.models.admin import launch_plan as _launch_plan_models from flytekit.models.admin import workflow as admin_workflow_models +from flytekit.models.admin.task import TaskSpec as _taskSpec from flytekit.tools.fast_registration import compute_digest as _compute_digest from flytekit.tools.fast_registration import filter_tar_file_fn as _filter_tar_file_fn from flytekit.tools.module_loader import iterate_registerable_entities_in_order @@ -100,17 +100,15 @@ def _should_register_with_admin(entity) -> bool: This is used in the code below. The translator.py module produces lots of objects (namely nodes and BranchNodes) that do not/should not be written to .pb file to send to admin. This function filters them out. """ - return isinstance( - entity, (task_models.TaskSpec, _launch_plan_models.LaunchPlan, admin_workflow_models.WorkflowSpec) - ) + return isinstance(entity, (_taskSpec, _launch_plan_models.LaunchPlan, admin_workflow_models.WorkflowSpec)) -def _find_duplicate_tasks(tasks: typing.List[task_models.TaskSpec]) -> typing.Set[task_models.TaskSpec]: +def _find_duplicate_tasks(tasks: typing.List[_taskSpec]) -> typing.Set[_taskSpec]: """ Given a list of `TaskSpec`, this function returns a set containing the duplicated `TaskSpec` if any exists. """ seen: typing.Set[_identifier.Identifier] = set() - duplicate_tasks: typing.Set[task_models.TaskSpec] = set() + duplicate_tasks: typing.Set[_taskSpec] = set() for task in tasks: if task.template.id not in seen: seen.add(task.template.id) @@ -137,8 +135,8 @@ def get_registrable_entities(ctx: flyte_context.FlyteContext) -> typing.List: new_api_model_values = list(new_api_serializable_entities.values()) entities_to_be_serialized = list(filter(_should_register_with_admin, new_api_model_values)) - serializable_tasks: typing.List[task_models.TaskSpec] = [ - entity for entity in entities_to_be_serialized if isinstance(entity, task_models.TaskSpec) + serializable_tasks: typing.List[_taskSpec] = [ + entity for entity in entities_to_be_serialized if isinstance(entity, _taskSpec) ] # Detect if any of the tasks is duplicated. Duplicate tasks are defined as having the same metadata identifiers # (see :py:class:`flytekit.common.core.identifier.Identifier`). Duplicate tasks are considered invalid at registration diff --git a/flytekit/common/component_nodes.py b/flytekit/common/component_nodes.py index ea39a28dea..1f978d09ea 100644 --- a/flytekit/common/component_nodes.py +++ b/flytekit/common/component_nodes.py @@ -35,7 +35,7 @@ def promote_from_model(cls, base_model, tasks): engine. :param flytekit.models.core.workflow.TaskNode base_model: - :param dict[flytekit.models.core.identifier.Identifier, flytekit.models.task.TaskTemplate] tasks: + :param dict[flytekit.models.core.identifier.Identifier, flytekit.models.core.task.TaskTemplate] tasks: :rtype: SdkTaskNode """ from flytekit.common.tasks import task as _task @@ -121,7 +121,7 @@ def promote_from_model(cls, base_model, sub_workflows, tasks): :param flytekit.models.core.workflow.WorkflowNode base_model: :param dict[flytekit.models.core.identifier.Identifier, flytekit.models.core.workflow.WorkflowTemplate] sub_workflows: - :param dict[flytekit.models.core.identifier.Identifier, flytekit.models.task.TaskTemplate] tasks: + :param dict[flytekit.models.core.identifier.Identifier, flytekit.models.core.task.TaskTemplate] tasks: :rtype: SdkWorkflowNode """ # put the import statement here to prevent circular dependency error diff --git a/flytekit/common/interface.py b/flytekit/common/interface.py index c8f5160e15..14b5389af7 100644 --- a/flytekit/common/interface.py +++ b/flytekit/common/interface.py @@ -6,8 +6,8 @@ from flytekit.common.types import containers as _containers from flytekit.common.types import helpers as _type_helpers from flytekit.common.types import primitives as _primitives -from flytekit.models import interface as _interface_models -from flytekit.models import literals as _literal_models +from flytekit.models.core import interface as _interface_models +from flytekit.models.core import literals as _literal_models class BindingData(_literal_models.BindingData, metaclass=_sdk_bases.ExtendedSdkType): @@ -40,7 +40,7 @@ def promote_from_model(cls, model): @classmethod def from_python_std(cls, literal_type, t_value, upstream_nodes=None): """ - :param flytekit.models.types.LiteralType literal_type: + :param flytekit.models.core.types.LiteralType literal_type: :param T t_value: :param list[flytekit.common.nodes.SdkNode] upstream_nodes: [Optional] Keeps track of the nodes upstream, if applicable. diff --git a/flytekit/common/launch_plan.py b/flytekit/common/launch_plan.py index adb9ff979d..a75d41d544 100644 --- a/flytekit/common/launch_plan.py +++ b/flytekit/common/launch_plan.py @@ -5,6 +5,8 @@ import six as _six from deprecated import deprecated as _deprecated +import flytekit.models.admin.common +import flytekit.models.admin.launch_plan from flytekit.common import interface as _interface from flytekit.common import nodes as _nodes from flytekit.common import promise as _promises @@ -20,13 +22,14 @@ from flytekit.configuration import auth as _auth_config from flytekit.configuration import sdk as _sdk_config from flytekit.engines.flyte import engine as _flyte_engine -from flytekit.models import common as _common_models -from flytekit.models import execution as _execution_models -from flytekit.models import interface as _interface_models -from flytekit.models import launch_plan as _launch_plan_models -from flytekit.models import literals as _literal_models -from flytekit.models import schedule as _schedule_model +from flytekit.models.admin import common as _common +from flytekit.models.admin import execution as _execution_models +from flytekit.models.admin import launch_plan as _launch_plan_models +from flytekit.models.admin import schedule as _schedule_model +from flytekit.models.admin.common import NamedEntityIdentifier as _namedEntityIdentifier from flytekit.models.core import identifier as _identifier_model +from flytekit.models.core import interface as _interface_models +from flytekit.models.core import literals as _literal_models from flytekit.models.core import workflow as _workflow_models @@ -111,9 +114,7 @@ def fetch(cls, project, domain, name, version=None): if launch_plan_id.version: lp = _flyte_engine.get_client().get_launch_plan(launch_plan_id) else: - named_entity_id = _common_models.NamedEntityIdentifier( - launch_plan_id.project, launch_plan_id.domain, launch_plan_id.name - ) + named_entity_id = _namedEntityIdentifier(launch_plan_id.project, launch_plan_id.domain, launch_plan_id.name) lp = _flyte_engine.get_client().get_active_launch_plan(named_entity_id) sdk_lp = cls.promote_from_model(lp.spec) @@ -168,7 +169,7 @@ def is_scheduled(self): @property def auth_role(self): """ - :rtype: flytekit.models.common.AuthRole + :rtype: flytekit.models.admin.common.AuthRole """ fixed_auth = super(SdkLaunchPlan, self).auth_role if fixed_auth is not None and ( @@ -184,7 +185,7 @@ def auth_role(self): "Using deprecated `role` from config. Please update your config to use `assumable_iam_role` instead" ) assumable_iam_role = _sdk_config.ROLE.get() - return _common_models.AuthRole( + return flytekit.models.admin.common.AuthRole( assumable_iam_role=assumable_iam_role, kubernetes_service_account=kubernetes_service_account, ) @@ -224,14 +225,14 @@ def entity_type_text(self): @property def raw_output_data_config(self): """ - :rtype: flytekit.models.common.RawOutputDataConfig + :rtype: flytekit.models.admin.common.RawOutputDataConfig """ raw_output_data_config = super(SdkLaunchPlan, self).raw_output_data_config if raw_output_data_config is not None and raw_output_data_config.output_location_prefix != "": return raw_output_data_config # If it was not set explicitly then let's use the value found in the configuration. - return _common_models.RawOutputDataConfig(_auth_config.RAW_OUTPUT_DATA_PREFIX.get()) + return _common.RawOutputDataConfig(_auth_config.RAW_OUTPUT_DATA_PREFIX.get()) @_exception_scopes.system_entry_point def validate(self): @@ -309,10 +310,10 @@ def launch_with_literals( :param list[flytekit.common.notifications.Notification] notification_overrides: [Optional] If specified, these are the notifications that will be honored for this execution. An empty list signals to disable all notifications. - :param flytekit.models.common.Labels label_overrides: - :param flytekit.models.common.Annotations annotation_overrides: + :param flytekit.models.admin.common.Labels label_overrides: + :param flytekit.models.admin.common.Annotations annotation_overrides: :rtype: flytekit.common.workflow_execution.SdkWorkflowExecution - :param flytekit.models.common.AuthRole auth_role: + :param flytekit.models.admin.common.AuthRole auth_role: """ # Kubernetes requires names starting with an alphabet for some resources. name = name or "f" + _uuid.uuid4().hex[:19] @@ -407,13 +408,13 @@ def __init__( :param Text role: Deprecated. IAM role to execute this launch plan with. :param flytekit.models.schedule.Schedule: Schedule to apply to this workflow. :param list[flytekit.models.common.Notification]: List of notifications to apply to this launch plan. - :param flytekit.models.common.Labels labels: Any custom kubernetes labels to apply to workflows executed by this + :param flytekit.models.admin.common.Labels labels: Any custom kubernetes labels to apply to workflows executed by this launch plan. - :param flytekit.models.common.Annotations annotations: Any custom kubernetes annotations to apply to workflows + :param flytekit.models.admin.common.Annotations annotations: Any custom kubernetes annotations to apply to workflows executed by this launch plan. Any custom kubernetes annotations to apply to workflows executed by this launch plan. - :param flytekit.models.common.Authrole auth_role: The auth method with which to execute the workflow. - :param flytekit.models.common.RawOutputDataConfig raw_output_data_config: Config for offloading data + :param flytekit.models.admin.common.Authrole auth_role: The auth method with which to execute the workflow. + :param flytekit.models.admin.common.RawOutputDataConfig raw_output_data_config: Config for offloading data """ if role and auth_role: raise ValueError("Cannot set both role and auth. Role is deprecated, use auth instead.") @@ -422,7 +423,7 @@ def __init__( default_inputs = default_inputs or {} if role: - auth_role = _common_models.AuthRole(assumable_iam_role=role) + auth_role = flytekit.models.admin.common.AuthRole(assumable_iam_role=role) # The constructor for SdkLaunchPlan sets the id to None anyways so we don't bother passing in an ID. The ID # should be set in one of three places, @@ -445,10 +446,10 @@ def __init__( if k in fixed_inputs }, ), - labels or _common_models.Labels({}), - annotations or _common_models.Annotations({}), + labels or _common.Labels({}), + annotations or _common.Annotations({}), auth_role, - raw_output_data_config or _common_models.RawOutputDataConfig(""), + raw_output_data_config or _common.RawOutputDataConfig(""), ) self._interface = _interface.TypedInterface( {k: v.var for k, v in _six.iteritems(default_inputs)}, diff --git a/flytekit/common/local_workflow.py b/flytekit/common/local_workflow.py index eb2067578f..059533488e 100644 --- a/flytekit/common/local_workflow.py +++ b/flytekit/common/local_workflow.py @@ -4,6 +4,8 @@ import six as _six from six.moves import queue as _queue +import flytekit.models.admin.common +import flytekit.models.admin.launch_plan from flytekit.common import interface as _interface from flytekit.common import launch_plan as _launch_plan from flytekit.common import nodes as _nodes @@ -13,11 +15,11 @@ from flytekit.common.types import helpers as _type_helpers from flytekit.common.workflow import SdkWorkflow from flytekit.configuration import internal as _internal_config -from flytekit.models import common as _common_models -from flytekit.models import interface as _interface_models -from flytekit.models import literals as _literal_models -from flytekit.models import schedule as _schedule_models +from flytekit.models.admin import common as _admin_common +from flytekit.models.admin import schedule as _schedule_models from flytekit.models.core import identifier as _identifier_model +from flytekit.models.core import interface as _interface_models +from flytekit.models.core import literals as _literal_models from flytekit.models.core import workflow as _workflow_models @@ -246,9 +248,9 @@ def create_launch_plan( fixed_inputs: Dict[str, Any] = None, schedule: _schedule_models.Schedule = None, role: str = None, - notifications: List[_common_models.Notification] = None, - labels: _common_models.Labels = None, - annotations: _common_models.Annotations = None, + notifications: List[_admin_common.Notification] = None, + labels: _admin_common.Labels = None, + annotations: _admin_common.Annotations = None, assumable_iam_role: str = None, kubernetes_service_account: str = None, raw_output_data_prefix: str = None, @@ -259,10 +261,10 @@ def create_launch_plan( :param dict[Text,T] fixed_inputs: :param flytekit.models.schedule.Schedule schedule: A schedule on which to execute this launch plan. :param Text role: Deprecated. Use assumable_iam_role instead. - :param list[flytekit.models.common.Notification] notifications: A list of notifications to enact by default for + :param list[flytekit.models.admin.common.Notification] notifications: A list of notifications to enact by default for this launch plan. - :param flytekit.models.common.Labels labels: - :param flytekit.models.common.Annotations annotations: + :param flytekit.models.admin.common.Labels labels: + :param flytekit.models.admin.common.Annotations annotations: :param cls: This parameter can be used by users to define an extension of a launch plan to instantiate. The class provided should be a subclass of flytekit.common.launch_plan.SdkLaunchPlan. :param Text assumable_iam_role: The IAM role to execute the workflow with. @@ -279,12 +281,12 @@ class provided should be a subclass of flytekit.common.launch_plan.SdkLaunchPlan if role: assumable_iam_role = role # For backwards compatibility - auth_role = _common_models.AuthRole( + auth_role = flytekit.models.admin.common.AuthRole( assumable_iam_role=assumable_iam_role, kubernetes_service_account=kubernetes_service_account, ) - raw_output_config = _common_models.RawOutputDataConfig(raw_output_data_prefix or "") + raw_output_config = _admin_common.RawOutputDataConfig(raw_output_data_prefix or "") return _launch_plan.SdkRunnableLaunchPlan( sdk_workflow=self, diff --git a/flytekit/common/mixins/launchable.py b/flytekit/common/mixins/launchable.py index 110ba663af..2c42138e0c 100644 --- a/flytekit/common/mixins/launchable.py +++ b/flytekit/common/mixins/launchable.py @@ -28,9 +28,9 @@ def launch( :param list[flytekit.common.notifications.Notification] notification_overrides: [Optional] If specified, these are the notifications that will be honored for this execution. An empty list signals to disable all notifications. - :param flytekit.models.common.Labels label_overrides: - :param flytekit.models.common.Annotations annotation_overrides: - :param flytekit.models.common.AuthRole auth_role: + :param flytekit.models.admin.common.Labels label_overrides: + :param flytekit.models.admin.common.Annotations annotation_overrides: + :param flytekit.models.admin.common.AuthRole auth_role: :rtype: T """ @@ -97,9 +97,9 @@ def launch_with_literals( :param list[flytekit.common.notifications.Notification] notification_overrides: [Optional] If specified, these are the notifications that will be honored for this execution. An empty list signals to disable all notifications. - :param flytekit.models.common.Labels label_overrides: - :param flytekit.models.common.Annotations annotation_overrides: - :param flytekit.models.common.AuthRole auth_role: + :param flytekit.models.admin.common.Labels label_overrides: + :param flytekit.models.admin.common.Annotations annotation_overrides: + :param flytekit.models.admin.common.AuthRole auth_role: :rtype: flytekit.models.core.identifier.WorkflowExecutionIdentifier: """ pass diff --git a/flytekit/common/nodes.py b/flytekit/common/nodes.py index fb68de2bc2..c57e983aa1 100644 --- a/flytekit/common/nodes.py +++ b/flytekit/common/nodes.py @@ -23,9 +23,9 @@ from flytekit.engines.flyte import engine as _flyte_engine from flytekit.interfaces.data import data_proxy as _data_proxy from flytekit.models import common as _common_models -from flytekit.models import literals as _literal_models -from flytekit.models import node_execution as _node_execution_models +from flytekit.models.admin import node_execution as _node_execution_models from flytekit.models.core import execution as _execution_models +from flytekit.models.core import literals as _literal_models from flytekit.models.core import workflow as _workflow_model @@ -184,7 +184,7 @@ def promote_from_model(cls, model, sub_workflows, tasks): :param flytekit.models.core.workflow.Node model: :param dict[flytekit.models.core.identifier.Identifier, flytekit.models.core.workflow.WorkflowTemplate] sub_workflows: - :param dict[flytekit.models.core.identifier.Identifier, flytekit.models.task.TaskTemplate] tasks: If specified, + :param dict[flytekit.models.core.identifier.Identifier, flytekit.models.core.task.TaskTemplate] tasks: If specified, these task templates will be passed to the SdkTaskNode promote_from_model call, and used instead of fetching from Admin. :rtype: SdkNode diff --git a/flytekit/common/notifications.py b/flytekit/common/notifications.py index 09b1d11358..c4833ff7e4 100644 --- a/flytekit/common/notifications.py +++ b/flytekit/common/notifications.py @@ -1,6 +1,6 @@ from flytekit.common import sdk_bases as _sdk_bases from flytekit.common.exceptions import user as _user_exceptions -from flytekit.models import common as _common_model +from flytekit.models.admin import common as _common_model from flytekit.models.core import execution as _execution_model @@ -59,7 +59,7 @@ def __init__(self, phases, recipients_email): @classmethod def promote_from_model(cls, base_model): """ - :param flytekit.models.common.Notification base_model: + :param flytekit.models.admin.common.Notification base_model: :rtype: Notification """ return cls(base_model.phases, base_model.pager_duty.recipients_email) @@ -77,7 +77,7 @@ def __init__(self, phases, recipients_email): @classmethod def promote_from_model(cls, base_model): """ - :param flytekit.models.common.Notification base_model: + :param flytekit.models.admin.common.Notification base_model: :rtype: Notification """ return cls(base_model.phases, base_model.email.recipients_email) @@ -95,7 +95,7 @@ def __init__(self, phases, recipients_email): @classmethod def promote_from_model(cls, base_model): """ - :param flytekit.models.common.Notification base_model: + :param flytekit.models.admin.common.Notification base_model: :rtype: Notification """ return cls(base_model.phases, base_model.slack.recipients_email) diff --git a/flytekit/common/promise.py b/flytekit/common/promise.py index 79352dd638..052089eb5c 100644 --- a/flytekit/common/promise.py +++ b/flytekit/common/promise.py @@ -1,9 +1,9 @@ +import flytekit.models.core.types from flytekit.common import constants as _constants from flytekit.common import sdk_bases as _sdk_bases from flytekit.common.exceptions import user as _user_exceptions from flytekit.common.types import helpers as _type_helpers -from flytekit.models import interface as _interface_models -from flytekit.models import types as _type_models +from flytekit.models.core import interface as _interface_models class Input(_interface_models.Parameter, metaclass=_sdk_bases.ExtendedSdkType): @@ -42,7 +42,7 @@ def __init__(self, name, sdk_type, help=None, **kwargs): self._sdk_default = default self._help = help self._sdk_type = sdk_type - self._promise = _type_models.OutputReference(_constants.GLOBAL_INPUT_NODE_ID, name) + self._promise = flytekit.models.core.types.OutputReference(_constants.GLOBAL_INPUT_NODE_ID, name) self._name = name super(Input, self).__init__( _interface_models.Variable(type=sdk_type.to_flyte_literal_type(), description=help or ""), @@ -120,7 +120,7 @@ def promote_from_model(cls, model): return cls("", sdk_type, help=model.var.description, required=True) -class NodeOutput(_type_models.OutputReference, metaclass=_sdk_bases.ExtendedSdkType): +class NodeOutput(flytekit.models.core.types.OutputReference, metaclass=_sdk_bases.ExtendedSdkType): def __init__(self, sdk_node, sdk_type, var): """ :param sdk_node: diff --git a/flytekit/common/schedules.py b/flytekit/common/schedules.py index d6e71c6f83..176cc9534d 100644 --- a/flytekit/common/schedules.py +++ b/flytekit/common/schedules.py @@ -5,7 +5,7 @@ from flytekit.common import sdk_bases as _sdk_bases from flytekit.common.exceptions import user as _user_exceptions -from flytekit.models import schedule as _schedule_models +from flytekit.models.admin import schedule as _schedule_models class _ExtendedSchedule(_schedule_models.Schedule): diff --git a/flytekit/common/tasks/executions.py b/flytekit/common/tasks/executions.py index d87c558c09..b11ce3ef70 100644 --- a/flytekit/common/tasks/executions.py +++ b/flytekit/common/tasks/executions.py @@ -11,9 +11,9 @@ from flytekit.common.types import helpers as _type_helpers from flytekit.engines.flyte import engine as _flyte_engine from flytekit.interfaces.data import data_proxy as _data_proxy -from flytekit.models import literals as _literal_models from flytekit.models.admin import task_execution as _task_execution_model from flytekit.models.core import execution as _execution_models +from flytekit.models.core import literals as _literal_models class SdkTaskExecution( diff --git a/flytekit/common/tasks/generic_spark_task.py b/flytekit/common/tasks/generic_spark_task.py index 95df1b9204..37a7b35ec3 100644 --- a/flytekit/common/tasks/generic_spark_task.py +++ b/flytekit/common/tasks/generic_spark_task.py @@ -3,6 +3,7 @@ import six as _six from google.protobuf.json_format import MessageToDict as _MessageToDict +import flytekit.models.core.task from flytekit import __version__ from flytekit.common import interface as _interface from flytekit.common.exceptions import scopes as _exception_scopes @@ -11,8 +12,10 @@ from flytekit.common.types import helpers as _helpers from flytekit.common.types import primitives as _primitives from flytekit.configuration import internal as _internal_config -from flytekit.models import literals as _literal_models -from flytekit.models import task as _task_models +from flytekit.models.core import literals as _literal_models +from flytekit.models.core.task import RuntimeMetadata as _runtimeMetadata +from flytekit.models.core.task import TaskMetadata as _taskMetadata +from flytekit.models.plugins import task as _task_models input_types_supported = { _primitives.Integer, @@ -75,10 +78,10 @@ def __init__( super(SdkGenericSparkTask, self).__init__( task_type, - _task_models.TaskMetadata( + _taskMetadata( discoverable, - _task_models.RuntimeMetadata( - _task_models.RuntimeMetadata.RuntimeType.FLYTE_SDK, + _runtimeMetadata( + _runtimeMetadata.RuntimeType.FLYTE_SDK, __version__, "spark", ), @@ -139,11 +142,11 @@ def _get_container_definition( args.append("--{}".format(k)) args.append("{{{{.Inputs.{}}}}}".format(k)) - return _task_models.Container( + return flytekit.models.core.task.Container( image=_internal_config.IMAGE.get(), command=[], args=args, - resources=_task_models.Resources([], []), + resources=flytekit.models.core.task.Resources([], []), env=environment, config={}, architecture=architecture, diff --git a/flytekit/common/tasks/hive_task.py b/flytekit/common/tasks/hive_task.py index 6a7a6b7d48..c9af28eb39 100644 --- a/flytekit/common/tasks/hive_task.py +++ b/flytekit/common/tasks/hive_task.py @@ -13,11 +13,11 @@ from flytekit.common.tasks import sdk_runnable as _sdk_runnable from flytekit.common.tasks import task as _base_task from flytekit.common.types import helpers as _type_helpers -from flytekit.models import dynamic_job as _dynamic_job -from flytekit.models import interface as _interface_model -from flytekit.models import literals as _literal_models -from flytekit.models import qubole as _qubole +from flytekit.models.core import dynamic_job as _dynamic_job +from flytekit.models.core import interface as _interface_model +from flytekit.models.core import literals as _literal_models from flytekit.models.core import workflow as _workflow_model +from flytekit.models.plugins import qubole as _qubole ALLOWED_TAGS_COUNT = int(6) MAX_TAG_LENGTH = int(20) @@ -259,7 +259,7 @@ def _create_hive_job_node(name, hive_job, metadata): """ :param Text name: :param _qubole.QuboleHiveJob hive_job: Hive job spec - :param flytekit.models.task.TaskMetadata metadata: This contains information needed at runtime to determine + :param flytekit.models.core.task.TaskMetadata metadata: This contains information needed at runtime to determine behavior such as whether or not outputs are discoverable, timeouts, and retries. :rtype: _nodes.SdkNode: """ diff --git a/flytekit/common/tasks/presto_task.py b/flytekit/common/tasks/presto_task.py index 47e198494d..7c6e240151 100644 --- a/flytekit/common/tasks/presto_task.py +++ b/flytekit/common/tasks/presto_task.py @@ -3,17 +3,18 @@ import six as _six from google.protobuf.json_format import MessageToDict as _MessageToDict +import flytekit.models.core.task +import flytekit.models.core.types from flytekit import __version__ from flytekit.common import constants as _constants from flytekit.common import interface as _interface from flytekit.common.exceptions import scopes as _exception_scopes from flytekit.common.tasks import task as _base_task from flytekit.common.types import helpers as _type_helpers -from flytekit.models import interface as _interface_model -from flytekit.models import literals as _literals -from flytekit.models import presto as _presto_models -from flytekit.models import task as _task_model -from flytekit.models import types as _types +from flytekit.models.core import interface as _interface_model +from flytekit.models.core import literals as _literals +from flytekit.models.core import task as _task_model +from flytekit.models.plugins import presto as _presto_models class SdkPrestoTask(_base_task.SdkTask): @@ -60,7 +61,9 @@ def __init__( metadata = _task_model.TaskMetadata( discoverable, # This needs to have the proper version reflected in it - _task_model.RuntimeMetadata(_task_model.RuntimeMetadata.RuntimeType.FLYTE_SDK, __version__, "python"), + flytekit.models.core.task.RuntimeMetadata( + flytekit.models.core.task.RuntimeMetadata.RuntimeType.FLYTE_SDK, __version__, "python" + ), timeout or _datetime.timedelta(seconds=0), _literals.RetryStrategy(retries), interruptible, @@ -80,22 +83,22 @@ def __init__( i = _interface.TypedInterface( { "__implicit_routing_group": _interface_model.Variable( - type=_types.LiteralType(simple=_types.SimpleType.STRING), + type=flytekit.models.core.types.LiteralType(simple=flytekit.models.core.types.SimpleType.STRING), description="The routing group set as an implicit input", ), "__implicit_catalog": _interface_model.Variable( - type=_types.LiteralType(simple=_types.SimpleType.STRING), + type=flytekit.models.core.types.LiteralType(simple=flytekit.models.core.types.SimpleType.STRING), description="The catalog set as an implicit input", ), "__implicit_schema": _interface_model.Variable( - type=_types.LiteralType(simple=_types.SimpleType.STRING), + type=flytekit.models.core.types.LiteralType(simple=flytekit.models.core.types.SimpleType.STRING), description="The schema set as an implicit input", ), }, { # Set the schema for the Presto query as an output "results": _interface_model.Variable( - type=_types.LiteralType(schema=output_schema.schema_type), + type=flytekit.models.core.types.LiteralType(schema=output_schema.schema_type), description="The schema for the Presto query", ) }, diff --git a/flytekit/common/tasks/pytorch_task.py b/flytekit/common/tasks/pytorch_task.py index b88ecc7121..4a688f850d 100644 --- a/flytekit/common/tasks/pytorch_task.py +++ b/flytekit/common/tasks/pytorch_task.py @@ -1,7 +1,7 @@ from google.protobuf.json_format import MessageToDict as _MessageToDict from flytekit.common.tasks import sdk_runnable as _sdk_runnable -from flytekit.models import task as _task_models +from flytekit.models.plugins import task as _task_models class SdkRunnablePytorchContainer(_sdk_runnable.SdkRunnableContainer): diff --git a/flytekit/common/tasks/raw_container.py b/flytekit/common/tasks/raw_container.py index c290b36205..58108bbb8b 100644 --- a/flytekit/common/tasks/raw_container.py +++ b/flytekit/common/tasks/raw_container.py @@ -2,15 +2,17 @@ from typing import Dict, List import flytekit +import flytekit.models.core.task from flytekit.common import constants as _constants from flytekit.common import interface as _interface from flytekit.common.exceptions import scopes as _exception_scopes from flytekit.common.tasks import task as _base_task from flytekit.common.types.base_sdk_types import FlyteSdkType from flytekit.configuration import resources as _resource_config -from flytekit.models import literals as _literals -from flytekit.models import task as _task_models -from flytekit.models.interface import Variable +from flytekit.models.core import literals as _literals +from flytekit.models.core.interface import Variable +from flytekit.models.core.task import RuntimeMetadata as _runtimeMetadata +from flytekit.models.core.task import TaskMetadata as _taskMetadata def types_to_variable(t: Dict[str, FlyteSdkType]) -> Dict[str, Variable]: @@ -25,7 +27,7 @@ def _get_container_definition( image: str, command: List[str], args: List[str], - data_loading_config: _task_models.DataLoadingConfig, + data_loading_config: flytekit.models.core.task.DataLoadingConfig, storage_request: str = None, ephemeral_storage_request: str = None, cpu_request: str = None, @@ -37,7 +39,7 @@ def _get_container_definition( gpu_limit: str = None, memory_limit: str = None, environment: Dict[str, str] = None, -) -> _task_models.Container: +) -> flytekit.models.core.task.Container: storage_limit = storage_limit or _resource_config.DEFAULT_STORAGE_LIMIT.get() storage_request = storage_request or _resource_config.DEFAULT_STORAGE_REQUEST.get() ephemeral_storage_limit = ephemeral_storage_limit or _resource_config.DEFAULT_EPHEMERAL_STORAGE_LIMIT.get() @@ -52,47 +54,75 @@ def _get_container_definition( requests = [] if storage_request: requests.append( - _task_models.Resources.ResourceEntry(_task_models.Resources.ResourceName.STORAGE, storage_request) + flytekit.models.core.task.Resources.ResourceEntry( + flytekit.models.core.task.Resources.ResourceName.STORAGE, storage_request + ) ) if ephemeral_storage_request: requests.append( - _task_models.Resources.ResourceEntry( - _task_models.Resources.ResourceName.EPHEMERAL_STORAGE, ephemeral_storage_request + flytekit.models.core.task.Resources.ResourceEntry( + flytekit.models.core.task.Resources.ResourceName.EPHEMERAL_STORAGE, ephemeral_storage_request ) ) if cpu_request: - requests.append(_task_models.Resources.ResourceEntry(_task_models.Resources.ResourceName.CPU, cpu_request)) + requests.append( + flytekit.models.core.task.Resources.ResourceEntry( + flytekit.models.core.task.Resources.ResourceName.CPU, cpu_request + ) + ) if gpu_request: - requests.append(_task_models.Resources.ResourceEntry(_task_models.Resources.ResourceName.GPU, gpu_request)) + requests.append( + flytekit.models.core.task.Resources.ResourceEntry( + flytekit.models.core.task.Resources.ResourceName.GPU, gpu_request + ) + ) if memory_request: requests.append( - _task_models.Resources.ResourceEntry(_task_models.Resources.ResourceName.MEMORY, memory_request) + flytekit.models.core.task.Resources.ResourceEntry( + flytekit.models.core.task.Resources.ResourceName.MEMORY, memory_request + ) ) limits = [] if storage_limit: - limits.append(_task_models.Resources.ResourceEntry(_task_models.Resources.ResourceName.STORAGE, storage_limit)) + limits.append( + flytekit.models.core.task.Resources.ResourceEntry( + flytekit.models.core.task.Resources.ResourceName.STORAGE, storage_limit + ) + ) if ephemeral_storage_limit: limits.append( - _task_models.Resources.ResourceEntry( - _task_models.Resources.ResourceName.EPHEMERAL_STORAGE, ephemeral_storage_limit + flytekit.models.core.task.Resources.ResourceEntry( + flytekit.models.core.task.Resources.ResourceName.EPHEMERAL_STORAGE, ephemeral_storage_limit ) ) if cpu_limit: - limits.append(_task_models.Resources.ResourceEntry(_task_models.Resources.ResourceName.CPU, cpu_limit)) + limits.append( + flytekit.models.core.task.Resources.ResourceEntry( + flytekit.models.core.task.Resources.ResourceName.CPU, cpu_limit + ) + ) if gpu_limit: - limits.append(_task_models.Resources.ResourceEntry(_task_models.Resources.ResourceName.GPU, gpu_limit)) + limits.append( + flytekit.models.core.task.Resources.ResourceEntry( + flytekit.models.core.task.Resources.ResourceName.GPU, gpu_limit + ) + ) if memory_limit: - limits.append(_task_models.Resources.ResourceEntry(_task_models.Resources.ResourceName.MEMORY, memory_limit)) + limits.append( + flytekit.models.core.task.Resources.ResourceEntry( + flytekit.models.core.task.Resources.ResourceName.MEMORY, memory_limit + ) + ) if environment is None: environment = {} - return _task_models.Container( + return flytekit.models.core.task.Container( image=image, command=command, args=args, - resources=_task_models.Resources(limits=limits, requests=requests), + resources=flytekit.models.core.task.Resources(limits=limits, requests=requests), env=environment, config={}, data_loading_config=data_loading_config, @@ -105,9 +135,9 @@ class SdkRawContainerTask(_base_task.SdkTask): separately as a container completely separate from the container where your Flyte workflow is defined. """ - METADATA_FORMAT_JSON = _task_models.DataLoadingConfig.LITERALMAP_FORMAT_JSON - METADATA_FORMAT_YAML = _task_models.DataLoadingConfig.LITERALMAP_FORMAT_YAML - METADATA_FORMAT_PROTO = _task_models.DataLoadingConfig.LITERALMAP_FORMAT_PROTO + METADATA_FORMAT_JSON = flytekit.models.core.task.DataLoadingConfig.LITERALMAP_FORMAT_JSON + METADATA_FORMAT_YAML = flytekit.models.core.task.DataLoadingConfig.LITERALMAP_FORMAT_YAML + METADATA_FORMAT_PROTO = flytekit.models.core.task.DataLoadingConfig.LITERALMAP_FORMAT_PROTO def __init__( self, @@ -117,7 +147,7 @@ def __init__( input_data_dir: str = None, output_data_dir: str = None, metadata_format: int = METADATA_FORMAT_JSON, - io_strategy: _task_models.IOStrategy = None, + io_strategy: flytekit.models.core.task.IOStrategy = None, command: List[str] = None, args: List[str] = None, storage_request: str = None, @@ -162,7 +192,7 @@ def __init__( # Set as class fields which are used down below to configure implicit # parameters - self._data_loading_config = _task_models.DataLoadingConfig( + self._data_loading_config = flytekit.models.core.task.DataLoadingConfig( input_path=input_data_dir, output_path=output_data_dir, format=metadata_format, @@ -170,11 +200,11 @@ def __init__( io_strategy=io_strategy, ) - metadata = _task_models.TaskMetadata( + metadata = _taskMetadata( discoverable, # This needs to have the proper version reflected in it - _task_models.RuntimeMetadata( - _task_models.RuntimeMetadata.RuntimeType.FLYTE_SDK, + _runtimeMetadata( + _runtimeMetadata.RuntimeType.FLYTE_SDK, flytekit.__version__, "python", ), diff --git a/flytekit/common/tasks/sagemaker/built_in_training_job_task.py b/flytekit/common/tasks/sagemaker/built_in_training_job_task.py index 356f933c3e..82c30d0b9a 100644 --- a/flytekit/common/tasks/sagemaker/built_in_training_job_task.py +++ b/flytekit/common/tasks/sagemaker/built_in_training_job_task.py @@ -2,15 +2,15 @@ from google.protobuf.json_format import MessageToDict +import flytekit.models.core.task +import flytekit.models.core.types from flytekit import __version__ from flytekit.common import interface as _interface from flytekit.common.constants import SdkTaskType from flytekit.common.exceptions import user as _user_exceptions from flytekit.common.tasks import task as _sdk_task -from flytekit.models import interface as _interface_model -from flytekit.models import literals as _literal_models -from flytekit.models import task as _task_models -from flytekit.models import types as _idl_types +from flytekit.models.core import interface as _interface_model +from flytekit.models.core import literals as _literal_models from flytekit.models.core import types as _core_types from flytekit.models.sagemaker import training_job as _training_job_models @@ -51,9 +51,9 @@ def __init__( super(SdkBuiltinAlgorithmTrainingJobTask, self).__init__( type=SdkTaskType.SAGEMAKER_TRAINING_JOB_TASK, - metadata=_task_models.TaskMetadata( - runtime=_task_models.RuntimeMetadata( - type=_task_models.RuntimeMetadata.RuntimeType.FLYTE_SDK, + metadata=flytekit.models.core.task.TaskMetadata( + runtime=flytekit.models.core.task.RuntimeMetadata( + type=flytekit.models.core.task.RuntimeMetadata.RuntimeType.FLYTE_SDK, version=__version__, flavor="sagemaker", ), @@ -67,11 +67,13 @@ def __init__( interface=_interface.TypedInterface( inputs={ "static_hyperparameters": _interface_model.Variable( - type=_idl_types.LiteralType(simple=_idl_types.SimpleType.STRUCT), + type=flytekit.models.core.types.LiteralType( + simple=flytekit.models.core.types.SimpleType.STRUCT + ), description="", ), "train": _interface_model.Variable( - type=_idl_types.LiteralType( + type=flytekit.models.core.types.LiteralType( blob=_core_types.BlobType( format=_content_type_to_blob_format(algorithm_specification.input_content_type), dimensionality=_core_types.BlobType.BlobDimensionality.MULTIPART, @@ -80,7 +82,7 @@ def __init__( description="", ), "validation": _interface_model.Variable( - type=_idl_types.LiteralType( + type=flytekit.models.core.types.LiteralType( blob=_core_types.BlobType( format=_content_type_to_blob_format(algorithm_specification.input_content_type), dimensionality=_core_types.BlobType.BlobDimensionality.MULTIPART, @@ -91,7 +93,7 @@ def __init__( }, outputs={ "model": _interface_model.Variable( - type=_idl_types.LiteralType( + type=flytekit.models.core.types.LiteralType( blob=_core_types.BlobType( format="", dimensionality=_core_types.BlobType.BlobDimensionality.SINGLE, diff --git a/flytekit/common/tasks/sagemaker/hpo_job_task.py b/flytekit/common/tasks/sagemaker/hpo_job_task.py index fd4d0a3ce1..d57b0bedcd 100644 --- a/flytekit/common/tasks/sagemaker/hpo_job_task.py +++ b/flytekit/common/tasks/sagemaker/hpo_job_task.py @@ -3,6 +3,8 @@ from google.protobuf.json_format import MessageToDict +import flytekit.models.core.task +import flytekit.models.core.types from flytekit import __version__ from flytekit.common import interface as _interface from flytekit.common.constants import SdkTaskType @@ -10,10 +12,8 @@ from flytekit.common.tasks.sagemaker.built_in_training_job_task import SdkBuiltinAlgorithmTrainingJobTask from flytekit.common.tasks.sagemaker.custom_training_job_task import CustomTrainingJobTask from flytekit.common.tasks.sagemaker.types import HyperparameterTuningJobConfig, ParameterRange -from flytekit.models import interface as _interface_model -from flytekit.models import literals as _literal_models -from flytekit.models import task as _task_models -from flytekit.models import types as _types_models +from flytekit.models.core import interface as _interface_model +from flytekit.models.core import literals as _literal_models from flytekit.models.core import types as _core_types from flytekit.models.sagemaker import hpo_job as _hpo_job_model @@ -79,9 +79,9 @@ def __init__( super().__init__( type=SdkTaskType.SAGEMAKER_HYPERPARAMETER_TUNING_JOB_TASK, - metadata=_task_models.TaskMetadata( - runtime=_task_models.RuntimeMetadata( - type=_task_models.RuntimeMetadata.RuntimeType.FLYTE_SDK, + metadata=flytekit.models.core.task.TaskMetadata( + runtime=flytekit.models.core.task.RuntimeMetadata( + type=flytekit.models.core.task.RuntimeMetadata.RuntimeType.FLYTE_SDK, version=__version__, flavor="sagemaker", ), @@ -96,7 +96,7 @@ def __init__( inputs=inputs, outputs={ "model": _interface_model.Variable( - type=_types_models.LiteralType( + type=flytekit.models.core.types.LiteralType( blob=_core_types.BlobType( format="", dimensionality=_core_types.BlobType.BlobDimensionality.SINGLE, diff --git a/flytekit/common/tasks/sdk_dynamic.py b/flytekit/common/tasks/sdk_dynamic.py index 2e64051488..18f15553a9 100644 --- a/flytekit/common/tasks/sdk_dynamic.py +++ b/flytekit/common/tasks/sdk_dynamic.py @@ -18,9 +18,9 @@ from flytekit.common.types import helpers as _type_helpers from flytekit.common.utils import _dnsify from flytekit.configuration import internal as _internal_config -from flytekit.models import array_job as _array_job -from flytekit.models import dynamic_job as _dynamic_job -from flytekit.models import literals as _literal_models +from flytekit.models.core import dynamic_job as _dynamic_job +from flytekit.models.core import literals as _literal_models +from flytekit.models.plugins import array_job as _array_job class PromiseOutputReference(_task_output.OutputReference): diff --git a/flytekit/common/tasks/sdk_runnable.py b/flytekit/common/tasks/sdk_runnable.py index 0ca5e3db1a..e2eb151b59 100644 --- a/flytekit/common/tasks/sdk_runnable.py +++ b/flytekit/common/tasks/sdk_runnable.py @@ -12,6 +12,7 @@ import six as _six +import flytekit.models.core.task from flytekit.common import constants as _constants from flytekit.common import interface as _interface from flytekit.common import sdk_bases as _sdk_bases @@ -28,8 +29,9 @@ from flytekit.configuration import secrets from flytekit.engines import loader as _engine_loader from flytekit.interfaces.stats import taggable -from flytekit.models import literals as _literal_models -from flytekit.models import task as _task_models +from flytekit.models.core import literals as _literal_models +from flytekit.models.core.task import RuntimeMetadata as _runtimeMetadata +from flytekit.models.core.task import TaskMetadata as _taskMetadata class SecretsManager(object): @@ -243,7 +245,7 @@ def get(self, key: str) -> typing.Any: return self.__getattr__(attr_name=key) -class SdkRunnableContainer(_task_models.Container, metaclass=_sdk_bases.ExtendedSdkType): +class SdkRunnableContainer(flytekit.models.core.task.Container, metaclass=_sdk_bases.ExtendedSdkType): """ This is not necessarily a local-only Container object. So long as configuration is present, you can use this object """ @@ -324,32 +326,56 @@ def get_resources( requests = [] if storage_request: requests.append( - _task_models.Resources.ResourceEntry(_task_models.Resources.ResourceName.STORAGE, storage_request) + flytekit.models.core.task.Resources.ResourceEntry( + flytekit.models.core.task.Resources.ResourceName.STORAGE, storage_request + ) ) if cpu_request: - requests.append(_task_models.Resources.ResourceEntry(_task_models.Resources.ResourceName.CPU, cpu_request)) + requests.append( + flytekit.models.core.task.Resources.ResourceEntry( + flytekit.models.core.task.Resources.ResourceName.CPU, cpu_request + ) + ) if gpu_request: - requests.append(_task_models.Resources.ResourceEntry(_task_models.Resources.ResourceName.GPU, gpu_request)) + requests.append( + flytekit.models.core.task.Resources.ResourceEntry( + flytekit.models.core.task.Resources.ResourceName.GPU, gpu_request + ) + ) if memory_request: requests.append( - _task_models.Resources.ResourceEntry(_task_models.Resources.ResourceName.MEMORY, memory_request) + flytekit.models.core.task.Resources.ResourceEntry( + flytekit.models.core.task.Resources.ResourceName.MEMORY, memory_request + ) ) limits = [] if storage_limit: limits.append( - _task_models.Resources.ResourceEntry(_task_models.Resources.ResourceName.STORAGE, storage_limit) + flytekit.models.core.task.Resources.ResourceEntry( + flytekit.models.core.task.Resources.ResourceName.STORAGE, storage_limit + ) ) if cpu_limit: - limits.append(_task_models.Resources.ResourceEntry(_task_models.Resources.ResourceName.CPU, cpu_limit)) + limits.append( + flytekit.models.core.task.Resources.ResourceEntry( + flytekit.models.core.task.Resources.ResourceName.CPU, cpu_limit + ) + ) if gpu_limit: - limits.append(_task_models.Resources.ResourceEntry(_task_models.Resources.ResourceName.GPU, gpu_limit)) + limits.append( + flytekit.models.core.task.Resources.ResourceEntry( + flytekit.models.core.task.Resources.ResourceName.GPU, gpu_limit + ) + ) if memory_limit: limits.append( - _task_models.Resources.ResourceEntry(_task_models.Resources.ResourceName.MEMORY, memory_limit) + flytekit.models.core.task.Resources.ResourceEntry( + flytekit.models.core.task.Resources.ResourceName.MEMORY, memory_limit + ) ) - return _task_models.Resources(limits=limits, requests=requests) + return flytekit.models.core.task.Resources(limits=limits, requests=requests) class SdkRunnableTaskStyle(enum.Enum): @@ -414,10 +440,10 @@ def __init__( self._task_function = task_function super(SdkRunnableTask, self).__init__( task_type, - _task_models.TaskMetadata( + _taskMetadata( discoverable, - _task_models.RuntimeMetadata( - _task_models.RuntimeMetadata.RuntimeType.FLYTE_SDK, + _runtimeMetadata( + _runtimeMetadata.RuntimeType.FLYTE_SDK, __version__, "python", ), diff --git a/flytekit/common/tasks/sidecar_task.py b/flytekit/common/tasks/sidecar_task.py index d79bf42e93..7311e4a123 100644 --- a/flytekit/common/tasks/sidecar_task.py +++ b/flytekit/common/tasks/sidecar_task.py @@ -2,11 +2,11 @@ from flyteidl.core import tasks_pb2 as _core_task from google.protobuf.json_format import MessageToDict as _MessageToDict +import flytekit.models.core.task from flytekit.common import sdk_bases as _sdk_bases from flytekit.common.exceptions import user as _user_exceptions from flytekit.common.tasks import sdk_dynamic as _sdk_dynamic from flytekit.common.tasks import sdk_runnable as _sdk_runnable -from flytekit.models import task as _task_models from flytekit.plugins import k8s as _lazy_k8s @@ -98,7 +98,7 @@ def reconcile_partial_pod_spec_and_task(self, pod_spec, primary_container_name, primary_exists = True break if not primary_exists: - containers.extend([_lazy_k8s.io.api.core.v1.generated_pb2.Container(name=primary_container_name)]) + containers.extend([flytekit.models.core.task.Container(name=primary_container_name)]) final_containers = [] for container in containers: @@ -139,7 +139,7 @@ def reconcile_partial_pod_spec_and_task(self, pod_spec, primary_container_name, del pod_spec.containers[:] pod_spec.containers.extend(final_containers) - sidecar_job_plugin = _task_models.SidecarJob( + sidecar_job_plugin = flytekit.models.core.task.SidecarJob( pod_spec=pod_spec, primary_container_name=primary_container_name, annotations=annotations, diff --git a/flytekit/common/tasks/spark_task.py b/flytekit/common/tasks/spark_task.py index 4b994286c1..68e522dc62 100644 --- a/flytekit/common/tasks/spark_task.py +++ b/flytekit/common/tasks/spark_task.py @@ -20,8 +20,8 @@ from flytekit.common.tasks import output as _task_output from flytekit.common.tasks import sdk_runnable as _sdk_runnable from flytekit.common.types import helpers as _type_helpers -from flytekit.models import literals as _literal_models -from flytekit.models import task as _task_models +from flytekit.models.core import literals as _literal_models +from flytekit.models.plugins import task as _task_models from flytekit.plugins import pyspark as _pyspark diff --git a/flytekit/common/tasks/task.py b/flytekit/common/tasks/task.py index ba55399382..bf0118dade 100644 --- a/flytekit/common/tasks/task.py +++ b/flytekit/common/tasks/task.py @@ -7,6 +7,8 @@ from google.protobuf import json_format as _json_format from google.protobuf import struct_pb2 as _struct +import flytekit.models.admin.common +import flytekit.models.admin.launch_plan from flytekit.common import interface as _interfaces from flytekit.common import nodes as _nodes from flytekit.common import sdk_bases as _sdk_bases @@ -22,19 +24,20 @@ from flytekit.configuration import internal as _internal_config from flytekit.configuration import sdk as _sdk_config from flytekit.engines.flyte import engine as _flyte_engine -from flytekit.models import common as _common_model -from flytekit.models import execution as _admin_execution_models -from flytekit.models import task as _task_model from flytekit.models.admin import common as _admin_common +from flytekit.models.admin import execution as _admin_execution_models +from flytekit.models.admin.common import NamedEntityIdentifier as _namedEntityIdentifier +from flytekit.models.admin.task import TaskSpec as _taskSpec from flytekit.models.core import identifier as _identifier_model from flytekit.models.core import workflow as _workflow_model +from flytekit.models.core.task import TaskTemplate as _taskTemplate class SdkTask( _hash_mixin.HashOnReferenceMixin, _registerable.RegisterableEntity, _launchable_mixin.LaunchableEntity, - _task_model.TaskTemplate, + _taskTemplate, metaclass=_sdk_bases.ExtendedSdkType, ): def __init__( @@ -96,7 +99,7 @@ def entity_type_text(self): @classmethod def promote_from_model(cls, base_model): """ - :param flytekit.models.task.TaskTemplate base_model: + :param flytekit.models.core.task.TaskTemplate base_model: :rtype: SdkTask """ t = cls( @@ -170,7 +173,7 @@ def register(self, project, domain, name, version): client = _flyte_engine.get_client() try: self._id = id_to_register - client.create_task(id_to_register, _task_model.TaskSpec(self)) + client.create_task(id_to_register, _taskSpec(self)) self._id = old_id self._has_registered = True return str(id_to_register) @@ -185,7 +188,7 @@ def serialize(self): """ :rtype: flyteidl.admin.task_pb2.TaskSpec """ - return _task_model.TaskSpec(self).to_flyte_idl() + return _taskSpec(self).to_flyte_idl() @classmethod @_exception_scopes.system_entry_point @@ -216,7 +219,7 @@ def fetch_latest(cls, project, domain, name): :param Text name: :rtype: SdkTask """ - named_task = _common_model.NamedEntityIdentifier(project, domain, name) + named_task = _namedEntityIdentifier(project, domain, name) client = _flyte_engine.get_client() task_list, _ = client.list_tasks_paginated( named_task, @@ -366,9 +369,9 @@ def launch_with_literals( :param list[flytekit.common.notifications.Notification] notification_overrides: [Optional] If specified, these are the notifications that will be honored for this execution. An empty list signals to disable all notifications. - :param flytekit.models.common.Labels label_overrides: - :param flytekit.models.common.Annotations annotation_overrides: - :param flytekit.models.common.AuthRole auth_role: + :param flytekit.models.admin.common.Labels label_overrides: + :param flytekit.models.admin.common.Annotations annotation_overrides: + :param flytekit.models.admin.common.AuthRole auth_role: :rtype: flytekit.common.workflow_execution.SdkWorkflowExecution """ disable_all = notification_overrides == [] @@ -390,7 +393,7 @@ def launch_with_literals( "Please update your config to use `assumable_iam_role` instead" ) assumable_iam_role = _sdk_config.ROLE.get() - auth_role = _common_model.AuthRole( + auth_role = flytekit.models.admin.common.AuthRole( assumable_iam_role=assumable_iam_role, kubernetes_service_account=kubernetes_service_account, ) diff --git a/flytekit/common/tasks/tensorflow_task.py b/flytekit/common/tasks/tensorflow_task.py index fdf582dbb9..48387d229b 100644 --- a/flytekit/common/tasks/tensorflow_task.py +++ b/flytekit/common/tasks/tensorflow_task.py @@ -1,7 +1,7 @@ from google.protobuf.json_format import MessageToDict as _MessageToDict from flytekit.common.tasks import sdk_runnable as _sdk_runnable -from flytekit.models import task as _task_models +from flytekit.models.plugins import task as _task_models class SdkRunnableTensorflowContainer(_sdk_runnable.SdkRunnableContainer): diff --git a/flytekit/common/translator.py b/flytekit/common/translator.py index 62d9e192dd..bcab26ea03 100644 --- a/flytekit/common/translator.py +++ b/flytekit/common/translator.py @@ -1,6 +1,9 @@ from collections import OrderedDict from typing import Callable, Dict, List, Optional, Tuple, Union +import flytekit.models.admin.common +import flytekit.models.admin.launch_plan +import flytekit.models.core.task from flytekit.common import constants as _common_constants from flytekit.common.utils import _dnsify from flytekit.core.base_task import PythonTask @@ -12,12 +15,12 @@ from flytekit.core.reference_entity import ReferenceEntity, ReferenceSpec, ReferenceTemplate from flytekit.core.task import ReferenceTask from flytekit.core.workflow import ReferenceWorkflow, WorkflowBase -from flytekit.models import common as _common_models -from flytekit.models import interface as interface_models -from flytekit.models import launch_plan as _launch_plan_models -from flytekit.models import task as task_models +from flytekit.models.admin import common as _common +from flytekit.models.admin import launch_plan as _launch_plan_models +from flytekit.models.admin import task as task_models from flytekit.models.admin import workflow as admin_workflow_models from flytekit.models.core import identifier as _identifier_model +from flytekit.models.core import interface as interface_models from flytekit.models.core import workflow as _core_wf from flytekit.models.core import workflow as workflow_model from flytekit.models.core.workflow import BranchNode as BranchNodeModel @@ -105,7 +108,7 @@ def get_serializable_task( # tasks that rely on user code defined in the container. This should be encapsulated by the auto container # parent class entity.set_command_fn(_fast_serialize_command_fn(settings, entity)) - tt = task_models.TaskTemplate( + tt = flytekit.models.core.task.TaskTemplate( id=task_id, type=entity.task_type, metadata=entity.metadata.to_taskmetadata_model(), @@ -206,10 +209,10 @@ def get_serializable_launch_plan( ), default_inputs=entity.parameters, fixed_inputs=entity.fixed_inputs, - labels=entity.labels or _common_models.Labels({}), - annotations=entity.annotations or _common_models.Annotations({}), - auth_role=entity._auth_role or _common_models.AuthRole(), - raw_output_data_config=entity.raw_output_data_config or _common_models.RawOutputDataConfig(""), + labels=entity.labels or _common.Labels({}), + annotations=entity.annotations or _common.Annotations({}), + auth_role=entity._auth_role or flytekit.models.admin.common.AuthRole(), + raw_output_data_config=entity.raw_output_data_config or _common.RawOutputDataConfig(""), max_parallelism=entity.max_parallelism, ) lp_id = _identifier_model.Identifier( @@ -403,7 +406,7 @@ def get_serializable( def gather_dependent_entities( serialized: OrderedDict, ) -> Tuple[ - Dict[_identifier_model.Identifier, task_models.TaskTemplate], + Dict[_identifier_model.Identifier, flytekit.models.core.task.TaskTemplate], Dict[_identifier_model.Identifier, admin_workflow_models.WorkflowSpec], Dict[_identifier_model.Identifier, _launch_plan_models.LaunchPlanSpec], ]: @@ -416,7 +419,7 @@ def gather_dependent_entities( :param serialized: This should be the filled in OrderedDict used in the get_serializable function above. :return: """ - task_templates: Dict[_identifier_model.Identifier, task_models.TaskTemplate] = {} + task_templates: Dict[_identifier_model.Identifier, flytekit.models.core.task.TaskTemplate] = {} workflow_specs: Dict[_identifier_model.Identifier, admin_workflow_models.WorkflowSpec] = {} launch_plan_specs: Dict[_identifier_model.Identifier, _launch_plan_models.LaunchPlanSpec] = {} diff --git a/flytekit/common/types/base_sdk_types.py b/flytekit/common/types/base_sdk_types.py index ab12969865..6b99f2be2f 100644 --- a/flytekit/common/types/base_sdk_types.py +++ b/flytekit/common/types/base_sdk_types.py @@ -5,7 +5,7 @@ from flytekit.common import sdk_bases as _sdk_bases from flytekit.common.exceptions import user as _user_exceptions from flytekit.models import common as _common_models -from flytekit.models import literals as _literal_models +from flytekit.models.core import literals as _literal_models class FlyteSdkType(_sdk_bases.ExtendedSdkType, metaclass=_common_models.FlyteABCMeta): @@ -46,7 +46,7 @@ def promote_from_model(cls, literal): @_abc.abstractmethod def to_flyte_literal_type(cls): """ - :rtype: flytekit.models.types.LiteralType + :rtype: flytekit.models.core.types.LiteralType """ pass @@ -103,7 +103,7 @@ def from_python_std(cls, t_value): @classmethod def to_flyte_literal_type(cls): """ - :rtype: flytekit.models.types.LiteralType + :rtype: flytekit.models.core.types.LiteralType """ raise _user_exceptions.FlyteAssertion( "A Void type does not have a literal type and cannot be used in this " "manner." diff --git a/flytekit/common/types/blobs.py b/flytekit/common/types/blobs.py index 7870cb75bd..7a77e6b0b6 100644 --- a/flytekit/common/types/blobs.py +++ b/flytekit/common/types/blobs.py @@ -1,8 +1,8 @@ +import flytekit.models.core.types from flytekit.common.exceptions import user as _user_exceptions from flytekit.common.types import base_sdk_types as _base_sdk_types from flytekit.common.types.impl import blobs as _blob_impl -from flytekit.models import literals as _literals -from flytekit.models import types as _idl_types +from flytekit.models.core import literals as _literals from flytekit.models.core import types as _core_types @@ -78,9 +78,9 @@ def from_python_std(cls, t_value): @classmethod def to_flyte_literal_type(cls): """ - :rtype: flytekit.models.types.LiteralType + :rtype: flytekit.models.core.types.LiteralType """ - return _idl_types.LiteralType( + return flytekit.models.core.types.LiteralType( blob=_core_types.BlobType(format="", dimensionality=_core_types.BlobType.BlobDimensionality.SINGLE) ) @@ -201,7 +201,7 @@ def to_flyte_literal_type(cls): """ :rtype: flytekit.models.types.LiteralType """ - return _idl_types.LiteralType( + return flytekit.models.core.types.LiteralType( blob=_core_types.BlobType( format="", dimensionality=_core_types.BlobType.BlobDimensionality.MULTIPART, @@ -324,7 +324,7 @@ def to_flyte_literal_type(cls): """ :rtype: flytekit.models.types.LiteralType """ - return _idl_types.LiteralType( + return flytekit.models.core.types.LiteralType( blob=_core_types.BlobType( format="csv", dimensionality=_core_types.BlobType.BlobDimensionality.SINGLE, @@ -435,7 +435,7 @@ def to_flyte_literal_type(cls): """ :rtype: flytekit.models.types.LiteralType """ - return _idl_types.LiteralType( + return flytekit.models.core.types.LiteralType( blob=_core_types.BlobType( format="csv", dimensionality=_core_types.BlobType.BlobDimensionality.MULTIPART, diff --git a/flytekit/common/types/containers.py b/flytekit/common/types/containers.py index 9267273b1e..07468abbfa 100644 --- a/flytekit/common/types/containers.py +++ b/flytekit/common/types/containers.py @@ -2,10 +2,10 @@ import six as _six +import flytekit.models.core.types from flytekit.common.exceptions import user as _user_exceptions from flytekit.common.types import base_sdk_types as _base_sdk_types -from flytekit.models import literals as _literals -from flytekit.models import types as _idl_types +from flytekit.models.core import literals as _literals class CollectionType(_base_sdk_types.FlyteSdkType): @@ -100,9 +100,9 @@ def from_python_std(cls, t_value): @classmethod def to_flyte_literal_type(cls): """ - :rtype: flytekit.models.types.LiteralType + :rtype: flytekit.models.core.types.LiteralType """ - return _idl_types.LiteralType(collection_type=cls.sub_type.to_flyte_literal_type()) + return flytekit.models.core.types.LiteralType(collection_type=cls.sub_type.to_flyte_literal_type()) @classmethod def promote_from_model(cls, literal_model): diff --git a/flytekit/common/types/helpers.py b/flytekit/common/types/helpers.py index 92294f38fd..75c8bad95a 100644 --- a/flytekit/common/types/helpers.py +++ b/flytekit/common/types/helpers.py @@ -5,7 +5,7 @@ from flytekit.common.exceptions import scopes as _exception_scopes from flytekit.common.exceptions import user as _user_exceptions from flytekit.configuration import sdk as _sdk_config -from flytekit.models import literals as _literal_models +from flytekit.models.core import literals as _literal_models class _TypeEngineLoader(object): @@ -59,7 +59,7 @@ def python_std_to_sdk_type(t): def get_sdk_type_from_literal_type(literal_type): """ - :param flytekit.models.types.LiteralType literal_type: + :param flytekit.models.core.types.LiteralType literal_type: :rtype: flytekit.common.types.base_sdk_types.FlyteSdkType """ for e in _TypeEngineLoader.iterate_engines_in_order(): @@ -86,7 +86,7 @@ def infer_sdk_type_from_literal(literal): def get_sdk_value_from_literal(literal, sdk_type=None): """ :param flytekit.models.literals.Literal literal: - :param flytekit.models.types.LiteralType sdk_type: + :param flytekit.models.core.types.LiteralType sdk_type: :rtype: flytekit.common.types.base_sdk_types.FlyteSdkValue """ # The spec states everything must be nullable, so if we receive a null value, swap to the null type behavior. diff --git a/flytekit/common/types/impl/blobs.py b/flytekit/common/types/impl/blobs.py index 45210da62a..430d616fd0 100644 --- a/flytekit/common/types/impl/blobs.py +++ b/flytekit/common/types/impl/blobs.py @@ -10,7 +10,7 @@ from flytekit.common.exceptions import scopes as _exception_scopes from flytekit.common.exceptions import user as _user_exceptions from flytekit.interfaces.data import data_proxy as _data_proxy -from flytekit.models import literals as _literal_models +from flytekit.models.core import literals as _literal_models from flytekit.models.core import types as _core_types diff --git a/flytekit/common/types/impl/schema.py b/flytekit/common/types/impl/schema.py index 04e4109d1c..3c14658e35 100644 --- a/flytekit/common/types/impl/schema.py +++ b/flytekit/common/types/impl/schema.py @@ -4,6 +4,7 @@ import six as _six +import flytekit.models.core.types from flytekit.common import sdk_bases as _sdk_bases from flytekit.common import utils as _utils from flytekit.common.exceptions import scopes as _exception_scopes @@ -14,8 +15,7 @@ from flytekit.common.types.impl import blobs as _blob_impl from flytekit.configuration import sdk as _sdk_config from flytekit.interfaces.data import data_proxy as _data_proxy -from flytekit.models import literals as _literal_models -from flytekit.models import types as _type_models +from flytekit.models.core import literals as _literal_models from flytekit.plugins import numpy as _np from flytekit.plugins import pandas as _pd @@ -396,14 +396,14 @@ def __exit__(self, exc_type, exc_val, exc_tb): return super(_SchemaBackingMpBlob, self).__exit__(exc_type, exc_val, exc_tb) -class SchemaType(_type_models.SchemaType, metaclass=_sdk_bases.ExtendedSdkType): +class SchemaType(flytekit.models.core.types.SchemaType, metaclass=_sdk_bases.ExtendedSdkType): _LITERAL_TYPE_TO_PROTO_ENUM = { - _primitives.Integer.to_flyte_literal_type(): _type_models.SchemaType.SchemaColumn.SchemaColumnType.INTEGER, - _primitives.Float.to_flyte_literal_type(): _type_models.SchemaType.SchemaColumn.SchemaColumnType.FLOAT, - _primitives.Boolean.to_flyte_literal_type(): _type_models.SchemaType.SchemaColumn.SchemaColumnType.BOOLEAN, - _primitives.Datetime.to_flyte_literal_type(): _type_models.SchemaType.SchemaColumn.SchemaColumnType.DATETIME, - _primitives.Timedelta.to_flyte_literal_type(): _type_models.SchemaType.SchemaColumn.SchemaColumnType.DURATION, - _primitives.String.to_flyte_literal_type(): _type_models.SchemaType.SchemaColumn.SchemaColumnType.STRING, + _primitives.Integer.to_flyte_literal_type(): flytekit.models.core.types.SchemaType.SchemaColumn.SchemaColumnType.INTEGER, + _primitives.Float.to_flyte_literal_type(): flytekit.models.core.types.SchemaType.SchemaColumn.SchemaColumnType.FLOAT, + _primitives.Boolean.to_flyte_literal_type(): flytekit.models.core.types.SchemaType.SchemaColumn.SchemaColumnType.BOOLEAN, + _primitives.Datetime.to_flyte_literal_type(): flytekit.models.core.types.SchemaType.SchemaColumn.SchemaColumnType.DATETIME, + _primitives.Timedelta.to_flyte_literal_type(): flytekit.models.core.types.SchemaType.SchemaColumn.SchemaColumnType.DURATION, + _primitives.String.to_flyte_literal_type(): flytekit.models.core.types.SchemaType.SchemaColumn.SchemaColumnType.STRING, } def __init__(self, columns=None): @@ -425,7 +425,9 @@ def columns(self): :rtype: list[flytekit.models.types.SchemaType.SchemaColumn] """ return [ - _type_models.SchemaType.SchemaColumn(n, type(self)._LITERAL_TYPE_TO_PROTO_ENUM[v.to_flyte_literal_type()]) + flytekit.models.core.types.SchemaType.SchemaColumn( + n, type(self)._LITERAL_TYPE_TO_PROTO_ENUM[v.to_flyte_literal_type()] + ) for n, v in _six.iteritems(self.sdk_columns) ] @@ -436,22 +438,22 @@ def promote_from_model(cls, model): :rtype: SchemaType """ _PROTO_ENUM_TO_SDK_TYPE = { - _type_models.SchemaType.SchemaColumn.SchemaColumnType.INTEGER: _helpers.get_sdk_type_from_literal_type( + flytekit.models.core.types.SchemaType.SchemaColumn.SchemaColumnType.INTEGER: _helpers.get_sdk_type_from_literal_type( _primitives.Integer.to_flyte_literal_type() ), - _type_models.SchemaType.SchemaColumn.SchemaColumnType.FLOAT: _helpers.get_sdk_type_from_literal_type( + flytekit.models.core.types.SchemaType.SchemaColumn.SchemaColumnType.FLOAT: _helpers.get_sdk_type_from_literal_type( _primitives.Float.to_flyte_literal_type() ), - _type_models.SchemaType.SchemaColumn.SchemaColumnType.BOOLEAN: _helpers.get_sdk_type_from_literal_type( + flytekit.models.core.types.SchemaType.SchemaColumn.SchemaColumnType.BOOLEAN: _helpers.get_sdk_type_from_literal_type( _primitives.Boolean.to_flyte_literal_type() ), - _type_models.SchemaType.SchemaColumn.SchemaColumnType.DATETIME: _helpers.get_sdk_type_from_literal_type( + flytekit.models.core.types.SchemaType.SchemaColumn.SchemaColumnType.DATETIME: _helpers.get_sdk_type_from_literal_type( _primitives.Datetime.to_flyte_literal_type() ), - _type_models.SchemaType.SchemaColumn.SchemaColumnType.DURATION: _helpers.get_sdk_type_from_literal_type( + flytekit.models.core.types.SchemaType.SchemaColumn.SchemaColumnType.DURATION: _helpers.get_sdk_type_from_literal_type( _primitives.Timedelta.to_flyte_literal_type() ), - _type_models.SchemaType.SchemaColumn.SchemaColumnType.STRING: _helpers.get_sdk_type_from_literal_type( + flytekit.models.core.types.SchemaType.SchemaColumn.SchemaColumnType.STRING: _helpers.get_sdk_type_from_literal_type( _primitives.String.to_flyte_literal_type() ), } diff --git a/flytekit/common/types/primitives.py b/flytekit/common/types/primitives.py index 446140dc2c..9e6b1411db 100644 --- a/flytekit/common/types/primitives.py +++ b/flytekit/common/types/primitives.py @@ -8,10 +8,10 @@ from google.protobuf import struct_pb2 as _struct from pytimeparse import parse as _parse_duration_string +import flytekit.models.core.types from flytekit.common.exceptions import user as _user_exceptions from flytekit.common.types import base_sdk_types as _base_sdk_types -from flytekit.models import literals as _literals -from flytekit.models import types as _idl_types +from flytekit.models.core import literals as _literals class Integer(_base_sdk_types.FlyteSdkValue): @@ -54,9 +54,9 @@ def from_python_std(cls, t_value): @classmethod def to_flyte_literal_type(cls): """ - :rtype: flytekit.models.types.LiteralType + :rtype: flytekit.models.core.types.LiteralType """ - return _idl_types.LiteralType(simple=_idl_types.SimpleType.INTEGER) + return flytekit.models.core.types.LiteralType(simple=flytekit.models.core.types.SimpleType.INTEGER) @classmethod def promote_from_model(cls, literal_model): @@ -135,7 +135,7 @@ def to_flyte_literal_type(cls): """ :rtype: flytekit.models.types.LiteralType """ - return _idl_types.LiteralType(simple=_idl_types.SimpleType.FLOAT) + return flytekit.models.core.types.LiteralType(simple=flytekit.models.core.types.SimpleType.FLOAT) @classmethod def promote_from_model(cls, literal_model): @@ -215,7 +215,7 @@ def to_flyte_literal_type(cls): """ :rtype: flytekit.models.types.LiteralType """ - return _idl_types.LiteralType(simple=_idl_types.SimpleType.BOOLEAN) + return flytekit.models.core.types.LiteralType(simple=flytekit.models.core.types.SimpleType.BOOLEAN) @classmethod def promote_from_model(cls, literal_model): @@ -294,7 +294,7 @@ def to_flyte_literal_type(cls): """ :rtype: flytekit.models.types.LiteralType """ - return _idl_types.LiteralType(simple=_idl_types.SimpleType.STRING) + return flytekit.models.core.types.LiteralType(simple=flytekit.models.core.types.SimpleType.STRING) @classmethod def promote_from_model(cls, literal_model): @@ -390,7 +390,7 @@ def to_flyte_literal_type(cls): """ :rtype: flytekit.models.types.LiteralType """ - return _idl_types.LiteralType(simple=_idl_types.SimpleType.DATETIME) + return flytekit.models.core.types.LiteralType(simple=flytekit.models.core.types.SimpleType.DATETIME) @classmethod def promote_from_model(cls, literal_model): @@ -470,7 +470,7 @@ def to_flyte_literal_type(cls): """ :rtype: flytekit.models.types.LiteralType """ - return _idl_types.LiteralType(simple=_idl_types.SimpleType.DURATION) + return flytekit.models.core.types.LiteralType(simple=flytekit.models.core.types.SimpleType.DURATION) @classmethod def promote_from_model(cls, literal_model): @@ -552,7 +552,9 @@ def to_flyte_literal_type(cls, metadata: typing.Dict = None): """ :rtype: flytekit.models.types.LiteralType """ - return _idl_types.LiteralType(simple=_idl_types.SimpleType.STRUCT, metadata=metadata) + return flytekit.models.core.types.LiteralType( + simple=flytekit.models.core.types.SimpleType.STRUCT, metadata=metadata + ) @classmethod def promote_from_model(cls, literal_model): diff --git a/flytekit/common/types/proto.py b/flytekit/common/types/proto.py index 2c3423ccfb..d4db6ac5e6 100644 --- a/flytekit/common/types/proto.py +++ b/flytekit/common/types/proto.py @@ -9,12 +9,12 @@ from google.protobuf.reflection import GeneratedProtocolMessageType from google.protobuf.struct_pb2 import Struct +import flytekit.models.core.types from flytekit.common.exceptions import user as _user_exceptions from flytekit.common.types import base_sdk_types as _base_sdk_types -from flytekit.models import literals as _literals -from flytekit.models import types as _idl_types from flytekit.models.common import FlyteIdlEntity, FlyteType -from flytekit.models.types import LiteralType +from flytekit.models.core import literals as _literals +from flytekit.models.core.types import LiteralType ProtobufT = Type[_proto_reflection.GeneratedProtocolMessageType] @@ -118,9 +118,11 @@ def from_python_std(cls, t_value): @classmethod def to_flyte_literal_type(cls): """ - :rtype: flytekit.models.types.LiteralType + :rtype: flytekit.models.core.types.LiteralType """ - return _idl_types.LiteralType(simple=_idl_types.SimpleType.BINARY, metadata={cls.PB_FIELD_KEY: cls.descriptor}) + return flytekit.models.core.types.LiteralType( + simple=flytekit.models.core.types.SimpleType.BINARY, metadata={cls.PB_FIELD_KEY: cls.descriptor} + ) @classmethod def promote_from_model(cls, literal_model): @@ -242,7 +244,9 @@ def to_flyte_literal_type(cls) -> LiteralType: """ :rtype: flytekit.models.types.LiteralType """ - return _idl_types.LiteralType(simple=_idl_types.SimpleType.STRUCT, metadata={cls.PB_FIELD_KEY: cls.descriptor}) + return flytekit.models.core.types.LiteralType( + simple=flytekit.models.core.types.SimpleType.STRUCT, metadata={cls.PB_FIELD_KEY: cls.descriptor} + ) @classmethod def promote_from_model(cls, literal_model): diff --git a/flytekit/common/types/schema.py b/flytekit/common/types/schema.py index eaf38d1c88..535bfe781e 100644 --- a/flytekit/common/types/schema.py +++ b/flytekit/common/types/schema.py @@ -1,8 +1,8 @@ +import flytekit.models.core.types from flytekit.common.exceptions import user as _user_exceptions from flytekit.common.types import base_sdk_types as _base_sdk_types from flytekit.common.types.impl import schema as _schema_impl -from flytekit.models import literals as _literals -from flytekit.models import types as _idl_types +from flytekit.models.core import literals as _literals class SchemaInstantiator(_base_sdk_types.InstantiableType): @@ -118,9 +118,9 @@ def from_python_std(cls, t_value): @classmethod def to_flyte_literal_type(cls): """ - :rtype: flytekit.models.types.LiteralType + :rtype: flytekit.models.core.types.LiteralType """ - return _idl_types.LiteralType(schema=cls.schema_type) + return flytekit.models.core.types.LiteralType(schema=cls.schema_type) @classmethod def promote_from_model(cls, literal_model): @@ -179,7 +179,7 @@ class _Schema(Schema, metaclass=SchemaInstantiator): def schema_instantiator_from_proto(schema_type): """ - :param flytekit.models.types.SchemaType schema_type: + :param flytekit.models.core.types.SchemaType schema_type: :rtype: SchemaInstantiator """ diff --git a/flytekit/common/workflow.py b/flytekit/common/workflow.py index 3fc4f498fd..13ad93520e 100644 --- a/flytekit/common/workflow.py +++ b/flytekit/common/workflow.py @@ -1,6 +1,8 @@ import datetime as _datetime from typing import List +import flytekit.models.admin.common +import flytekit.models.admin.launch_plan from flytekit.common import constants as _constants from flytekit.common import interface as _interface from flytekit.common import nodes as _nodes @@ -15,13 +17,13 @@ from flytekit.configuration import auth as _auth_config from flytekit.configuration import internal as _internal_config from flytekit.engines.flyte import engine as _flyte_engine -from flytekit.models import common as _common_models -from flytekit.models import interface as _interface_models -from flytekit.models import launch_plan as _launch_plan_models -from flytekit.models import literals as _literal_models -from flytekit.models import schedule as _schedule_models +from flytekit.models.admin import common as _common +from flytekit.models.admin import launch_plan as _launch_plan_models +from flytekit.models.admin import schedule as _schedule_models from flytekit.models.admin import workflow as _admin_workflow_model from flytekit.models.core import identifier as _identifier_model +from flytekit.models.core import interface as _interface_models +from flytekit.models.core import literals as _literal_models from flytekit.models.core import workflow as _workflow_models @@ -184,7 +186,7 @@ def promote_from_model(cls, base_model, sub_workflows=None, tasks=None): sub_workflows: Provide a list of WorkflowTemplate models (should be returned from Admin as part of the admin CompiledWorkflowClosure. Relevant sub-workflows should always be provided. - :param dict[flytekit.models.core.identifier.Identifier, flytekit.models.task.TaskTemplate] tasks: Same as above + :param dict[flytekit.models.core.identifier.Identifier, flytekit.models.admin.task.TaskTemplate] tasks: Same as above but for tasks. If tasks are not provided relevant TaskTemplates will be fetched from Admin :rtype: SdkWorkflow """ @@ -269,7 +271,7 @@ def create_launch_plan(self, *args, **kwargs): if not (assumable_iam_role or kubernetes_service_account): raise _user_exceptions.FlyteValidationException("No assumable role or service account found") - auth_role = _common_models.AuthRole( + auth_role = flytekit.models.admin.common.AuthRole( assumable_iam_role=assumable_iam_role, kubernetes_service_account=kubernetes_service_account, ) @@ -282,10 +284,10 @@ def create_launch_plan(self, *args, **kwargs): ), default_inputs=_interface_models.ParameterMap({}), fixed_inputs=_literal_models.LiteralMap(literals={}), - labels=_common_models.Labels({}), - annotations=_common_models.Annotations({}), + labels=_common.Labels({}), + annotations=_common.Annotations({}), auth_role=auth_role, - raw_output_data_config=_common_models.RawOutputDataConfig(""), + raw_output_data_config=_common.RawOutputDataConfig(""), ) @_exception_scopes.system_entry_point diff --git a/flytekit/common/workflow_execution.py b/flytekit/common/workflow_execution.py index 14695d0e68..c4d41580ed 100644 --- a/flytekit/common/workflow_execution.py +++ b/flytekit/common/workflow_execution.py @@ -13,9 +13,9 @@ from flytekit.common.types import helpers as _type_helpers from flytekit.engines.flyte import engine as _flyte_engine from flytekit.interfaces.data import data_proxy as _data_proxy -from flytekit.models import execution as _execution_models -from flytekit.models import literals as _literal_models +from flytekit.models.admin import execution as _execution_models from flytekit.models.core import execution as _core_execution_models +from flytekit.models.core import literals as _literal_models class SdkWorkflowExecution( diff --git a/flytekit/contrib/notebook/tasks.py b/flytekit/contrib/notebook/tasks.py index 5836717574..2bf2285a92 100644 --- a/flytekit/contrib/notebook/tasks.py +++ b/flytekit/contrib/notebook/tasks.py @@ -9,6 +9,7 @@ from google.protobuf import json_format as _json_format from google.protobuf import text_format as _text_format +import flytekit.models.core.task from flytekit import __version__ from flytekit.bin import entrypoint as _entrypoint from flytekit.common import constants as _constants @@ -22,9 +23,11 @@ from flytekit.common.types import helpers as _type_helpers from flytekit.contrib.notebook.supported_types import notebook_types_map as _notebook_types_map from flytekit.engines import loader as _engine_loader -from flytekit.models import interface as _interface -from flytekit.models import literals as _literal_models -from flytekit.models import task as _task_models +from flytekit.models.core import interface as _interface +from flytekit.models.core import literals as _literal_models +from flytekit.models.core.task import RuntimeMetadata as _runtimeMetadata +from flytekit.models.core.task import TaskMetadata as _task_matadata +from flytekit.models.plugins import task as _task_models from flytekit.plugins import papermill as _pm from flytekit.sdk.spark_types import SparkType as _spark_type from flytekit.sdk.types import Types as _Types @@ -120,10 +123,10 @@ def __init__( super(SdkNotebookTask, self).__init__( task_type, - _task_models.TaskMetadata( + _task_matadata( discoverable, - _task_models.RuntimeMetadata( - _task_models.RuntimeMetadata.RuntimeType.FLYTE_SDK, + _runtimeMetadata( + _runtimeMetadata.RuntimeType.FLYTE_SDK, __version__, "notebook", ), @@ -496,7 +499,7 @@ def _get_container_definition(self, environment=None, **kwargs): return _spark_task.SdkRunnableSparkContainer( command=[], args=[], - resources=_task_models.Resources(limits=[], requests=[]), + resources=flytekit.models.core.task.Resources(limits=[], requests=[]), env=environment or {}, config={}, ) diff --git a/flytekit/core/base_task.py b/flytekit/core/base_task.py index 570e1c4d0c..3d34725827 100644 --- a/flytekit/core/base_task.py +++ b/flytekit/core/base_task.py @@ -23,6 +23,7 @@ from dataclasses import dataclass from typing import Any, Dict, Generic, List, Optional, Tuple, Type, TypeVar, Union +import flytekit.models.core.task from flytekit.common.tasks.sdk_runnable import ExecutionParameters from flytekit.core.context_manager import FlyteContext, FlyteContextManager, FlyteEntities, SerializationSettings from flytekit.core.interface import Interface, transform_interface_to_typed_interface @@ -39,13 +40,14 @@ from flytekit.core.tracker import TrackedInstance from flytekit.core.type_engine import TypeEngine from flytekit.loggers import logger -from flytekit.models import dynamic_job as _dynamic_job -from flytekit.models import interface as _interface_models -from flytekit.models import literals as _literal_models -from flytekit.models import task as _task_model +from flytekit.models.core import dynamic_job as _dynamic_job +from flytekit.models.core import interface as _interface_models +from flytekit.models.core import literals as _literal_models from flytekit.models.core import workflow as _workflow_model -from flytekit.models.interface import Variable -from flytekit.models.security import SecurityContext +from flytekit.models.core.interface import Variable +from flytekit.models.core.security import SecurityContext +from flytekit.models.core.task import RuntimeMetadata as _runtimeMetadata +from flytekit.models.core.task import TaskMetadata as _task_matadata def kwtypes(**kwargs) -> Dict[str, Type]: @@ -104,17 +106,15 @@ def __post_init__(self): def retry_strategy(self) -> _literal_models.RetryStrategy: return _literal_models.RetryStrategy(self.retries) - def to_taskmetadata_model(self) -> _task_model.TaskMetadata: + def to_taskmetadata_model(self) -> _task_matadata: """ Converts to _task_model.TaskMetadata """ from flytekit import __version__ - return _task_model.TaskMetadata( + return _task_matadata( discoverable=self.cache, - runtime=_task_model.RuntimeMetadata( - _task_model.RuntimeMetadata.RuntimeType.FLYTE_SDK, __version__, "python" - ), + runtime=_runtimeMetadata(_runtimeMetadata.RuntimeType.FLYTE_SDK, __version__, "python"), timeout=self.timeout, retries=self.retry_strategy, interruptible=self.interruptible, @@ -277,19 +277,19 @@ def __call__(self, *args, **kwargs): def compile(self, ctx: FlyteContext, *args, **kwargs): raise Exception("not implemented") - def get_container(self, settings: SerializationSettings) -> _task_model.Container: + def get_container(self, settings: SerializationSettings) -> flytekit.models.core.task.Container: """ Returns the container definition (if any) that is used to run the task on hosted Flyte. """ return None - def get_k8s_pod(self, settings: SerializationSettings) -> _task_model.K8sPod: + def get_k8s_pod(self, settings: SerializationSettings) -> flytekit.models.core.task.K8sPod: """ Returns the kubernetes pod definition (if any) that is used to run the task on hosted Flyte. """ return None - def get_sql(self, settings: SerializationSettings) -> Optional[_task_model.Sql]: + def get_sql(self, settings: SerializationSettings) -> Optional[flytekit.models.core.task.Sql]: """ Returns the Sql definition (if any) that is used to run the task on hosted Flyte. """ diff --git a/flytekit/core/condition.py b/flytekit/core/condition.py index a38d0e2ab1..c1980129d2 100644 --- a/flytekit/core/condition.py +++ b/flytekit/core/condition.py @@ -18,8 +18,8 @@ ) from flytekit.models.core import condition as _core_cond from flytekit.models.core import workflow as _core_wf -from flytekit.models.literals import Binding, BindingData, Literal, RetryStrategy -from flytekit.models.types import Error +from flytekit.models.core.literals import Binding, BindingData, Literal, RetryStrategy +from flytekit.models.core.types import Error class BranchNode(object): diff --git a/flytekit/core/container_task.py b/flytekit/core/container_task.py index 75c6449c07..f9a59c7dea 100644 --- a/flytekit/core/container_task.py +++ b/flytekit/core/container_task.py @@ -1,12 +1,12 @@ from enum import Enum from typing import Any, Dict, List, Optional, Type +import flytekit.models.core.task from flytekit.common.tasks.raw_container import _get_container_definition from flytekit.core.base_task import PythonTask, TaskMetadata from flytekit.core.context_manager import SerializationSettings from flytekit.core.interface import Interface from flytekit.core.resources import Resources, ResourceSpec -from flytekit.models import task as _task_model class ContainerTask(PythonTask): @@ -17,17 +17,17 @@ class ContainerTask(PythonTask): """ class MetadataFormat(Enum): - JSON = _task_model.DataLoadingConfig.LITERALMAP_FORMAT_JSON - YAML = _task_model.DataLoadingConfig.LITERALMAP_FORMAT_YAML - PROTO = _task_model.DataLoadingConfig.LITERALMAP_FORMAT_PROTO + JSON = flytekit.models.core.task.DataLoadingConfig.LITERALMAP_FORMAT_JSON + YAML = flytekit.models.core.task.DataLoadingConfig.LITERALMAP_FORMAT_YAML + PROTO = flytekit.models.core.task.DataLoadingConfig.LITERALMAP_FORMAT_PROTO class IOStrategy(Enum): - DOWNLOAD_EAGER = _task_model.IOStrategy.DOWNLOAD_MODE_EAGER - DOWNLOAD_STREAM = _task_model.IOStrategy.DOWNLOAD_MODE_STREAM - DO_NOT_DOWNLOAD = _task_model.IOStrategy.DOWNLOAD_MODE_NO_DOWNLOAD - UPLOAD_EAGER = _task_model.IOStrategy.UPLOAD_MODE_EAGER - UPLOAD_ON_EXIT = _task_model.IOStrategy.UPLOAD_MODE_ON_EXIT - DO_NOT_UPLOAD = _task_model.IOStrategy.UPLOAD_MODE_NO_UPLOAD + DOWNLOAD_EAGER = flytekit.models.core.task.IOStrategy.DOWNLOAD_MODE_EAGER + DOWNLOAD_STREAM = flytekit.models.core.task.IOStrategy.DOWNLOAD_MODE_STREAM + DO_NOT_DOWNLOAD = flytekit.models.core.task.IOStrategy.DOWNLOAD_MODE_NO_DOWNLOAD + UPLOAD_EAGER = flytekit.models.core.task.IOStrategy.UPLOAD_MODE_EAGER + UPLOAD_ON_EXIT = flytekit.models.core.task.IOStrategy.UPLOAD_MODE_ON_EXIT + DO_NOT_UPLOAD = flytekit.models.core.task.IOStrategy.UPLOAD_MODE_NO_UPLOAD def __init__( self, @@ -80,13 +80,13 @@ def execute(self, **kwargs) -> Any: ) return None - def get_container(self, settings: SerializationSettings) -> _task_model.Container: + def get_container(self, settings: SerializationSettings) -> flytekit.models.core.task.Container: env = {**settings.env, **self.environment} if self.environment else settings.env return _get_container_definition( image=self._image, command=self._cmd, args=self._args, - data_loading_config=_task_model.DataLoadingConfig( + data_loading_config=flytekit.models.core.task.DataLoadingConfig( input_path=self._input_data_dir, output_path=self._output_data_dir, format=self._md_format.value, diff --git a/flytekit/core/interface.py b/flytekit/core/interface.py index 842a4198b5..e56f7f091f 100644 --- a/flytekit/core/interface.py +++ b/flytekit/core/interface.py @@ -12,7 +12,7 @@ from flytekit.core.docstring import Docstring from flytekit.core.type_engine import TypeEngine from flytekit.loggers import logger -from flytekit.models import interface as _interface_models +from flytekit.models.core import interface as _interface_models class Interface(object): diff --git a/flytekit/core/launch_plan.py b/flytekit/core/launch_plan.py index 481871977e..d018987792 100644 --- a/flytekit/core/launch_plan.py +++ b/flytekit/core/launch_plan.py @@ -3,15 +3,17 @@ import inspect from typing import Any, Callable, Dict, List, Optional, Type +import flytekit.models.admin.common +import flytekit.models.admin.launch_plan from flytekit.core import workflow as _annotated_workflow from flytekit.core.context_manager import FlyteContext, FlyteContextManager, FlyteEntities from flytekit.core.interface import Interface, transform_inputs_to_parameters, transform_signature_to_interface from flytekit.core.promise import create_and_link_node, translate_inputs_to_literals from flytekit.core.reference_entity import LaunchPlanReference, ReferenceEntity -from flytekit.models import common as _common_models -from flytekit.models import interface as _interface_models -from flytekit.models import literals as _literal_models -from flytekit.models import schedule as _schedule_model +from flytekit.models.admin import common as _admin_common +from flytekit.models.admin import schedule as _schedule_model +from flytekit.models.core import interface as _interface_models +from flytekit.models.core import literals as _literal_models from flytekit.models.core import workflow as _workflow_model @@ -109,11 +111,11 @@ def create( default_inputs: Dict[str, Any] = None, fixed_inputs: Dict[str, Any] = None, schedule: _schedule_model.Schedule = None, - notifications: List[_common_models.Notification] = None, - labels: _common_models.Labels = None, - annotations: _common_models.Annotations = None, - raw_output_data_config: _common_models.RawOutputDataConfig = None, - auth_role: _common_models.AuthRole = None, + notifications: List[_admin_common.Notification] = None, + labels: _admin_common.Labels = None, + annotations: _admin_common.Annotations = None, + raw_output_data_config: _admin_common.RawOutputDataConfig = None, + auth_role: flytekit.models.admin.common.AuthRole = None, max_parallelism: int = None, ) -> LaunchPlan: ctx = FlyteContextManager.current_context() @@ -175,11 +177,11 @@ def get_or_create( default_inputs: Dict[str, Any] = None, fixed_inputs: Dict[str, Any] = None, schedule: _schedule_model.Schedule = None, - notifications: List[_common_models.Notification] = None, - labels: _common_models.Labels = None, - annotations: _common_models.Annotations = None, - raw_output_data_config: _common_models.RawOutputDataConfig = None, - auth_role: _common_models.AuthRole = None, + notifications: List[_admin_common.Notification] = None, + labels: _admin_common.Labels = None, + annotations: _admin_common.Annotations = None, + raw_output_data_config: _admin_common.RawOutputDataConfig = None, + auth_role: flytekit.models.admin.common.AuthRole = None, max_parallelism: int = None, ) -> LaunchPlan: """ @@ -275,11 +277,11 @@ def __init__( parameters: _interface_models.ParameterMap, fixed_inputs: _literal_models.LiteralMap, schedule: _schedule_model.Schedule = None, - notifications: List[_common_models.Notification] = None, - labels: _common_models.Labels = None, - annotations: _common_models.Annotations = None, - raw_output_data_config: _common_models.RawOutputDataConfig = None, - auth_role: _common_models.AuthRole = None, + notifications: List[_admin_common.Notification] = None, + labels: _admin_common.Labels = None, + annotations: _admin_common.Annotations = None, + raw_output_data_config: _admin_common.RawOutputDataConfig = None, + auth_role: flytekit.models.admin.common.AuthRole = None, max_parallelism: int = None, ): self._name = name @@ -334,19 +336,19 @@ def schedule(self) -> Optional[_schedule_model.Schedule]: return self._schedule @property - def notifications(self) -> List[_common_models.Notification]: + def notifications(self) -> List[_admin_common.Notification]: return self._notifications @property - def labels(self) -> Optional[_common_models.Labels]: + def labels(self) -> Optional[_admin_common.Labels]: return self._labels @property - def annotations(self) -> Optional[_common_models.Annotations]: + def annotations(self) -> Optional[_admin_common.Annotations]: return self._annotations @property - def raw_output_data_config(self) -> Optional[_common_models.RawOutputDataConfig]: + def raw_output_data_config(self) -> Optional[_admin_common.RawOutputDataConfig]: return self._raw_output_data_config @property diff --git a/flytekit/core/local_cache.py b/flytekit/core/local_cache.py index 9180224207..5e28c3305f 100644 --- a/flytekit/core/local_cache.py +++ b/flytekit/core/local_cache.py @@ -2,7 +2,7 @@ from diskcache import Cache -from flytekit.models.literals import LiteralMap +from flytekit.models.core.literals import LiteralMap # Location on the filesystem where serialized objects will be stored # TODO: read from config diff --git a/flytekit/core/map_task.py b/flytekit/core/map_task.py index fa48b474ca..b1ffe0fe8a 100644 --- a/flytekit/core/map_task.py +++ b/flytekit/core/map_task.py @@ -14,9 +14,9 @@ from flytekit.core.context_manager import ExecutionState, FlyteContext, FlyteContextManager, SerializationSettings from flytekit.core.interface import transform_interface_to_list_interface from flytekit.core.python_function_task import PythonFunctionTask -from flytekit.models.array_job import ArrayJob -from flytekit.models.interface import Variable -from flytekit.models.task import Container, K8sPod, Sql +from flytekit.models.core.interface import Variable +from flytekit.models.core.task import Container, K8sPod, Sql +from flytekit.models.plugins.array_job import ArrayJob class MapPythonTask(PythonTask): diff --git a/flytekit/core/node.py b/flytekit/core/node.py index 0567e7b64a..ffacefcdec 100644 --- a/flytekit/core/node.py +++ b/flytekit/core/node.py @@ -6,9 +6,9 @@ from flytekit.common.utils import _dnsify from flytekit.core.resources import Resources -from flytekit.models import literals as _literal_models +from flytekit.models.core import literals as _literal_models from flytekit.models.core import workflow as _workflow_model -from flytekit.models.task import Resources as _resources_model +from flytekit.models.core.task import Resources as _resources_model class Node(object): diff --git a/flytekit/core/notification.py b/flytekit/core/notification.py index 763b7f417f..deba8f911f 100644 --- a/flytekit/core/notification.py +++ b/flytekit/core/notification.py @@ -17,7 +17,7 @@ """ from typing import List -from flytekit.models import common as _common_model +from flytekit.models.admin import common as _common_model from flytekit.models.core import execution as _execution_model diff --git a/flytekit/core/promise.py b/flytekit/core/promise.py index 81163ea96f..475576afe4 100644 --- a/flytekit/core/promise.py +++ b/flytekit/core/promise.py @@ -7,6 +7,7 @@ from typing_extensions import Protocol +import flytekit.models.core.types from flytekit.common import constants as _common_constants from flytekit.common.exceptions import user as _user_exceptions from flytekit.core import context_manager as _flyte_context @@ -16,13 +17,11 @@ from flytekit.core.interface import Interface from flytekit.core.node import Node from flytekit.core.type_engine import DictTransformer, ListTransformer, TypeEngine -from flytekit.models import interface as _interface_models -from flytekit.models import literals as _literal_models -from flytekit.models import literals as _literals_models -from flytekit.models import types as _type_models -from flytekit.models import types as type_models +from flytekit.models.core import interface as _interface_models +from flytekit.models.core import literals as _literal_models +from flytekit.models.core import literals as _literals_models from flytekit.models.core import workflow as _workflow_model -from flytekit.models.literals import Primitive +from flytekit.models.core.literals import Primitive def translate_inputs_to_literals( @@ -64,7 +63,7 @@ def my_wf(in1: int, in2: int) -> int: """ def extract_value( - ctx: FlyteContext, input_val: Any, val_type: type, flyte_literal_type: _type_models.LiteralType + ctx: FlyteContext, input_val: Any, val_type: type, flyte_literal_type: flytekit.models.core.types.LiteralType ) -> _literal_models.Literal: if isinstance(input_val, list): if flyte_literal_type.collection_type is None: @@ -80,11 +79,11 @@ def extract_value( elif isinstance(input_val, dict): if ( flyte_literal_type.map_value_type is None - and flyte_literal_type.simple != _type_models.SimpleType.STRUCT + and flyte_literal_type.simple != flytekit.models.core.types.SimpleType.STRUCT ): raise TypeError(f"Not a map type {flyte_literal_type} but got a map {input_val}") k_type, sub_type = DictTransformer.get_dict_types(val_type) # type: ignore - if flyte_literal_type.simple == _type_models.SimpleType.STRUCT: + if flyte_literal_type.simple == flytekit.models.core.types.SimpleType.STRUCT: return TypeEngine.to_literal(ctx, input_val, type(input_val), flyte_literal_type) else: literal_map = { @@ -549,7 +548,7 @@ def __rshift__(self, other: Any): def binding_data_from_python_std( ctx: _flyte_context.FlyteContext, - expected_literal_type: _type_models.LiteralType, + expected_literal_type: flytekit.models.core.types.LiteralType, t_value: typing.Any, t_value_type: type, ) -> _literals_models.BindingData: @@ -579,13 +578,13 @@ def binding_data_from_python_std( elif isinstance(t_value, dict): if ( expected_literal_type.map_value_type is None - and expected_literal_type.simple != _type_models.SimpleType.STRUCT + and expected_literal_type.simple != flytekit.models.core.types.SimpleType.STRUCT ): raise AssertionError( f"this should be a Dictionary type and it is not: {type(t_value)} vs {expected_literal_type}" ) k_type, v_type = DictTransformer.get_dict_types(t_value_type) - if expected_literal_type.simple == _type_models.SimpleType.STRUCT: + if expected_literal_type.simple == flytekit.models.core.types.SimpleType.STRUCT: lit = TypeEngine.to_literal(ctx, t_value, type(t_value), expected_literal_type) return _literals_models.BindingData(scalar=lit.scalar) else: @@ -613,7 +612,7 @@ def binding_data_from_python_std( def binding_from_python_std( ctx: _flyte_context.FlyteContext, var_name: str, - expected_literal_type: _type_models.LiteralType, + expected_literal_type: flytekit.models.core.types.LiteralType, t_value: typing.Any, t_value_type: type, ) -> _literals_models.Binding: @@ -690,7 +689,7 @@ def __repr__(self): raise AssertionError(f"Task {self._task_name} returns nothing, NoneType return cannot be used") -class NodeOutput(type_models.OutputReference): +class NodeOutput(flytekit.models.core.types.OutputReference): def __init__(self, node: Node, var: str): """ :param node: diff --git a/flytekit/core/python_auto_container.py b/flytekit/core/python_auto_container.py index 689f5781dc..4ce4d3c28b 100644 --- a/flytekit/core/python_auto_container.py +++ b/flytekit/core/python_auto_container.py @@ -4,6 +4,7 @@ import re from typing import Callable, Dict, List, Optional, TypeVar +import flytekit.models.core.task from flytekit.common.tasks.raw_container import _get_container_definition from flytekit.core.base_task import PythonTask, TaskResolverMixin from flytekit.core.context_manager import FlyteContextManager, ImageConfig, SerializationSettings @@ -11,8 +12,7 @@ from flytekit.core.tracked_abc import FlyteTrackedABC from flytekit.core.tracker import TrackedInstance from flytekit.loggers import logger -from flytekit.models import task as _task_model -from flytekit.models.security import Secret, SecurityContext +from flytekit.models.core.security import Secret, SecurityContext T = TypeVar("T") @@ -149,7 +149,7 @@ def get_command(self, settings: SerializationSettings) -> List[str]: """ return self._get_command_fn(settings) - def get_container(self, settings: SerializationSettings) -> _task_model.Container: + def get_container(self, settings: SerializationSettings) -> flytekit.models.core.task.Container: env = {**settings.env, **self.environment} if self.environment else settings.env return _get_container_definition( image=get_registerable_container_image(self.container_image, settings.image_config), diff --git a/flytekit/core/python_customized_container_task.py b/flytekit/core/python_customized_container_task.py index eaeb509d2e..b0237de9cc 100644 --- a/flytekit/core/python_customized_container_task.py +++ b/flytekit/core/python_customized_container_task.py @@ -5,6 +5,7 @@ from flyteidl.core import tasks_pb2 as _tasks_pb2 +import flytekit.models.core.task from flytekit.common import utils as common_utils from flytekit.common.tasks.raw_container import _get_container_definition from flytekit.core.base_task import PythonTask, Task, TaskResolverMixin @@ -13,9 +14,9 @@ from flytekit.core.shim_task import ExecutableTemplateShimTask, ShimTaskExecutor from flytekit.core.tracker import TrackedInstance from flytekit.loggers import logger -from flytekit.models import task as _task_model from flytekit.models.core import identifier as identifier_models -from flytekit.models.security import Secret, SecurityContext +from flytekit.models.core.security import Secret, SecurityContext +from flytekit.models.core.task import TaskTemplate as _taskTemplate from flytekit.tools.module_loader import load_object_from_module TC = TypeVar("TC") @@ -130,7 +131,7 @@ def task_resolver(self) -> TaskTemplateResolver: return self._task_resolver @property - def task_template(self) -> Optional[_task_model.TaskTemplate]: + def task_template(self) -> Optional[_taskTemplate]: """ Override the base class implementation to serialize on first call. """ @@ -157,7 +158,7 @@ def get_command(self, settings: SerializationSettings) -> List[str]: return container_args - def get_container(self, settings: SerializationSettings) -> _task_model.Container: + def get_container(self, settings: SerializationSettings) -> flytekit.models.core.task.Container: env = {**settings.env, **self.environment} if self.environment else settings.env return _get_container_definition( image=self.container_image, @@ -175,13 +176,13 @@ def get_container(self, settings: SerializationSettings) -> _task_model.Containe memory_limit=self.resources.limits.mem, ) - def serialize_to_model(self, settings: SerializationSettings) -> _task_model.TaskTemplate: + def serialize_to_model(self, settings: SerializationSettings) -> _taskTemplate: # This doesn't get called from translator unfortunately. Will need to move the translator to use the model # objects directly first. # Note: This doesn't settle the issue of duplicate registrations. We'll need to figure that out somehow. # TODO: After new control plane classes are in, promote the template to a FlyteTask, so that authors of # customized-container tasks have a familiar thing to work with. - obj = _task_model.TaskTemplate( + obj = _taskTemplate( identifier_models.Identifier( identifier_models.ResourceType.TASK, settings.project, settings.domain, self.name, settings.version ), @@ -233,7 +234,7 @@ def load_task(self, loader_args: List[str]) -> ExecutableTemplateShimTask: task_template_local_path = os.path.join(ctx.execution_state.working_dir, "task_template.pb") # type: ignore ctx.file_access.get_data(loader_args[0], task_template_local_path) task_template_proto = common_utils.load_proto_from_file(_tasks_pb2.TaskTemplate, task_template_local_path) - task_template_model = _task_model.TaskTemplate.from_flyte_idl(task_template_proto) + task_template_model = _taskTemplate.from_flyte_idl(task_template_proto) executor_class = load_object_from_module(loader_args[1]) return ExecutableTemplateShimTask(task_template_model, executor_class) diff --git a/flytekit/core/python_function_task.py b/flytekit/core/python_function_task.py index addea248f4..bb7e1d1c77 100644 --- a/flytekit/core/python_function_task.py +++ b/flytekit/core/python_function_task.py @@ -34,10 +34,10 @@ WorkflowMetadataDefaults, ) from flytekit.loggers import logger -from flytekit.models import dynamic_job as _dynamic_job -from flytekit.models import literals as _literal_models -from flytekit.models import task as task_models from flytekit.models.admin import workflow as admin_workflow_models +from flytekit.models.admin.task import TaskSpec as _taskSpec +from flytekit.models.core import dynamic_job as _dynamic_job +from flytekit.models.core import literals as _literal_models T = TypeVar("T") @@ -222,9 +222,9 @@ def compile_into_workflow( if isinstance(entity, ReferenceTask): raise Exception("Reference tasks are currently unsupported within dynamic tasks") - if not isinstance(model, task_models.TaskSpec): + if not isinstance(model, _taskSpec): raise TypeError( - f"Unexpected type for serialized form of task. Expected {task_models.TaskSpec}, but got {type(model)}" + f"Unexpected type for serialized form of task. Expected {_taskSpec}, but got {type(model)}" ) # Store the valid task template so that we can pass it to the diff --git a/flytekit/core/reference_entity.py b/flytekit/core/reference_entity.py index 22090838ff..c9e0d1e440 100644 --- a/flytekit/core/reference_entity.py +++ b/flytekit/core/reference_entity.py @@ -15,9 +15,9 @@ ) from flytekit.core.type_engine import TypeEngine from flytekit.loggers import logger -from flytekit.models import interface as _interface_models -from flytekit.models import literals as _literal_models from flytekit.models.core import identifier as _identifier_model +from flytekit.models.core import interface as _interface_models +from flytekit.models.core import literals as _literal_models from flytekit.models.core import workflow as _workflow_model diff --git a/flytekit/core/schedule.py b/flytekit/core/schedule.py index 0c5fe786ae..8c1ef1520b 100644 --- a/flytekit/core/schedule.py +++ b/flytekit/core/schedule.py @@ -9,7 +9,7 @@ import croniter as _croniter -from flytekit.models import schedule as _schedule_models +from flytekit.models.admin import schedule as _schedule_models # Duplicates flytekit.common.schedules.Schedule to avoid using the ExtendedSdkType metaclass. diff --git a/flytekit/core/shim_task.py b/flytekit/core/shim_task.py index 637cfe4b47..4645387742 100644 --- a/flytekit/core/shim_task.py +++ b/flytekit/core/shim_task.py @@ -2,12 +2,12 @@ from typing import Any, Generic, Type, TypeVar, Union +import flytekit.models.core.task from flytekit import ExecutionParameters, FlyteContext, FlyteContextManager, logger from flytekit.core.tracker import TrackedInstance from flytekit.core.type_engine import TypeEngine -from flytekit.models import dynamic_job as _dynamic_job -from flytekit.models import literals as _literal_models -from flytekit.models import task as _task_model +from flytekit.models.core import dynamic_job as _dynamic_job +from flytekit.models.core import literals as _literal_models class ExecutableTemplateShimTask(object): @@ -34,14 +34,16 @@ class ExecutableTemplateShimTask(object): that the ``entrypoint.py`` can execute, even though this class doesn't inherit from ``PythonTask``. """ - def __init__(self, tt: _task_model.TaskTemplate, executor_type: Type[ShimTaskExecutor], *args, **kwargs): + def __init__( + self, tt: flytekit.models.core.task.TaskTemplate, executor_type: Type[ShimTaskExecutor], *args, **kwargs + ): self._executor_type = executor_type self._executor = executor_type() self._task_template = tt super().__init__(*args, **kwargs) @property - def task_template(self) -> _task_model.TaskTemplate: + def task_template(self) -> flytekit.models.core.task.TaskTemplate: return self._task_template @property @@ -149,7 +151,7 @@ def dispatch_execute( class ShimTaskExecutor(TrackedInstance, Generic[T]): - def execute_from_model(self, tt: _task_model.TaskTemplate, **kwargs) -> Any: + def execute_from_model(self, tt: flytekit.models.core.task.TaskTemplate, **kwargs) -> Any: """ This function must be overridden and is where all the business logic for running a task should live. Keep in mind that you're only working with the ``TaskTemplate``. You won't have access to any information in the task diff --git a/flytekit/core/task.py b/flytekit/core/task.py index 24fe283399..5b3eff4042 100644 --- a/flytekit/core/task.py +++ b/flytekit/core/task.py @@ -7,7 +7,7 @@ from flytekit.core.python_function_task import PythonFunctionTask from flytekit.core.reference_entity import ReferenceEntity, TaskReference from flytekit.core.resources import Resources -from flytekit.models.security import Secret +from flytekit.models.core.security import Secret class TaskPlugins(object): diff --git a/flytekit/core/type_engine.py b/flytekit/core/type_engine.py index 39a7be762a..38b7668975 100644 --- a/flytekit/core/type_engine.py +++ b/flytekit/core/type_engine.py @@ -19,16 +19,16 @@ from google.protobuf.struct_pb2 import Struct from marshmallow_jsonschema import JSONSchema +import flytekit.models.core.types from flytekit.common.exceptions import user as user_exceptions from flytekit.common.types import primitives as _primitives from flytekit.core.context_manager import FlyteContext from flytekit.core.type_helpers import load_type_from_tag from flytekit.loggers import logger -from flytekit.models import interface as _interface_models -from flytekit.models import types as _type_models +from flytekit.models.core import interface as _interface_models from flytekit.models.core import types as _core_types -from flytekit.models.literals import Literal, LiteralCollection, LiteralMap, Primitive, Scalar -from flytekit.models.types import LiteralType, SimpleType +from flytekit.models.core.literals import Literal, LiteralCollection, LiteralMap, Primitive, Scalar +from flytekit.models.core.types import LiteralType, SimpleType T = typing.TypeVar("T") DEFINITIONS = "definitions" @@ -576,7 +576,7 @@ def get_literal_type(self, t: Type[T]) -> Optional[LiteralType]: """ try: sub_type = TypeEngine.to_literal_type(self.get_sub_type(t)) - return _type_models.LiteralType(collection_type=sub_type) + return flytekit.models.core.types.LiteralType(collection_type=sub_type) except Exception as e: raise ValueError(f"Type of Generic List type is not supported, {e}") @@ -631,7 +631,7 @@ def get_literal_type(self, t: Type[dict]) -> LiteralType: if tp[0] == str: try: sub_type = TypeEngine.to_literal_type(tp[1]) - return _type_models.LiteralType(map_value_type=sub_type) + return flytekit.models.core.types.LiteralType(map_value_type=sub_type) except Exception as e: raise ValueError(f"Type of Generic List type is not supported, {e}") return _primitives.Generic.to_flyte_literal_type() @@ -699,7 +699,7 @@ def _blob_type(self) -> _core_types.BlobType: ) def get_literal_type(self, t: typing.TextIO) -> LiteralType: - return _type_models.LiteralType( + return flytekit.models.core.types.LiteralType( blob=self._blob_type(), ) @@ -733,7 +733,7 @@ def _blob_type(self) -> _core_types.BlobType: ) def get_literal_type(self, t: Type[typing.BinaryIO]) -> LiteralType: - return _type_models.LiteralType( + return flytekit.models.core.types.LiteralType( blob=self._blob_type(), ) @@ -904,7 +904,7 @@ def _register_default_type_transformers(): SimpleTransformer( "none", None, - _type_models.LiteralType(simple=_type_models.SimpleType.NONE), + flytekit.models.core.types.LiteralType(simple=flytekit.models.core.types.SimpleType.NONE), lambda x: None, lambda x: None, ) diff --git a/flytekit/core/workflow.py b/flytekit/core/workflow.py index a14be3d0a5..c6ea527da7 100644 --- a/flytekit/core/workflow.py +++ b/flytekit/core/workflow.py @@ -35,8 +35,8 @@ from flytekit.core.reference_entity import ReferenceEntity, WorkflowReference from flytekit.core.type_engine import TypeEngine from flytekit.loggers import logger -from flytekit.models import interface as _interface_models -from flytekit.models import literals as _literal_models +from flytekit.models.core import interface as _interface_models +from flytekit.models.core import literals as _literal_models from flytekit.models.core import workflow as _workflow_model GLOBAL_START_NODE = Node( diff --git a/flytekit/engines/common.py b/flytekit/engines/common.py index ed1dfd9d37..024cbb2608 100644 --- a/flytekit/engines/common.py +++ b/flytekit/engines/common.py @@ -219,10 +219,10 @@ def launch( :param Text domain: :param Text name: :param flytekit.models.literals.LiteralMap inputs: The inputs to pass - :param list[flytekit.models.common.Notification] notification_overrides: If specified, override the + :param list[flytekit.models.admin.common.Notification] notification_overrides: If specified, override the notifications. - :param flytekit.models.common.Labels label_overrides: - :param flytekit.models.common.Annotations annotation_overrides: + :param flytekit.models.admin.common.Labels label_overrides: + :param flytekit.models.admin.common.Annotations annotation_overrides: :rtype: flytekit.models.execution.Execution """ pass @@ -283,11 +283,11 @@ def launch( :param Text domain: :param Text name: :param flytekit.models.literals.LiteralMap inputs: The inputs to pass - :param list[flytekit.models.common.Notification] notification_overrides: If specified, override the + :param list[flytekit.models.admin.common.Notification] notification_overrides: If specified, override the notifications. - :param flytekit.models.common.Labels label_overrides: - :param flytekit.models.common.Annotations annotation_overrides: - :param flytekit.models.common.AuthRole auth_role: + :param flytekit.models.admin.common.Labels label_overrides: + :param flytekit.models.admin.common.Annotations annotation_overrides: + :param flytekit.models.admin.common.AuthRole auth_role: :rtype: flytekit.models.execution.Execution """ pass @@ -351,7 +351,7 @@ def fetch_task(self, task_id): """ :param flytekit.models.core.identifier.Identifier task_id: This identifier should have a resource type of kind Task. - :rtype: flytekit.models.task.Task + :rtype: flytekit.models.admin.task.Task """ pass @@ -359,8 +359,8 @@ def fetch_task(self, task_id): def fetch_latest_task(self, named_task): """ Fetches the latest task - :param flytekit.models.common.NamedEntityIdentifier named_task: NamedEntityIdentifier to fetch - :rtype: flytekit.models.task.Task + :param flytekit.models.admin.common.NamedEntityIdentifier named_task: NamedEntityIdentifier to fetch + :rtype: flytekit.models.admin.task.Task """ pass diff --git a/flytekit/engines/flyte/engine.py b/flytekit/engines/flyte/engine.py index 7af35e20b9..6e3ad7f9bd 100644 --- a/flytekit/engines/flyte/engine.py +++ b/flytekit/engines/flyte/engine.py @@ -8,6 +8,8 @@ from flyteidl.core import literals_pb2 as _literals_pb2 import flytekit +import flytekit.models.admin.common +import flytekit.models.admin.launch_plan from flytekit.clients.friendly import SynchronousFlyteClient as _SynchronousFlyteClient from flytekit.clients.helpers import iterate_node_executions as _iterate_node_executions from flytekit.clients.helpers import iterate_task_executions as _iterate_task_executions @@ -22,14 +24,14 @@ from flytekit.engines import common as _common_engine from flytekit.interfaces.data import data_proxy as _data_proxy from flytekit.interfaces.stats.taggable import get_stats as _get_stats -from flytekit.models import common as _common_models -from flytekit.models import execution as _execution_models -from flytekit.models import literals as _literals -from flytekit.models import task as _task_models from flytekit.models.admin import common as _common +from flytekit.models.admin import execution as _execution_models from flytekit.models.admin import workflow as _workflow_model +from flytekit.models.admin.common import NamedEntityIdentifier as _namedEntityIdentifier +from flytekit.models.admin.task import TaskSpec as _taskSpec from flytekit.models.core import errors as _error_models from flytekit.models.core import identifier as _identifier +from flytekit.models.core import literals as _literals class _FlyteClientManager(object): @@ -160,9 +162,7 @@ def fetch_launch_plan(self, launch_plan_id): _platform_config.URL.get(), insecure=_platform_config.INSECURE.get() ).client.get_launch_plan(launch_plan_id) else: - named_entity_id = _common_models.NamedEntityIdentifier( - launch_plan_id.project, launch_plan_id.domain, launch_plan_id.name - ) + named_entity_id = _namedEntityIdentifier(launch_plan_id.project, launch_plan_id.domain, launch_plan_id.name) return _FlyteClientManager( _platform_config.URL.get(), insecure=_platform_config.INSECURE.get() ).client.get_active_launch_plan(named_entity_id) @@ -317,7 +317,7 @@ class FlyteTask(_common_engine.BaseTaskExecutor): def register(self, identifier): client = _FlyteClientManager(_platform_config.URL.get(), insecure=_platform_config.INSECURE.get()).client try: - client.create_task(identifier, _task_models.TaskSpec(self.sdk_task)) + client.create_task(identifier, _taskSpec(self.sdk_task)) except _user_exceptions.FlyteEntityAlreadyExistsException: pass @@ -426,7 +426,7 @@ def launch( notifications. :param flytekit.models.common.Labels label_overrides: :param flytekit.models.common.Annotations annotation_overrides: - :param flytekit.models.common.AuthRole auth_role: + :param flytekit.models.admin.common.AuthRole auth_role: :rtype: flytekit.models.execution.Execution """ disable_all = notification_overrides == [] @@ -446,7 +446,7 @@ def launch( "Please update your config to use `assumable_iam_role` instead" ) assumable_iam_role = _sdk_config.ROLE.get() - auth_role = _common_models.AuthRole( + auth_role = flytekit.models.admin.common.AuthRole( assumable_iam_role=assumable_iam_role, kubernetes_service_account=kubernetes_service_account, ) diff --git a/flytekit/engines/unit/engine.py b/flytekit/engines/unit/engine.py index 2e3c4e187b..c2366c9fbf 100644 --- a/flytekit/engines/unit/engine.py +++ b/flytekit/engines/unit/engine.py @@ -16,10 +16,10 @@ from flytekit.engines import common as _common_engine from flytekit.engines.unit.mock_stats import MockStats from flytekit.interfaces.data import data_proxy as _data_proxy -from flytekit.models import array_job as _array_job -from flytekit.models import literals as _literals -from flytekit.models import qubole as _qubole_models +from flytekit.models.core import literals as _literals from flytekit.models.core.identifier import WorkflowExecutionIdentifier +from flytekit.models.plugins import array_job as _array_job +from flytekit.models.plugins import qubole as _qubole_models class UnitTestEngineFactory(_common_engine.BaseExecutionEngineFactory): diff --git a/flytekit/extras/sqlite3/task.py b/flytekit/extras/sqlite3/task.py index e4a803f50a..354b0e7792 100644 --- a/flytekit/extras/sqlite3/task.py +++ b/flytekit/extras/sqlite3/task.py @@ -8,12 +8,12 @@ import pandas as pd +import flytekit.models.core.task from flytekit import FlyteContext, kwtypes from flytekit.core.base_sql_task import SQLTask from flytekit.core.context_manager import SerializationSettings from flytekit.core.python_customized_container_task import PythonCustomizedContainerTask from flytekit.core.shim_task import ShimTaskExecutor -from flytekit.models import task as task_models from flytekit.types.schema import FlyteSchema @@ -111,7 +111,7 @@ def get_custom(self, settings: SerializationSettings) -> typing.Dict[str, typing class SQLite3TaskExecutor(ShimTaskExecutor[SQLite3Task]): - def execute_from_model(self, tt: task_models.TaskTemplate, **kwargs) -> typing.Any: + def execute_from_model(self, tt: flytekit.models.core.task.TaskTemplate, **kwargs) -> typing.Any: with tempfile.TemporaryDirectory() as temp_dir: ctx = FlyteContext.current_context() file_ext = os.path.basename(tt.custom["uri"]) diff --git a/flytekit/models/admin/common.py b/flytekit/models/admin/common.py index 3baba54e24..15ba6e7e5e 100644 --- a/flytekit/models/admin/common.py +++ b/flytekit/models/admin/common.py @@ -1,9 +1,11 @@ +import six as _six from flyteidl.admin import common_pb2 as _common_pb2 -from flytekit.models import common as _common +from flytekit.models import common as _common_models +from flytekit.models.common import FlyteIdlEntity -class Sort(_common.FlyteIdlEntity): +class Sort(_common_models.FlyteIdlEntity): class Direction(object): DESCENDING = _common_pb2.Sort.DESCENDING ASCENDING = _common_pb2.Sort.ASCENDING @@ -68,3 +70,447 @@ def from_python_std(cls, text): "start with 'asc(' or 'desc'.".format(text) ) return cls(key=key, direction=direction) + + +class EmailNotification(_common_models.FlyteIdlEntity): + def __init__(self, recipients_email): + """ + :param list[Text] recipients_email: + """ + self._recipients_email = recipients_email + + @property + def recipients_email(self): + """ + :rtype: list[Text] + """ + return self._recipients_email + + def to_flyte_idl(self): + """ + :rtype: flyteidl.admin.common_pb2.EmailNotification + """ + return _common_pb2.EmailNotification(recipients_email=self.recipients_email) + + @classmethod + def from_flyte_idl(cls, pb2_object): + """ + :param flyteidl.admin.common_pb2.EmailNotification pb2_object: + :rtype: EmailNotification + """ + return cls(pb2_object.recipients_email) + + +class SlackNotification(_common_models.FlyteIdlEntity): + def __init__(self, recipients_email): + """ + :param list[Text] recipients_email: + """ + self._recipients_email = recipients_email + + @property + def recipients_email(self): + """ + :rtype: list[Text] + """ + return self._recipients_email + + def to_flyte_idl(self): + """ + :rtype: flyteidl.admin.common_pb2.SlackNotification + """ + return _common_pb2.SlackNotification(recipients_email=self.recipients_email) + + @classmethod + def from_flyte_idl(cls, pb2_object): + """ + :param flyteidl.admin.common_pb2.SlackNotification pb2_object: + :rtype: EmailNotification + """ + return cls(pb2_object.recipients_email) + + +class PagerDutyNotification(_common_models.FlyteIdlEntity): + def __init__(self, recipients_email): + """ + :param list[Text] recipients_email: + """ + self._recipients_email = recipients_email + + @property + def recipients_email(self): + """ + :rtype: list[Text] + """ + return self._recipients_email + + def to_flyte_idl(self): + """ + :rtype: flyteidl.admin.common_pb2.PagerDutyNotification + """ + return _common_pb2.PagerDutyNotification(recipients_email=self.recipients_email) + + @classmethod + def from_flyte_idl(cls, pb2_object): + """ + :param flyteidl.admin.common_pb2.PagerDutyNotification pb2_object: + :rtype: EmailNotification + """ + return cls(pb2_object.recipients_email) + + +class Notification(_common_models.FlyteIdlEntity): + def __init__( + self, + phases, + email: EmailNotification = None, + pager_duty: PagerDutyNotification = None, + slack: SlackNotification = None, + ): + """ + Represents a structure for notifications based on execution status. + :param list[int] phases: A list of phases to which users can associate the notifications. + :param EmailNotification email: [Optional] Specify this for an email notification. + :param PagerDutyNotification email: [Optional] Specify this for a PagerDuty notification. + :param SlackNotification email: [Optional] Specify this for a Slack notification. + """ + self._phases = phases + self._email = email + self._pager_duty = pager_duty + self._slack = slack + + @property + def phases(self): + """ + A list of phases to which users can associate the notifications. + :rtype: list[int] + """ + return self._phases + + @property + def email(self): + """ + :rtype: EmailNotification + """ + return self._email + + @property + def pager_duty(self): + """ + :rtype: PagerDutyNotification + """ + return self._pager_duty + + @property + def slack(self): + """ + :rtype: SlackNotification + """ + return self._slack + + def to_flyte_idl(self): + """ + :rtype: flyteidl.admin.common_pb2.Notification + """ + return _common_pb2.Notification( + phases=self.phases, + email=self.email.to_flyte_idl() if self.email else None, + pager_duty=self.pager_duty.to_flyte_idl() if self.pager_duty else None, + slack=self.slack.to_flyte_idl() if self.slack else None, + ) + + @classmethod + def from_flyte_idl(cls, p): + """ + :param flyteidl.admin.common_pb2.Notification p: + :rtype: Notification + """ + return cls( + p.phases, + email=EmailNotification.from_flyte_idl(p.email) if p.HasField("email") else None, + pager_duty=PagerDutyNotification.from_flyte_idl(p.pager_duty) if p.HasField("pager_duty") else None, + slack=SlackNotification.from_flyte_idl(p.slack) if p.HasField("slack") else None, + ) + + +class Labels(_common_models.FlyteIdlEntity): + def __init__(self, values): + """ + Label values to be applied to a workflow execution resource. + + :param dict[Text, Text] values: + """ + self._values = values + + @property + def values(self): + return self._values + + def to_flyte_idl(self): + """ + :rtype: dict[Text, Text] + """ + return _common_pb2.Labels(values={k: v for k, v in _six.iteritems(self.values)}) + + @classmethod + def from_flyte_idl(cls, pb2_object): + """ + :param flyteidl.admin.common_pb2.Labels pb2_object: + :rtype: Labels + """ + return cls({k: v for k, v in _six.iteritems(pb2_object.values)}) + + +class Annotations(_common_models.FlyteIdlEntity): + def __init__(self, values): + """ + Annotation values to be applied to a workflow execution resource. + + :param dict[Text, Text] values: + """ + self._values = values + + @property + def values(self): + return self._values + + def to_flyte_idl(self): + """ + :rtype: _common_pb2.Annotations + """ + return _common_pb2.Annotations(values={k: v for k, v in _six.iteritems(self.values)}) + + @classmethod + def from_flyte_idl(cls, pb2_object): + """ + :param flyteidl.admin.common_pb2.Annotations pb2_object: + :rtype: Annotations + """ + return cls({k: v for k, v in _six.iteritems(pb2_object.values)}) + + +class UrlBlob(_common_models.FlyteIdlEntity): + def __init__(self, url, bytes): + """ + :param Text url: + :param int bytes: + """ + self._url = url + self._bytes = bytes + + @property + def url(self): + """ + :rtype: Text + """ + return self._url + + @property + def bytes(self): + """ + :rtype: int + """ + return self._bytes + + def to_flyte_idl(self): + """ + :rtype: flyteidl.admin.common_pb2.UrlBlob + """ + return _common_pb2.UrlBlob(url=self.url, bytes=self.bytes) + + @classmethod + def from_flyte_idl(cls, pb): + """ + :param flyteidl.admin.common_pb2.UrlBlob pb: + :rtype: UrlBlob + """ + return cls(url=pb.url, bytes=pb.bytes) + + +class RawOutputDataConfig(_common_models.FlyteIdlEntity): + def __init__(self, output_location_prefix): + """ + :param Text output_location_prefix: Location of offloaded data for things like S3, etc. + """ + self._output_location_prefix = output_location_prefix + + @property + def output_location_prefix(self): + return self._output_location_prefix + + def to_flyte_idl(self): + """ + :rtype: flyteidl.admin.common_pb2.Auth + """ + return _common_pb2.RawOutputDataConfig(output_location_prefix=self.output_location_prefix) + + @classmethod + def from_flyte_idl(cls, pb2): + return cls(output_location_prefix=pb2.output_location_prefix) + + +class NamedEntityState(object): + ACTIVE = _common_pb2.NAMED_ENTITY_ACTIVE + ARCHIVED = _common_pb2.NAMED_ENTITY_ARCHIVED + + @classmethod + def enum_to_string(cls, val): + """ + :param int val: + :rtype: Text + """ + if val == cls.ACTIVE: + return "ACTIVE" + elif val == cls.ARCHIVED: + return "ARCHIVED" + else: + return "" + + +class NamedEntityIdentifier(_common_models.FlyteIdlEntity): + def __init__(self, project, domain, name): + """ + :param Text project: + :param Text domain: + :param Text name: + """ + self._project = project + self._domain = domain + self._name = name + + @property + def project(self): + """ + :rtype: Text + """ + return self._project + + @property + def domain(self): + """ + :rtype: Text + """ + return self._domain + + @property + def name(self): + """ + :rtype: Text + """ + return self._name + + def to_flyte_idl(self): + """ + :rtype: flyteidl.admin.common_pb2.NamedEntityIdentifier + """ + return _common_pb2.NamedEntityIdentifier( + project=self.project, + domain=self.domain, + name=self.name, + ) + + @classmethod + def from_flyte_idl(cls, p): + """ + :param flyteidl.core.common_pb2.NamedEntityIdentifier p: + :rtype: Identifier + """ + return cls( + project=p.project, + domain=p.domain, + name=p.name, + ) + + +class NamedEntityMetadata(_common_models.FlyteIdlEntity): + def __init__(self, description, state): + """ + + :param Text description: + :param int state: enum value from NamedEntityState + """ + self._description = description + self._state = state + + @property + def description(self): + """ + :rtype: Text + """ + return self._description + + @property + def state(self): + """ + enum value from NamedEntityState + :rtype: int + """ + return self._state + + def to_flyte_idl(self): + """ + :rtype: flyteidl.admin.common_pb2.NamedEntityMetadata + """ + return _common_pb2.NamedEntityMetadata( + description=self.description, + state=self.state, + ) + + @classmethod + def from_flyte_idl(cls, p): + """ + :param flyteidl.core.common_pb2.NamedEntityMetadata p: + :rtype: Identifier + """ + return cls( + description=p.description, + state=p.state, + ) + + +class AuthRole(FlyteIdlEntity): + def __init__(self, assumable_iam_role=None, kubernetes_service_account=None): + """ + At most one of assumable_iam_role or kubernetes_service_account can be set. + :param Text assumable_iam_role: IAM identity with set permissions policies. + :param Text kubernetes_service_account: Provides an identity for workflow execution resources. Flyte deployment + administrators are responsible for handling permissions as they relate to the service account. + """ + self._assumable_iam_role = assumable_iam_role + self._kubernetes_service_account = kubernetes_service_account + + @property + def assumable_iam_role(self): + """ + The IAM role to execute the workflow with + :rtype: Text + """ + return self._assumable_iam_role + + @property + def kubernetes_service_account(self): + """ + The kubernetes service account to execute the workflow with + :rtype: Text + """ + return self._kubernetes_service_account + + def to_flyte_idl(self): + """ + :rtype: flyteidl.admin.common.AuthRole + """ + return _common_pb2.AuthRole( + assumable_iam_role=self.assumable_iam_role if self.assumable_iam_role else None, + kubernetes_service_account=self.kubernetes_service_account if self.kubernetes_service_account else None, + ) + + @classmethod + def from_flyte_idl(cls, pb2_object): + """ + :param flyteidl.admin.common.AuthRole pb2_object: + :rtype: AuthRole + """ + return cls( + assumable_iam_role=pb2_object.assumable_iam_role, + kubernetes_service_account=pb2_object.kubernetes_service_account, + ) diff --git a/flytekit/models/execution.py b/flytekit/models/admin/execution.py similarity index 87% rename from flytekit/models/execution.py rename to flytekit/models/admin/execution.py index 042b3ad909..b3a5c6ea66 100644 --- a/flytekit/models/execution.py +++ b/flytekit/models/admin/execution.py @@ -3,10 +3,13 @@ import flyteidl.admin.task_execution_pb2 as _task_execution_pb2 import pytz as _pytz +import flytekit.models.admin.common +import flytekit.models.admin.launch_plan from flytekit.models import common as _common_models -from flytekit.models import literals as _literals_models +from flytekit.models.admin import common as _admin_common from flytekit.models.core import execution as _core_execution from flytekit.models.core import identifier as _identifier +from flytekit.models.core import literals as _literals_models class ExecutionMetadata(_common_models.FlyteIdlEntity): @@ -86,9 +89,9 @@ def __init__( :param ExecutionMetadata metadata: The metadata to be associated with this execution :param NotificationList notifications: List of notifications for this execution. :param bool disable_all: If true, all notifications should be disabled. - :param flytekit.models.common.Labels labels: Labels to apply to the execution. - :param flytekit.models.common.Annotations annotations: Annotations to apply to the execution - :param flytekit.models.common.AuthRole auth_role: The authorization method with which to execute the workflow. + :param flytekit.models.admin.common.Labels labels: Labels to apply to the execution. + :param flytekit.models.admin.common.Annotations annotations: Annotations to apply to the execution + :param flytekit.models.admin.common.AuthRole auth_role: The authorization method with which to execute the workflow. :param max_parallelism int: Controls the maximum number of tasknodes that can be run in parallel for the entire workflow. This is useful to achieve fairness. Note: MapTasks are regarded as one unit, and parallelism/concurrency of MapTasks is independent from this. @@ -98,9 +101,9 @@ def __init__( self._metadata = metadata self._notifications = notifications self._disable_all = disable_all - self._labels = labels or _common_models.Labels({}) - self._annotations = annotations or _common_models.Annotations({}) - self._auth_role = auth_role or _common_models.AuthRole() + self._labels = labels or _admin_common.Labels({}) + self._annotations = annotations or _admin_common.Annotations({}) + self._auth_role = auth_role or flytekit.models.admin.common.AuthRole() self._max_parallelism = max_parallelism @property @@ -135,21 +138,21 @@ def disable_all(self): @property def labels(self): """ - :rtype: flytekit.models.common.Labels + :rtype: flytekit.models.admin.common.Labels """ return self._labels @property def annotations(self): """ - :rtype: flytekit.models.common.Annotations + :rtype: flytekit.models.admin.common.Annotations """ return self._annotations @property def auth_role(self): """ - :rtype: flytekit.models.common.AuthRole + :rtype: flytekit.models.admin.common.AuthRole """ return self._auth_role @@ -183,9 +186,9 @@ def from_flyte_idl(cls, p): metadata=ExecutionMetadata.from_flyte_idl(p.metadata), notifications=NotificationList.from_flyte_idl(p.notifications) if p.HasField("notifications") else None, disable_all=p.disable_all if p.HasField("disable_all") else None, - labels=_common_models.Labels.from_flyte_idl(p.labels), - annotations=_common_models.Annotations.from_flyte_idl(p.annotations), - auth_role=_common_models.AuthRole.from_flyte_idl(p.auth_role), + labels=_admin_common.Labels.from_flyte_idl(p.labels), + annotations=_admin_common.Annotations.from_flyte_idl(p.annotations), + auth_role=flytekit.models.admin.common.AuthRole.from_flyte_idl(p.auth_role), max_parallelism=p.max_parallelism, ) @@ -367,14 +370,14 @@ def from_flyte_idl(cls, pb2_object): class NotificationList(_common_models.FlyteIdlEntity): def __init__(self, notifications): """ - :param list[flytekit.models.common.Notification] notifications: A simple list of notifications. + :param list[flytekit.models.admin.common.Notification] notifications: A simple list of notifications. """ self._notifications = notifications @property def notifications(self): """ - :rtype: list[flytekit.models.common.Notification] + :rtype: list[flytekit.models.admin.common.Notification] """ return self._notifications @@ -390,7 +393,7 @@ def from_flyte_idl(cls, pb2_object): :param flyteidl.admin.execution_pb2.NotificationList pb2_object: :rtype: NotificationList """ - return cls([_common_models.Notification.from_flyte_idl(p) for p in pb2_object.notifications]) + return cls([_admin_common.Notification.from_flyte_idl(p) for p in pb2_object.notifications]) class _CommonDataResponse(_common_models.FlyteIdlEntity): @@ -401,8 +404,8 @@ class _CommonDataResponse(_common_models.FlyteIdlEntity): def __init__(self, inputs, outputs, full_inputs, full_outputs): """ - :param _common_models.UrlBlob inputs: - :param _common_models.UrlBlob outputs: + :param _admin_common.UrlBlob inputs: + :param _admin_common.UrlBlob outputs: :param _literals_pb2.LiteralMap full_inputs: :param _literals_pb2.LiteralMap full_outputs: """ @@ -414,14 +417,14 @@ def __init__(self, inputs, outputs, full_inputs, full_outputs): @property def inputs(self): """ - :rtype: _common_models.UrlBlob + :rtype: _admin_common.UrlBlob """ return self._inputs @property def outputs(self): """ - :rtype: _common_models.UrlBlob + :rtype: _admin_common.UrlBlob """ return self._outputs @@ -448,8 +451,8 @@ def from_flyte_idl(cls, pb2_object): :rtype: WorkflowExecutionGetDataResponse """ return cls( - inputs=_common_models.UrlBlob.from_flyte_idl(pb2_object.inputs), - outputs=_common_models.UrlBlob.from_flyte_idl(pb2_object.outputs), + inputs=_admin_common.UrlBlob.from_flyte_idl(pb2_object.inputs), + outputs=_admin_common.UrlBlob.from_flyte_idl(pb2_object.outputs), full_inputs=_literals_models.LiteralMap.from_flyte_idl(pb2_object.full_inputs), full_outputs=_literals_models.LiteralMap.from_flyte_idl(pb2_object.full_outputs), ) @@ -474,8 +477,8 @@ def from_flyte_idl(cls, pb2_object): :rtype: TaskExecutionGetDataResponse """ return cls( - inputs=_common_models.UrlBlob.from_flyte_idl(pb2_object.inputs), - outputs=_common_models.UrlBlob.from_flyte_idl(pb2_object.outputs), + inputs=_admin_common.UrlBlob.from_flyte_idl(pb2_object.inputs), + outputs=_admin_common.UrlBlob.from_flyte_idl(pb2_object.outputs), full_inputs=_literals_models.LiteralMap.from_flyte_idl(pb2_object.full_inputs), full_outputs=_literals_models.LiteralMap.from_flyte_idl(pb2_object.full_outputs), ) @@ -500,8 +503,8 @@ def from_flyte_idl(cls, pb2_object): :rtype: NodeExecutionGetDataResponse """ return cls( - inputs=_common_models.UrlBlob.from_flyte_idl(pb2_object.inputs), - outputs=_common_models.UrlBlob.from_flyte_idl(pb2_object.outputs), + inputs=_admin_common.UrlBlob.from_flyte_idl(pb2_object.inputs), + outputs=_admin_common.UrlBlob.from_flyte_idl(pb2_object.outputs), full_inputs=_literals_models.LiteralMap.from_flyte_idl(pb2_object.full_inputs), full_outputs=_literals_models.LiteralMap.from_flyte_idl(pb2_object.full_outputs), ) diff --git a/flytekit/models/launch_plan.py b/flytekit/models/admin/launch_plan.py similarity index 74% rename from flytekit/models/launch_plan.py rename to flytekit/models/admin/launch_plan.py index 9c508166c5..7bc6f8265e 100644 --- a/flytekit/models/launch_plan.py +++ b/flytekit/models/admin/launch_plan.py @@ -1,10 +1,12 @@ from flyteidl.admin import launch_plan_pb2 as _launch_plan from flytekit.models import common as _common -from flytekit.models import interface as _interface -from flytekit.models import literals as _literals -from flytekit.models import schedule as _schedule +from flytekit.models.admin import common as _admin_common +from flytekit.models.admin import schedule as _schedule +from flytekit.models.admin.common import AuthRole from flytekit.models.core import identifier as _identifier +from flytekit.models.core import interface as _interface +from flytekit.models.core import literals as _literals class LaunchPlanMetadata(_common.FlyteIdlEntity): @@ -12,7 +14,7 @@ def __init__(self, schedule, notifications): """ :param flytekit.models.schedule.Schedule schedule: Schedule to execute the Launch Plan - :param list[flytekit.models.common.Notification] notifications: List of notifications based on + :param list[flytekit.models.admin.common.Notification] notifications: List of notifications based on execution status transitions """ self._schedule = schedule @@ -30,7 +32,7 @@ def schedule(self): def notifications(self): """ List of notifications based on Execution status transitions - :rtype: list[flytekit.models.common.Notification] + :rtype: list[flytekit.models.admin.common.Notification] """ return self._notifications @@ -54,56 +56,7 @@ def from_flyte_idl(cls, pb2_object): schedule=_schedule.Schedule.from_flyte_idl(pb2_object.schedule) if pb2_object.HasField("schedule") else None, - notifications=[_common.Notification.from_flyte_idl(n) for n in pb2_object.notifications], - ) - - -class Auth(_common.FlyteIdlEntity): - def __init__(self, assumable_iam_role=None, kubernetes_service_account=None): - """ - DEPRECATED. Do not use. Use flytekit.models.common.AuthRole instead - At most one of assumable_iam_role or kubernetes_service_account can be set. - :param Text assumable_iam_role: IAM identity with set permissions policies. - :param Text kubernetes_service_account: Provides an identity for workflow execution resources. Flyte deployment - administrators are responsible for handling permissions as they relate to the service account. - """ - self._assumable_iam_role = assumable_iam_role - self._kubernetes_service_account = kubernetes_service_account - - @property - def assumable_iam_role(self): - """ - The IAM role to execute the workflow with - :rtype: Text - """ - return self._assumable_iam_role - - @property - def kubernetes_service_account(self): - """ - The kubernetes service account to execute the workflow with - :rtype: Text - """ - return self._kubernetes_service_account - - def to_flyte_idl(self): - """ - :rtype: flyteidl.admin.launch_plan_pb2.Auth - """ - return _launch_plan.Auth( - assumable_iam_role=self.assumable_iam_role if self.assumable_iam_role else None, - kubernetes_service_account=self.kubernetes_service_account if self.kubernetes_service_account else None, - ) - - @classmethod - def from_flyte_idl(cls, pb2_object): - """ - :param flyteidl.admin.launch_plan_pb2.Auth pb2_object: - :rtype: Auth - """ - return cls( - assumable_iam_role=pb2_object.assumable_iam_role, - kubernetes_service_account=pb2_object.kubernetes_service_account, + notifications=[_admin_common.Notification.from_flyte_idl(n) for n in pb2_object.notifications], ) @@ -114,10 +67,10 @@ def __init__( entity_metadata, default_inputs, fixed_inputs, - labels: _common.Labels, - annotations: _common.Annotations, - auth_role: _common.AuthRole, - raw_output_data_config: _common.RawOutputDataConfig, + labels: _admin_common.Labels, + annotations: _admin_common.Annotations, + auth_role: AuthRole, + raw_output_data_config: _admin_common.RawOutputDataConfig, max_parallelism=None, ): """ @@ -129,10 +82,10 @@ def __init__( :param flytekit.models.literals.LiteralMap fixed_inputs: Fixed, non-overridable inputs for the Launch Plan :param flytekit.models.common.Labels: Any custom kubernetes labels to apply to workflows executed by this launch plan. - :param flytekit.models.common.Annotations annotations: + :param flytekit.models.admin.common.Annotations annotations: Any custom kubernetes annotations to apply to workflows executed by this launch plan. - :param flytekit.models.common.AuthRole auth_role: The auth method with which to execute the workflow. - :param flytekit.models.common.RawOutputDataConfig raw_output_data_config: Value for where to store offloaded + :param flytekit.models.admin.common.AuthRole auth_role: The auth method with which to execute the workflow. + :param flytekit.models.admin.common.RawOutputDataConfig raw_output_data_config: Value for where to store offloaded data like Blobs and Schemas. :param max_parallelism int: Controls the maximum number of tasknodes that can be run in parallel for the entire workflow. This is useful to achieve fairness. Note: MapTasks are regarded as one unit, and @@ -180,18 +133,18 @@ def fixed_inputs(self): return self._fixed_inputs @property - def labels(self) -> _common.Labels: + def labels(self) -> _admin_common.Labels: """ The labels to execute the workflow with - :rtype: flytekit.models.common.Labels + :rtype: flytekit.models.admin.common.Labels """ return self._labels @property - def annotations(self) -> _common.Annotations: + def annotations(self) -> _admin_common.Annotations: """ The annotations to execute the workflow with - :rtype: flytekit.models.common.Annotations + :rtype: flytekit.models.admin.common.Annotations """ return self._annotations @@ -199,7 +152,7 @@ def annotations(self) -> _common.Annotations: def auth_role(self): """ The authorization method with which to execute the workflow. - :rtype: flytekit.models.common.AuthRole + :rtype: flytekit.models.admin.common.AuthRole """ return self._auth_role @@ -207,7 +160,7 @@ def auth_role(self): def raw_output_data_config(self): """ Where to store offloaded data like Blobs and Schemas - :rtype: flytekit.models.common.RawOutputDataConfig + :rtype: flytekit.models.admin.common.RawOutputDataConfig """ return self._raw_output_data_config @@ -240,23 +193,23 @@ def from_flyte_idl(cls, pb2): auth_role = None # First check the newer field, auth_role. if pb2.auth_role is not None and (pb2.auth_role.assumable_iam_role or pb2.auth_role.kubernetes_service_account): - auth_role = _common.AuthRole.from_flyte_idl(pb2.auth_role) + auth_role = AuthRole.from_flyte_idl(pb2.auth_role) # Fallback to the deprecated field. elif pb2.auth is not None: if pb2.auth.assumable_iam_role: - auth_role = _common.AuthRole(assumable_iam_role=pb2.auth.assumable_iam_role) + auth_role = AuthRole(assumable_iam_role=pb2.auth.assumable_iam_role) else: - auth_role = _common.AuthRole(assumable_iam_role=pb2.auth.kubernetes_service_account) + auth_role = AuthRole(assumable_iam_role=pb2.auth.kubernetes_service_account) return cls( workflow_id=_identifier.Identifier.from_flyte_idl(pb2.workflow_id), entity_metadata=LaunchPlanMetadata.from_flyte_idl(pb2.entity_metadata), default_inputs=_interface.ParameterMap.from_flyte_idl(pb2.default_inputs), fixed_inputs=_literals.LiteralMap.from_flyte_idl(pb2.fixed_inputs), - labels=_common.Labels.from_flyte_idl(pb2.labels), - annotations=_common.Annotations.from_flyte_idl(pb2.annotations), + labels=_admin_common.Labels.from_flyte_idl(pb2.labels), + annotations=_admin_common.Annotations.from_flyte_idl(pb2.annotations), auth_role=auth_role, - raw_output_data_config=_common.RawOutputDataConfig.from_flyte_idl(pb2.raw_output_data_config), + raw_output_data_config=_admin_common.RawOutputDataConfig.from_flyte_idl(pb2.raw_output_data_config), max_parallelism=pb2.max_parallelism, ) diff --git a/flytekit/models/matchable_resource.py b/flytekit/models/admin/matchable_resource.py similarity index 100% rename from flytekit/models/matchable_resource.py rename to flytekit/models/admin/matchable_resource.py diff --git a/flytekit/models/node_execution.py b/flytekit/models/admin/node_execution.py similarity index 100% rename from flytekit/models/node_execution.py rename to flytekit/models/admin/node_execution.py diff --git a/flytekit/models/project.py b/flytekit/models/admin/project.py similarity index 100% rename from flytekit/models/project.py rename to flytekit/models/admin/project.py diff --git a/flytekit/models/schedule.py b/flytekit/models/admin/schedule.py similarity index 100% rename from flytekit/models/schedule.py rename to flytekit/models/admin/schedule.py diff --git a/flytekit/models/admin/task.py b/flytekit/models/admin/task.py new file mode 100644 index 0000000000..e22e8bddbb --- /dev/null +++ b/flytekit/models/admin/task.py @@ -0,0 +1,110 @@ +from flyteidl.admin import task_pb2 as _admin_task + +from flytekit.models import common as _common +from flytekit.models.core import identifier as _identifier +from flytekit.models.core.compiler import CompiledTask as _compiledTask +from flytekit.models.core.task import TaskTemplate + + +class TaskSpec(_common.FlyteIdlEntity): + def __init__(self, template): + """ + :param flytekit.models.core.task.TaskTemplate template: + """ + self._template = template + + @property + def template(self): + """ + :rtype: flytekit.models.core.task.TaskTemplate + """ + return self._template + + def to_flyte_idl(self): + """ + :rtype: flyteidl.admin.tasks_pb2.TaskSpec + """ + return _admin_task.TaskSpec(template=self.template.to_flyte_idl()) + + @classmethod + def from_flyte_idl(cls, pb2_object): + """ + :param flyteidl.admin.tasks_pb2.TaskSpec pb2_object: + :rtype: TaskSpec + """ + return cls(TaskTemplate.from_flyte_idl(pb2_object.template)) + + +class TaskClosure(_common.FlyteIdlEntity): + def __init__(self, compiled_task): + """ + :param flytekit.models.core.compiler.CompiledTask compiled_task: + """ + self._compiled_task = compiled_task + + @property + def compiled_task(self): + """ + :rtype: flytekit.models.core.compiler.CompiledTask + """ + return self._compiled_task + + def to_flyte_idl(self): + """ + :rtype: flyteidl.admin.task_pb2.TaskClosure + """ + return _admin_task.TaskClosure(compiled_task=self.compiled_task.to_flyte_idl()) + + @classmethod + def from_flyte_idl(cls, pb2_object): + """ + :param flyteidl.admin.task_pb2.TaskClosure pb2_object: + :rtype: TaskClosure + """ + return cls(compiled_task=_compiledTask.from_flyte_idl(pb2_object.compiled_task)) + + +class Task(_common.FlyteIdlEntity): + def __init__(self, id, closure): + """ + :param flytekit.models.core.identifier.Identifier id: The (project, domain, name) identifier for this task. + :param TaskClosure closure: The closure for the underlying workload. + """ + self._id = id + self._closure = closure + + @property + def id(self): + """ + The (project, domain, name, version) identifier for this task. + :rtype: flytekit.models.core.identifier.Identifier + """ + return self._id + + @property + def closure(self): + """ + The closure for the underlying workload. + :rtype: TaskClosure + """ + return self._closure + + def to_flyte_idl(self): + """ + :rtype: flyteidl.admin.task_pb2.Task + """ + return _admin_task.Task( + closure=self.closure.to_flyte_idl(), + id=self.id.to_flyte_idl(), + ) + + @classmethod + def from_flyte_idl(cls, pb2_object): + """ + :param flyteidl.admin.task_pb2.Task pb2_object: + :rtype: TaskDefinition + """ + return cls( + closure=TaskClosure.from_flyte_idl(pb2_object.closure), + id=_identifier.Identifier.from_flyte_idl(pb2_object.id), + ) diff --git a/flytekit/models/common.py b/flytekit/models/common.py index 139c3f0d2d..63f62a44b8 100644 --- a/flytekit/models/common.py +++ b/flytekit/models/common.py @@ -2,7 +2,6 @@ import json as _json import six as _six -from flyteidl.admin import common_pb2 as _common_pb2 from google.protobuf import json_format as _json_format from google.protobuf import struct_pb2 as _struct @@ -99,382 +98,3 @@ def to_dict(self): :rtype: dict[Text, T] """ pass - - -class NamedEntityIdentifier(FlyteIdlEntity): - def __init__(self, project, domain, name=None): - """ - :param Text project: The name of the project in which this entity lives. - :param Text domain: The name of the domain within the project. - :param Text name: [Optional] The name of the entity within the namespace of the project and domain. - """ - self._project = project - self._domain = domain - self._name = name - - @property - def project(self): - """ - The name of the project in which this entity lives. - :rtype: Text - """ - return self._project - - @property - def domain(self): - """ - The name of the domain within the project. - :rtype: Text - """ - return self._domain - - @property - def name(self): - """ - The name of the entity within the namespace of the project and domain. - :rtype: Text - """ - return self._name - - def to_flyte_idl(self): - """ - Stores object to a Flyte-IDL defined protobuf. - :rtype: flyteidl.admin.common_pb2.NamedEntityIdentifier - """ - - # We use the kwarg constructor of the protobuf and setting name=None is equivalent to not setting it at all - return _common_pb2.NamedEntityIdentifier(project=self.project, domain=self.domain, name=self.name) - - @classmethod - def from_flyte_idl(cls, idl_object): - """ - :param flyteidl.admin.common_pb2.NamedEntityIdentifier idl_object: - :rtype: NamedEntityIdentifier - """ - return cls(idl_object.project, idl_object.domain, idl_object.name) - - -class EmailNotification(FlyteIdlEntity): - def __init__(self, recipients_email): - """ - :param list[Text] recipients_email: - """ - self._recipients_email = recipients_email - - @property - def recipients_email(self): - """ - :rtype: list[Text] - """ - return self._recipients_email - - def to_flyte_idl(self): - """ - :rtype: flyteidl.admin.common_pb2.EmailNotification - """ - return _common_pb2.EmailNotification(recipients_email=self.recipients_email) - - @classmethod - def from_flyte_idl(cls, pb2_object): - """ - :param flyteidl.admin.common_pb2.EmailNotification pb2_object: - :rtype: EmailNotification - """ - return cls(pb2_object.recipients_email) - - -class SlackNotification(FlyteIdlEntity): - def __init__(self, recipients_email): - """ - :param list[Text] recipients_email: - """ - self._recipients_email = recipients_email - - @property - def recipients_email(self): - """ - :rtype: list[Text] - """ - return self._recipients_email - - def to_flyte_idl(self): - """ - :rtype: flyteidl.admin.common_pb2.SlackNotification - """ - return _common_pb2.SlackNotification(recipients_email=self.recipients_email) - - @classmethod - def from_flyte_idl(cls, pb2_object): - """ - :param flyteidl.admin.common_pb2.SlackNotification pb2_object: - :rtype: EmailNotification - """ - return cls(pb2_object.recipients_email) - - -class PagerDutyNotification(FlyteIdlEntity): - def __init__(self, recipients_email): - """ - :param list[Text] recipients_email: - """ - self._recipients_email = recipients_email - - @property - def recipients_email(self): - """ - :rtype: list[Text] - """ - return self._recipients_email - - def to_flyte_idl(self): - """ - :rtype: flyteidl.admin.common_pb2.PagerDutyNotification - """ - return _common_pb2.PagerDutyNotification(recipients_email=self.recipients_email) - - @classmethod - def from_flyte_idl(cls, pb2_object): - """ - :param flyteidl.admin.common_pb2.PagerDutyNotification pb2_object: - :rtype: EmailNotification - """ - return cls(pb2_object.recipients_email) - - -class Notification(FlyteIdlEntity): - def __init__( - self, - phases, - email: EmailNotification = None, - pager_duty: PagerDutyNotification = None, - slack: SlackNotification = None, - ): - """ - Represents a structure for notifications based on execution status. - - :param list[int] phases: A list of phases to which users can associate the notifications. - :param EmailNotification email: [Optional] Specify this for an email notification. - :param PagerDutyNotification email: [Optional] Specify this for a PagerDuty notification. - :param SlackNotification email: [Optional] Specify this for a Slack notification. - """ - self._phases = phases - self._email = email - self._pager_duty = pager_duty - self._slack = slack - - @property - def phases(self): - """ - A list of phases to which users can associate the notifications. - :rtype: list[int] - """ - return self._phases - - @property - def email(self): - """ - :rtype: EmailNotification - """ - return self._email - - @property - def pager_duty(self): - """ - :rtype: PagerDutyNotification - """ - return self._pager_duty - - @property - def slack(self): - """ - :rtype: SlackNotification - """ - return self._slack - - def to_flyte_idl(self): - """ - :rtype: flyteidl.admin.common_pb2.Notification - """ - return _common_pb2.Notification( - phases=self.phases, - email=self.email.to_flyte_idl() if self.email else None, - pager_duty=self.pager_duty.to_flyte_idl() if self.pager_duty else None, - slack=self.slack.to_flyte_idl() if self.slack else None, - ) - - @classmethod - def from_flyte_idl(cls, p): - """ - :param flyteidl.admin.common_pb2.Notification p: - :rtype: Notification - """ - return cls( - p.phases, - email=EmailNotification.from_flyte_idl(p.email) if p.HasField("email") else None, - pager_duty=PagerDutyNotification.from_flyte_idl(p.pager_duty) if p.HasField("pager_duty") else None, - slack=SlackNotification.from_flyte_idl(p.slack) if p.HasField("slack") else None, - ) - - -class Labels(FlyteIdlEntity): - def __init__(self, values): - """ - Label values to be applied to a workflow execution resource. - - :param dict[Text, Text] values: - """ - self._values = values - - @property - def values(self): - return self._values - - def to_flyte_idl(self): - """ - :rtype: dict[Text, Text] - """ - return _common_pb2.Labels(values={k: v for k, v in _six.iteritems(self.values)}) - - @classmethod - def from_flyte_idl(cls, pb2_object): - """ - :param flyteidl.admin.common_pb2.Labels pb2_object: - :rtype: Labels - """ - return cls({k: v for k, v in _six.iteritems(pb2_object.values)}) - - -class Annotations(FlyteIdlEntity): - def __init__(self, values): - """ - Annotation values to be applied to a workflow execution resource. - - :param dict[Text, Text] values: - """ - self._values = values - - @property - def values(self): - return self._values - - def to_flyte_idl(self): - """ - :rtype: _common_pb2.Annotations - """ - return _common_pb2.Annotations(values={k: v for k, v in _six.iteritems(self.values)}) - - @classmethod - def from_flyte_idl(cls, pb2_object): - """ - :param flyteidl.admin.common_pb2.Annotations pb2_object: - :rtype: Annotations - """ - return cls({k: v for k, v in _six.iteritems(pb2_object.values)}) - - -class UrlBlob(FlyteIdlEntity): - def __init__(self, url, bytes): - """ - :param Text url: - :param int bytes: - """ - self._url = url - self._bytes = bytes - - @property - def url(self): - """ - :rtype: Text - """ - return self._url - - @property - def bytes(self): - """ - :rtype: int - """ - return self._bytes - - def to_flyte_idl(self): - """ - :rtype: flyteidl.admin.common_pb2.UrlBlob - """ - return _common_pb2.UrlBlob(url=self.url, bytes=self.bytes) - - @classmethod - def from_flyte_idl(cls, pb): - """ - :param flyteidl.admin.common_pb2.UrlBlob pb: - :rtype: UrlBlob - """ - return cls(url=pb.url, bytes=pb.bytes) - - -class AuthRole(FlyteIdlEntity): - def __init__(self, assumable_iam_role=None, kubernetes_service_account=None): - """ - At most one of assumable_iam_role or kubernetes_service_account can be set. - :param Text assumable_iam_role: IAM identity with set permissions policies. - :param Text kubernetes_service_account: Provides an identity for workflow execution resources. Flyte deployment - administrators are responsible for handling permissions as they relate to the service account. - """ - self._assumable_iam_role = assumable_iam_role - self._kubernetes_service_account = kubernetes_service_account - - @property - def assumable_iam_role(self): - """ - The IAM role to execute the workflow with - :rtype: Text - """ - return self._assumable_iam_role - - @property - def kubernetes_service_account(self): - """ - The kubernetes service account to execute the workflow with - :rtype: Text - """ - return self._kubernetes_service_account - - def to_flyte_idl(self): - """ - :rtype: flyteidl.admin.launch_plan_pb2.Auth - """ - return _common_pb2.AuthRole( - assumable_iam_role=self.assumable_iam_role if self.assumable_iam_role else None, - kubernetes_service_account=self.kubernetes_service_account if self.kubernetes_service_account else None, - ) - - @classmethod - def from_flyte_idl(cls, pb2_object): - """ - :param flyteidl.admin.launch_plan_pb2.Auth pb2_object: - :rtype: Auth - """ - return cls( - assumable_iam_role=pb2_object.assumable_iam_role, - kubernetes_service_account=pb2_object.kubernetes_service_account, - ) - - -class RawOutputDataConfig(FlyteIdlEntity): - def __init__(self, output_location_prefix): - """ - :param Text output_location_prefix: Location of offloaded data for things like S3, etc. - """ - self._output_location_prefix = output_location_prefix - - @property - def output_location_prefix(self): - return self._output_location_prefix - - def to_flyte_idl(self): - """ - :rtype: flyteidl.admin.common_pb2.Auth - """ - return _common_pb2.RawOutputDataConfig(output_location_prefix=self.output_location_prefix) - - @classmethod - def from_flyte_idl(cls, pb2): - return cls(output_location_prefix=pb2.output_location_prefix) diff --git a/flytekit/models/core/compiler.py b/flytekit/models/core/compiler.py index 3246ee22b3..d50a6d8478 100644 --- a/flytekit/models/core/compiler.py +++ b/flytekit/models/core/compiler.py @@ -1,6 +1,8 @@ import six as _six from flyteidl.core import compiler_pb2 as _compiler_pb2 +import flytekit.models.admin.task +import flytekit.models.core.task from flytekit.models import common as _common from flytekit.models.core import workflow as _core_workflow_models @@ -121,35 +123,34 @@ def from_flyte_idl(cls, p): ) -# TODO: properly sort out the model code and remove one of these duplicate CompiledTasks class CompiledTask(_common.FlyteIdlEntity): def __init__(self, template): """ - :param TODO template: + :param flyteidl.core.CompiledTask.template template: """ self._template = template @property def template(self): """ - :rtype: TODO + :rtype: template """ return self._template def to_flyte_idl(self): """ - :rtype: flyteidl.core.compiler_pb2.CompiledTask + :rtype: flyteidl.core.CompiledTask """ - return _compiler_pb2.CompiledTask(template=self.template) # TODO: .to_flyte_idl() + return _compiler_pb2.CompiledTask(template=self.template.to_flyte_idl()) @classmethod def from_flyte_idl(cls, p): """ - :param flyteidl.core.compiler_pb2.CompiledTask p: + :param flyteidl.core.CompiledTask p: :rtype: CompiledTask """ # TODO: Refactor task so we don't have cyclical import - return cls(None) + return cls(template=flytekit.models.core.task.TaskTemplate.from_flyte_idl(p.template)) class CompiledWorkflowClosure(_common.FlyteIdlEntity): @@ -200,12 +201,9 @@ def from_flyte_idl(cls, p): :param flyteidl.core.compiler_pb2.CompiledWorkflowClosure p: :rtype: CompiledWorkflowClosure """ - # This import is here to prevent a circular dependency issue. - # TODO: properly sort out the model code and remove the duplicate CompiledTask - from flytekit.models.task import CompiledTask as _CompiledTask return cls( primary=CompiledWorkflow.from_flyte_idl(p.primary), sub_workflows=[CompiledWorkflow.from_flyte_idl(s) for s in p.sub_workflows], - tasks=[_CompiledTask.from_flyte_idl(t) for t in p.tasks], + tasks=[CompiledTask.from_flyte_idl(t) for t in p.tasks], ) diff --git a/flytekit/models/core/condition.py b/flytekit/models/core/condition.py index 845b3b4f79..941c5d9159 100644 --- a/flytekit/models/core/condition.py +++ b/flytekit/models/core/condition.py @@ -1,7 +1,7 @@ from flyteidl.core import condition_pb2 as _condition from flytekit.models import common as _common -from flytekit.models import literals as _literals +from flytekit.models.core import literals as _literals class ComparisonExpression(_common.FlyteIdlEntity): diff --git a/flytekit/models/dynamic_job.py b/flytekit/models/core/dynamic_job.py similarity index 89% rename from flytekit/models/dynamic_job.py rename to flytekit/models/core/dynamic_job.py index 44a985a5e7..a475c157bb 100644 --- a/flytekit/models/dynamic_job.py +++ b/flytekit/models/core/dynamic_job.py @@ -1,8 +1,8 @@ from flyteidl.core import dynamic_job_pb2 as _dynamic_job +import flytekit.models.core.task from flytekit.models import common as _common -from flytekit.models import literals as _literals -from flytekit.models import task as _task +from flytekit.models.core import literals as _literals from flytekit.models.core import workflow as _workflow @@ -11,7 +11,7 @@ def __init__(self, tasks, nodes, min_successes, outputs, subworkflows): """ Initializes a new FutureTaskDocument. - :param list[flytekit.models.task.TaskTemplate] tasks: A collection of unique tasks to execute. + :param list[flytekit.models.core.task.TaskTemplate] tasks: A collection of unique tasks to execute. :param list[flytekit.models.core.workflow.Node] nodes: A collection of task nodes. :param int min_successes: An absolute number of the minimum number of successful completions of subtasks. As soon as this criteria is met, the future job will be marked as successful and outputs will be computed. @@ -32,7 +32,7 @@ def __init__(self, tasks, nodes, min_successes, outputs, subworkflows): def tasks(self): """ A collection of tasks to execute. - :rtype: list[_task.TaskTemplate] + :rtype: list[flytekit.models.core.task.TaskTemplate] """ return self._tasks @@ -89,7 +89,9 @@ def from_flyte_idl(cls, pb2_object): :return: DynamicJobSpec """ return cls( - tasks=[_task.TaskTemplate.from_flyte_idl(task) for task in pb2_object.tasks] if pb2_object.tasks else None, + tasks=[flytekit.models.core.task.TaskTemplate.from_flyte_idl(task) for task in pb2_object.tasks] + if pb2_object.tasks + else None, nodes=[_workflow.Node.from_flyte_idl(n) for n in pb2_object.nodes], min_successes=pb2_object.min_successes, outputs=[_literals.Binding.from_flyte_idl(output) for output in pb2_object.outputs] diff --git a/flytekit/models/interface.py b/flytekit/models/core/interface.py similarity index 95% rename from flytekit/models/interface.py rename to flytekit/models/core/interface.py index 5364d39c3e..56ad82e465 100644 --- a/flytekit/models/interface.py +++ b/flytekit/models/core/interface.py @@ -3,15 +3,15 @@ import six as _six from flyteidl.core import interface_pb2 as _interface_pb2 +import flytekit.models.core.types from flytekit.models import common as _common -from flytekit.models import literals as _literals -from flytekit.models import types as _types +from flytekit.models.core import literals as _literals class Variable(_common.FlyteIdlEntity): def __init__(self, type, description): """ - :param flytekit.models.types.LiteralType type: This describes the type of value that must be provided to + :param flytekit.models.core.types.LiteralType type: This describes the type of value that must be provided to satisfy this variable. :param Text description: This is a help string that can provide context for what this variable means in relation to a task or workflow. @@ -23,7 +23,7 @@ def __init__(self, type, description): def type(self): """ This describes the type of value that must be provided to satisfy this variable. - :rtype: flytekit.models.types.LiteralType + :rtype: flytekit.models.core.types.LiteralType """ return self._type @@ -48,7 +48,7 @@ def from_flyte_idl(cls, variable_proto): :rtype: Variable """ return cls( - type=_types.LiteralType.from_flyte_idl(variable_proto.type), + type=flytekit.models.core.types.LiteralType.from_flyte_idl(variable_proto.type), description=variable_proto.description, ) diff --git a/flytekit/models/literals.py b/flytekit/models/core/literals.py similarity index 98% rename from flytekit/models/literals.py rename to flytekit/models/core/literals.py index 684fef95ba..ef197d5b65 100644 --- a/flytekit/models/literals.py +++ b/flytekit/models/core/literals.py @@ -7,8 +7,8 @@ from flytekit.common.exceptions import user as _user_exceptions from flytekit.models import common as _common from flytekit.models.core import types as _core_types -from flytekit.models.types import OutputReference as _OutputReference -from flytekit.models.types import SchemaType as _SchemaType +from flytekit.models.core.types import OutputReference as _OutputReference +from flytekit.models.core.types import SchemaType as _SchemaType class RetryStrategy(_common.FlyteIdlEntity): @@ -361,7 +361,7 @@ def __init__(self, scalar=None, collection=None, promise=None, map=None): :param Scalar scalar: [Optional] A simple scalar value. :param BindingDataCollection collection: [Optional] A collection of binding data. This allows nesting of binding data to any number of levels. - :param flytekit.models.types.OutputReference promise: [Optional] References an output promised by another node. + :param flytekit.models.core.types.OutputReference promise: [Optional] References an output promised by another node. :param BindingDataMap map: [Optional] A map of bindings. The key is always a string. """ self._scalar = scalar @@ -389,7 +389,7 @@ def collection(self): def promise(self): """ [Optional] References an output promised by another node. - :rtype: flytekit.models.types.OutputReference + :rtype: flytekit.models.core.types.OutputReference """ return self._promise @@ -505,7 +505,7 @@ def __init__(self, uri, type): A strongly typed schema that defines the interface of data retrieved from the underlying storage medium. :param Text uri: - :param flytekit.models.types.SchemaType type: + :param flytekit.models.core.types.SchemaType type: """ self._uri = uri self._type = type @@ -520,7 +520,7 @@ def uri(self): @property def type(self): """ - :rtype: flytekit.models.types.SchemaType + :rtype: flytekit.models.core.types.SchemaType """ return self._type diff --git a/flytekit/models/security.py b/flytekit/models/core/security.py similarity index 100% rename from flytekit/models/security.py rename to flytekit/models/core/security.py diff --git a/flytekit/models/task.py b/flytekit/models/core/task.py similarity index 73% rename from flytekit/models/task.py rename to flytekit/models/core/task.py index f12ce505d1..35dcdf9295 100644 --- a/flytekit/models/task.py +++ b/flytekit/models/core/task.py @@ -2,24 +2,17 @@ import typing import six as _six -from flyteidl.admin import task_pb2 as _admin_task -from flyteidl.core import compiler_pb2 as _compiler from flyteidl.core import literals_pb2 as _literals_pb2 from flyteidl.core import tasks_pb2 as _core_task -from flyteidl.plugins import pytorch_pb2 as _pytorch_task -from flyteidl.plugins import spark_pb2 as _spark_task -from flyteidl.plugins import tensorflow_pb2 as _tensorflow_task from google.protobuf import json_format as _json_format from google.protobuf import struct_pb2 as _struct -from flytekit.common.exceptions import user as _user_exceptions from flytekit.models import common as _common -from flytekit.models import interface as _interface -from flytekit.models import literals as _literals -from flytekit.models import security as _sec from flytekit.models.core import identifier as _identifier +from flytekit.models.core import interface as _interface +from flytekit.models.core import literals as _literals +from flytekit.models.core import security as _sec from flytekit.plugins import flyteidl as _lazy_flyteidl -from flytekit.sdk.spark_types import SparkType as _spark_type class Resources(_common.FlyteIdlEntity): @@ -117,633 +110,6 @@ def from_flyte_idl(cls, pb2_object): ) -class RuntimeMetadata(_common.FlyteIdlEntity): - class RuntimeType(object): - OTHER = 0 - FLYTE_SDK = 1 - - def __init__(self, type, version, flavor): - """ - :param int type: Enum type from RuntimeMetadata.RuntimeType - :param Text version: Version string for SDK version. Can be used for metrics or managing breaking changes in - Admin or Propeller - :param Text flavor: Optional extra information about runtime environment (e.g. Python, GoLang, etc.) - """ - self._type = type - self._version = version - self._flavor = flavor - - @property - def type(self): - """ - Enum type from RuntimeMetadata.RuntimeType - :rtype: int - """ - return self._type - - @property - def version(self): - """ - Version string for SDK version. Can be used for metrics or managing breaking changes in Admin or Propeller - :rtype: Text - """ - return self._version - - @property - def flavor(self): - """ - Optional extra information about runtime environment (e.g. Python, GoLang, etc.) - :rtype: Text - """ - return self._flavor - - def to_flyte_idl(self): - """ - :rtype: flyteidl.core.tasks_pb2.RuntimeMetadata - """ - return _core_task.RuntimeMetadata(type=self.type, version=self.version, flavor=self.flavor) - - @classmethod - def from_flyte_idl(cls, pb2_object): - """ - :param flyteidl.core.tasks_pb2.RuntimeMetadata pb2_object: - :rtype: RuntimeMetadata - """ - return cls(type=pb2_object.type, version=pb2_object.version, flavor=pb2_object.flavor) - - -class TaskMetadata(_common.FlyteIdlEntity): - def __init__( - self, - discoverable, - runtime, - timeout, - retries, - interruptible, - discovery_version, - deprecated_error_message, - ): - """ - Information needed at runtime to determine behavior such as whether or not outputs are discoverable, timeouts, - and retries. - - :param bool discoverable: Whether or not the outputs of this task should be cached for discovery. - :param RuntimeMetadata runtime: Metadata describing the runtime environment for this task. - :param datetime.timedelta timeout: The amount of time to wait before timing out. This includes queuing and - scheduler latency. - :param bool interruptible: Whether or not the task is interruptible. - :param flytekit.models.literals.RetryStrategy retries: Retry strategy for this task. 0 retries means only - try once. - :param Text discovery_version: This is the version used to create a logical version for data in the cache. - This is only used when `discoverable` is true. Data is considered discoverable if: the inputs to a given - task are the same and the discovery_version is also the same. - :param Text deprecated: This string can be used to mark the task as deprecated. Consumers of the task will - receive deprecation warnings. - """ - self._discoverable = discoverable - self._runtime = runtime - self._timeout = timeout - self._interruptible = interruptible - self._retries = retries - self._discovery_version = discovery_version - self._deprecated_error_message = deprecated_error_message - - @property - def discoverable(self): - """ - Whether or not the outputs of this task should be cached for discovery. - :rtype: bool - """ - return self._discoverable - - @property - def runtime(self): - """ - Metadata describing the runtime environment for this task. - :rtype: RuntimeMetadata - """ - return self._runtime - - @property - def retries(self): - """ - Retry strategy for this task. 0 retries means only try once. - :rtype: flytekit.models.literals.RetryStrategy - """ - return self._retries - - @property - def timeout(self): - """ - The amount of time to wait before timing out. This includes queuing and scheduler latency. - :rtype: datetime.timedelta - """ - return self._timeout - - @property - def interruptible(self): - """ - Whether or not the task is interruptible. - :rtype: bool - """ - return self._interruptible - - @property - def discovery_version(self): - """ - This is the version used to create a logical version for data in the cache. - This is only used when `discoverable` is true. Data is considered discoverable if: the inputs to a given - task are the same and the discovery_version is also the same. - :rtype: Text - """ - return self._discovery_version - - @property - def deprecated_error_message(self): - """ - This string can be used to mark the task as deprecated. Consumers of the task will receive deprecation - warnings. - :rtype: Text - """ - return self._deprecated_error_message - - def to_flyte_idl(self): - """ - :rtype: flyteidl.admin.task_pb2.TaskMetadata - """ - tm = _core_task.TaskMetadata( - discoverable=self.discoverable, - runtime=self.runtime.to_flyte_idl(), - retries=self.retries.to_flyte_idl(), - interruptible=self.interruptible, - discovery_version=self.discovery_version, - deprecated_error_message=self.deprecated_error_message, - ) - if self.timeout: - tm.timeout.FromTimedelta(self.timeout) - return tm - - @classmethod - def from_flyte_idl(cls, pb2_object): - """ - :param flyteidl.core.task_pb2.TaskMetadata pb2_object: - :rtype: TaskMetadata - """ - return cls( - discoverable=pb2_object.discoverable, - runtime=RuntimeMetadata.from_flyte_idl(pb2_object.runtime), - timeout=pb2_object.timeout.ToTimedelta(), - interruptible=pb2_object.interruptible if pb2_object.HasField("interruptible") else None, - retries=_literals.RetryStrategy.from_flyte_idl(pb2_object.retries), - discovery_version=pb2_object.discovery_version, - deprecated_error_message=pb2_object.deprecated_error_message, - ) - - -class TaskTemplate(_common.FlyteIdlEntity): - def __init__( - self, - id, - type, - metadata, - interface, - custom, - container=None, - task_type_version=0, - security_context=None, - config=None, - k8s_pod=None, - sql=None, - ): - """ - A task template represents the full set of information necessary to perform a unit of work in the Flyte system. - It contains the metadata about what inputs and outputs are consumed or produced. It also contains the metadata - necessary for Flyte Propeller to do the appropriate work. - - :param flytekit.models.core.identifier.Identifier id: This is generated by the system and uniquely identifies - the task. - :param Text type: This is used to define additional extensions for use by Propeller or SDK. - :param TaskMetadata metadata: This contains information needed at runtime to determine behavior such as - whether or not outputs are discoverable, timeouts, and retries. - :param flytekit.models.interface.TypedInterface interface: The interface definition for this task. - :param dict[Text, T] custom: Dictionary that must be serializable to a protobuf Struct for custom task plugins. - :param Container container: Provides the necessary entrypoint information for execution. For instance, - a Container might be specified with the necessary command line arguments. - :param int task_type_version: Specific version of this task type used by plugins to potentially modify - execution behavior or serialization. - :param dict[str, str] config: For plugin tasks this represents additional configuration information to be used - in tandem with the custom. - :param dict[str, str] config: For plugin tasks this represents additional configuration information to be used - in tandem with the custom. - :param K8sPod k8s_pod: Alternative to the container used to execute this task. - :param Sql sql: This is used to execute query in FlytePropeller instead of running container or k8s_pod. - """ - if ( - (container is not None and k8s_pod is not None) - or (container is not None and sql is not None) - or (k8s_pod is not None and sql is not None) - ): - raise ValueError("At most one of container, k8s_pod or sql can be set") - self._id = id - self._type = type - self._metadata = metadata - self._interface = interface - self._custom = custom - self._container = container - self._task_type_version = task_type_version - self._config = config - self._security_context = security_context - self._k8s_pod = k8s_pod - self._sql = sql - - @property - def id(self): - """ - This is generated by the system and uniquely identifies the task. - :rtype: flytekit.models.core.identifier.Identifier - """ - return self._id - - @property - def type(self): - """ - This is used to identify additional extensions for use by Propeller or SDK. - :rtype: Text - """ - return self._type - - @property - def metadata(self): - """ - This contains information needed at runtime to determine behavior such as whether or not outputs are - discoverable, timeouts, and retries. - :rtype: TaskMetadata - """ - return self._metadata - - @property - def interface(self): - """ - The interface definition for this task. - :rtype: flytekit.models.interface.TypedInterface - """ - return self._interface - - @property - def custom(self): - """ - Arbitrary dictionary containing metadata for custom plugins. - :rtype: dict[Text, T] - """ - return self._custom - - @property - def task_type_version(self): - return self._task_type_version - - @property - def container(self): - """ - If not None, the target of execution should be a container. - :rtype: Container - """ - return self._container - - @property - def config(self): - """ - Arbitrary dictionary containing metadata for parsing and handling custom plugins. - :rtype: dict[Text, T] - """ - return self._config - - @property - def security_context(self): - return self._security_context - - @property - def k8s_pod(self): - return self._k8s_pod - - @property - def sql(self): - return self._sql - - def to_flyte_idl(self): - """ - :rtype: flyteidl.core.tasks_pb2.TaskTemplate - """ - task_template = _core_task.TaskTemplate( - id=self.id.to_flyte_idl(), - type=self.type, - metadata=self.metadata.to_flyte_idl(), - interface=self.interface.to_flyte_idl(), - custom=_json_format.Parse(_json.dumps(self.custom), _struct.Struct()) if self.custom else None, - container=self.container.to_flyte_idl() if self.container else None, - task_type_version=self.task_type_version, - security_context=self.security_context.to_flyte_idl() if self.security_context else None, - config={k: v for k, v in self.config.items()} if self.config is not None else None, - k8s_pod=self.k8s_pod.to_flyte_idl() if self.k8s_pod else None, - sql=self.sql.to_flyte_idl() if self.sql else None, - ) - return task_template - - @classmethod - def from_flyte_idl(cls, pb2_object): - """ - :param flyteidl.core.tasks_pb2.TaskTemplate pb2_object: - :rtype: TaskTemplate - """ - return cls( - id=_identifier.Identifier.from_flyte_idl(pb2_object.id), - type=pb2_object.type, - metadata=TaskMetadata.from_flyte_idl(pb2_object.metadata), - interface=_interface.TypedInterface.from_flyte_idl(pb2_object.interface), - custom=_json_format.MessageToDict(pb2_object.custom) if pb2_object else None, - container=Container.from_flyte_idl(pb2_object.container) if pb2_object.HasField("container") else None, - task_type_version=pb2_object.task_type_version, - security_context=_sec.SecurityContext.from_flyte_idl(pb2_object.security_context) - if pb2_object.security_context - else None, - config={k: v for k, v in pb2_object.config.items()} if pb2_object.config is not None else None, - k8s_pod=K8sPod.from_flyte_idl(pb2_object.k8s_pod) if pb2_object.HasField("k8s_pod") else None, - sql=Sql.from_flyte_idl(pb2_object.sql) if pb2_object.HasField("sql") else None, - ) - - -class TaskSpec(_common.FlyteIdlEntity): - def __init__(self, template): - """ - :param TaskTemplate template: - """ - self._template = template - - @property - def template(self): - """ - :rtype: TaskTemplate - """ - return self._template - - def to_flyte_idl(self): - """ - :rtype: flyteidl.admin.tasks_pb2.TaskSpec - """ - return _admin_task.TaskSpec(template=self.template.to_flyte_idl()) - - @classmethod - def from_flyte_idl(cls, pb2_object): - """ - :param flyteidl.admin.tasks_pb2.TaskSpec pb2_object: - :rtype: TaskSpec - """ - return cls(TaskTemplate.from_flyte_idl(pb2_object.template)) - - -class Task(_common.FlyteIdlEntity): - def __init__(self, id, closure): - """ - :param flytekit.models.core.identifier.Identifier id: The (project, domain, name) identifier for this task. - :param TaskClosure closure: The closure for the underlying workload. - """ - self._id = id - self._closure = closure - - @property - def id(self): - """ - The (project, domain, name, version) identifier for this task. - :rtype: flytekit.models.core.identifier.Identifier - """ - return self._id - - @property - def closure(self): - """ - The closure for the underlying workload. - :rtype: TaskClosure - """ - return self._closure - - def to_flyte_idl(self): - """ - :rtype: flyteidl.admin.task_pb2.Task - """ - return _admin_task.Task( - closure=self.closure.to_flyte_idl(), - id=self.id.to_flyte_idl(), - ) - - @classmethod - def from_flyte_idl(cls, pb2_object): - """ - :param flyteidl.admin.task_pb2.Task pb2_object: - :rtype: TaskDefinition - """ - return cls( - closure=TaskClosure.from_flyte_idl(pb2_object.closure), - id=_identifier.Identifier.from_flyte_idl(pb2_object.id), - ) - - -class TaskClosure(_common.FlyteIdlEntity): - def __init__(self, compiled_task): - """ - :param CompiledTask compiled_task: - """ - self._compiled_task = compiled_task - - @property - def compiled_task(self): - """ - :rtype: CompiledTask - """ - return self._compiled_task - - def to_flyte_idl(self): - """ - :rtype: flyteidl.admin.task_pb2.TaskClosure - """ - return _admin_task.TaskClosure(compiled_task=self.compiled_task.to_flyte_idl()) - - @classmethod - def from_flyte_idl(cls, pb2_object): - """ - :param flyteidl.admin.task_pb2.TaskClosure pb2_object: - :rtype: TaskClosure - """ - return cls(compiled_task=CompiledTask.from_flyte_idl(pb2_object.compiled_task)) - - -class CompiledTask(_common.FlyteIdlEntity): - def __init__(self, template): - """ - :param TaskTemplate template: - """ - self._template = template - - @property - def template(self): - """ - :rtype: TaskTemplate - """ - return self._template - - def to_flyte_idl(self): - """ - :rtype: flyteidl.core.compiler_pb2.CompiledTask - """ - return _compiler.CompiledTask(template=self.template.to_flyte_idl()) - - @classmethod - def from_flyte_idl(cls, pb2_object): - """ - :param flyteidl.core.compiler_pb2.CompiledTask pb2_object: - :rtype: CompiledTask - """ - return cls(template=TaskTemplate.from_flyte_idl(pb2_object.template)) - - -class SparkJob(_common.FlyteIdlEntity): - def __init__( - self, - spark_type, - application_file, - main_class, - spark_conf, - hadoop_conf, - executor_path, - ): - """ - This defines a SparkJob target. It will execute the appropriate SparkJob. - - :param application_file: The main application file to execute. - :param dict[Text, Text] spark_conf: A definition of key-value pairs for spark config for the job. - :param dict[Text, Text] hadoop_conf: A definition of key-value pairs for hadoop config for the job. - """ - self._application_file = application_file - self._spark_type = spark_type - self._main_class = main_class - self._executor_path = executor_path - self._spark_conf = spark_conf - self._hadoop_conf = hadoop_conf - - def with_overrides( - self, new_spark_conf: typing.Dict[str, str] = None, new_hadoop_conf: typing.Dict[str, str] = None - ) -> "SparkJob": - if not new_spark_conf: - new_spark_conf = self.spark_conf - - if not new_hadoop_conf: - new_hadoop_conf = self.hadoop_conf - - return SparkJob( - spark_type=self.spark_type, - application_file=self.application_file, - main_class=self.main_class, - spark_conf=new_spark_conf, - hadoop_conf=new_hadoop_conf, - executor_path=self.executor_path, - ) - - @property - def main_class(self): - """ - The main class to execute - :rtype: Text - """ - return self._main_class - - @property - def spark_type(self): - """ - Spark Job Type - :rtype: Text - """ - return self._spark_type - - @property - def application_file(self): - """ - The main application file to execute - :rtype: Text - """ - return self._application_file - - @property - def executor_path(self): - """ - The python executable to use - :rtype: Text - """ - return self._executor_path - - @property - def spark_conf(self): - """ - A definition of key-value pairs for spark config for the job. - :rtype: dict[Text, Text] - """ - return self._spark_conf - - @property - def hadoop_conf(self): - """ - A definition of key-value pairs for hadoop config for the job. - :rtype: dict[Text, Text] - """ - return self._hadoop_conf - - def to_flyte_idl(self): - """ - :rtype: flyteidl.plugins.spark_pb2.SparkJob - """ - - if self.spark_type == _spark_type.PYTHON: - application_type = _spark_task.SparkApplication.PYTHON - elif self.spark_type == _spark_type.JAVA: - application_type = _spark_task.SparkApplication.JAVA - elif self.spark_type == _spark_type.SCALA: - application_type = _spark_task.SparkApplication.SCALA - elif self.spark_type == _spark_type.R: - application_type = _spark_task.SparkApplication.R - else: - raise _user_exceptions.FlyteValidationException("Invalid Spark Application Type Specified") - - return _spark_task.SparkJob( - applicationType=application_type, - mainApplicationFile=self.application_file, - mainClass=self.main_class, - executorPath=self.executor_path, - sparkConf=self.spark_conf, - hadoopConf=self.hadoop_conf, - ) - - @classmethod - def from_flyte_idl(cls, pb2_object): - """ - :param flyteidl.plugins.spark_pb2.SparkJob pb2_object: - :rtype: SparkJob - """ - - application_type = _spark_type.PYTHON - if pb2_object.type == _spark_task.SparkApplication.JAVA: - application_type = _spark_type.JAVA - elif pb2_object.type == _spark_task.SparkApplication.SCALA: - application_type = _spark_type.SCALA - elif pb2_object.type == _spark_task.SparkApplication.R: - application_type = _spark_type.R - - return cls( - type=application_type, - spark_conf=pb2_object.sparkConf, - application_file=pb2_object.mainApplicationFile, - main_class=pb2_object.mainClass, - hadoop_conf=pb2_object.hadoopConf, - executor_path=pb2_object.executorPath, - ) - - class IOStrategy(_common.FlyteIdlEntity): """ Provides methods to manage data in and out of the Raw container using Download Modes. This can only be used if DataLoadingConfig is enabled. @@ -833,7 +199,7 @@ def __init__(self, image, command, args, resources, env, config, architecture=No :param Text image: The fully-qualified identifier for the image. :param list[Text] command: A list of 'words' for the command. i.e. ['aws', 's3', 'ls'] :param list[Text] args: A list of arguments for the command. i.e. ['s3://some/path', '/tmp/local/path'] - :param Resources resources: A definition of requisite compute resources. + :param flytekit.models.core.task.Resources resources: A definition of requisite compute resources. :param dict[Text, Text] env: A definition of key-value pairs for environment variables. :param dict[Text, Text] config: A definition of configuration key-value pairs. : param Architecture: Architecture supported by this container's image. @@ -884,7 +250,7 @@ def args(self): def resources(self): """ A definition of requisite compute resources. - :rtype: Resources + :rtype: flytekit.models.core.task.Resources """ return self._resources @@ -909,7 +275,7 @@ def config(self): @property def data_loading_config(self): """ - :rtype: DataLoadingConfig + :rtype: flytekit.models.core.task.DataLoadingConfig """ return self._data_loading_config @@ -931,7 +297,7 @@ def to_flyte_idl(self): @classmethod def from_flyte_idl(cls, pb2_object): """ - :param flyteidl.admin.task_pb2.Task pb2_object: + :param flyteidl.core.tasks_pb2.Container pb2_object: :rtype: Container """ return cls( @@ -1027,136 +393,438 @@ def statement(self) -> str: return self._statement @property - def dialect(self) -> int: - return self._dialect + def dialect(self) -> int: + return self._dialect + + def to_flyte_idl(self) -> _core_task.Sql: + return _core_task.Sql(statement=self.statement, dialect=self.dialect) + + @classmethod + def from_flyte_idl(cls, pb2_object: _core_task.Sql): + return cls( + statement=pb2_object.statement, + dialect=pb2_object.dialect, + ) + + +class SidecarJob(_common.FlyteIdlEntity): + def __init__(self, pod_spec, primary_container_name, annotations=None, labels=None): + """ + A sidecar job represents the full kubernetes pod spec and related metadata required for executing a sidecar + task. + + :param pod_spec: k8s.io.api.core.v1.PodSpec + :param primary_container_name: Text + :param dict[Text, Text] annotations: + :param dict[Text, Text] labels: + """ + self._pod_spec = pod_spec + self._primary_container_name = primary_container_name + self._annotations = annotations + self._labels = labels + + @property + def pod_spec(self): + """ + :rtype: k8s.io.api.core.v1.PodSpec + """ + return self._pod_spec + + @property + def primary_container_name(self): + """ + :rtype: Text + """ + return self._primary_container_name + + @property + def annotations(self): + """ + :rtype: dict[Text,Text] + """ + return self._annotations + + @property + def labels(self): + """ + :rtype: dict[Text,Text] + """ + return self._labels + + def to_flyte_idl(self): + """ + :rtype: flyteidl.core.tasks_pb2.SidecarJob + """ + return _lazy_flyteidl.plugins.sidecar_pb2.SidecarJob( + pod_spec=self.pod_spec, + primary_container_name=self.primary_container_name, + annotations=self.annotations, + labels=self.labels, + ) + + @classmethod + def from_flyte_idl(cls, pb2_object): + """ + :param flyteidl.admin.task_pb2.Task pb2_object: + :rtype: flytekit.models.core.task.Container + """ + return cls( + pod_spec=pb2_object.pod_spec, + primary_container_name=pb2_object.primary_container_name, + annotations=pb2_object.annotations, + labels=pb2_object.labels, + ) + + +class TaskMetadata(_common.FlyteIdlEntity): + def __init__( + self, + discoverable, + runtime, + timeout, + retries, + interruptible, + discovery_version, + deprecated_error_message, + ): + """ + Information needed at runtime to determine behavior such as whether or not outputs are discoverable, timeouts, + and retries. + + :param bool discoverable: Whether or not the outputs of this task should be cached for discovery. + :param flytekit.models.core.task.RuntimeMetadata runtime: Metadata describing the runtime environment for this task. + :param datetime.timedelta timeout: The amount of time to wait before timing out. This includes queuing and + scheduler latency. + :param bool interruptible: Whether or not the task is interruptible. + :param flytekit.models.literals.RetryStrategy retries: Retry strategy for this task. 0 retries means only + try once. + :param Text discovery_version: This is the version used to create a logical version for data in the cache. + This is only used when `discoverable` is true. Data is considered discoverable if: the inputs to a given + task are the same and the discovery_version is also the same. + :param Text deprecated: This string can be used to mark the task as deprecated. Consumers of the task will + receive deprecation warnings. + """ + self._discoverable = discoverable + self._runtime = runtime + self._timeout = timeout + self._interruptible = interruptible + self._retries = retries + self._discovery_version = discovery_version + self._deprecated_error_message = deprecated_error_message + + @property + def discoverable(self): + """ + Whether or not the outputs of this task should be cached for discovery. + :rtype: bool + """ + return self._discoverable + + @property + def runtime(self): + """ + Metadata describing the runtime environment for this task. + :rtype: flytekit.models.core.task.RuntimeMetadata + """ + return self._runtime + + @property + def retries(self): + """ + Retry strategy for this task. 0 retries means only try once. + :rtype: flytekit.models.literals.RetryStrategy + """ + return self._retries + + @property + def timeout(self): + """ + The amount of time to wait before timing out. This includes queuing and scheduler latency. + :rtype: datetime.timedelta + """ + return self._timeout + + @property + def interruptible(self): + """ + Whether or not the task is interruptible. + :rtype: bool + """ + return self._interruptible + + @property + def discovery_version(self): + """ + This is the version used to create a logical version for data in the cache. + This is only used when `discoverable` is true. Data is considered discoverable if: the inputs to a given + task are the same and the discovery_version is also the same. + :rtype: Text + """ + return self._discovery_version + + @property + def deprecated_error_message(self): + """ + This string can be used to mark the task as deprecated. Consumers of the task will receive deprecation + warnings. + :rtype: Text + """ + return self._deprecated_error_message - def to_flyte_idl(self) -> _core_task.Sql: - return _core_task.Sql(statement=self.statement, dialect=self.dialect) + def to_flyte_idl(self): + """ + :rtype: flyteidl.core.task_pb2.TaskMetadata + """ + tm = _core_task.TaskMetadata( + discoverable=self.discoverable, + runtime=self.runtime.to_flyte_idl(), + retries=self.retries.to_flyte_idl(), + interruptible=self.interruptible, + discovery_version=self.discovery_version, + deprecated_error_message=self.deprecated_error_message, + ) + if self.timeout: + tm.timeout.FromTimedelta(self.timeout) + return tm @classmethod - def from_flyte_idl(cls, pb2_object: _core_task.Sql): + def from_flyte_idl(cls, pb2_object): + """ + :param flyteidl.core.task_pb2.TaskMetadata pb2_object: + :rtype: TaskMetadata + """ return cls( - statement=pb2_object.statement, - dialect=pb2_object.dialect, + discoverable=pb2_object.discoverable, + runtime=RuntimeMetadata.from_flyte_idl(pb2_object.runtime), + timeout=pb2_object.timeout.ToTimedelta(), + interruptible=pb2_object.interruptible if pb2_object.HasField("interruptible") else None, + retries=_literals.RetryStrategy.from_flyte_idl(pb2_object.retries), + discovery_version=pb2_object.discovery_version, + deprecated_error_message=pb2_object.deprecated_error_message, ) -class SidecarJob(_common.FlyteIdlEntity): - def __init__(self, pod_spec, primary_container_name, annotations=None, labels=None): +class TaskTemplate(_common.FlyteIdlEntity): + def __init__( + self, + id, + type, + metadata, + interface, + custom, + container=None, + task_type_version=0, + security_context=None, + config=None, + k8s_pod=None, + sql=None, + ): """ - A sidecar job represents the full kubernetes pod spec and related metadata required for executing a sidecar - task. + A task template represents the full set of information necessary to perform a unit of work in the Flyte system. + It contains the metadata about what inputs and outputs are consumed or produced. It also contains the metadata + necessary for Flyte Propeller to do the appropriate work. - :param pod_spec: k8s.io.api.core.v1.PodSpec - :param primary_container_name: Text - :param dict[Text, Text] annotations: - :param dict[Text, Text] labels: + :param flytekit.models.core.identifier.Identifier id: This is generated by the system and uniquely identifies + the task. + :param Text type: This is used to define additional extensions for use by Propeller or SDK. + :param flytekit.models.core.task.TaskMetadata metadata: This contains information needed at runtime to determine behavior such as + whether or not outputs are discoverable, timeouts, and retries. + :param flytekit.models.interface.TypedInterface interface: The interface definition for this task. + :param dict[Text, T] custom: Dictionary that must be serializable to a protobuf Struct for custom task plugins. + :param Container container: Provides the necessary entrypoint information for execution. For instance, + a Container might be specified with the necessary command line arguments. + :param int task_type_version: Specific version of this task type used by plugins to potentially modify + execution behavior or serialization. + :param dict[str, str] config: For plugin tasks this represents additional configuration information to be used + in tandem with the custom. + :param dict[str, str] config: For plugin tasks this represents additional configuration information to be used + in tandem with the custom. + :param K8sPod k8s_pod: Alternative to the container used to execute this task. + :param Sql sql: This is used to execute query in FlytePropeller instead of running container or k8s_pod. """ - self._pod_spec = pod_spec - self._primary_container_name = primary_container_name - self._annotations = annotations - self._labels = labels + if ( + (container is not None and k8s_pod is not None) + or (container is not None and sql is not None) + or (k8s_pod is not None and sql is not None) + ): + raise ValueError("At most one of container, k8s_pod or sql can be set") + self._id = id + self._type = type + self._metadata = metadata + self._interface = interface + self._custom = custom + self._container = container + self._task_type_version = task_type_version + self._config = config + self._security_context = security_context + self._k8s_pod = k8s_pod + self._sql = sql @property - def pod_spec(self): + def id(self): """ - :rtype: k8s.io.api.core.v1.PodSpec + This is generated by the system and uniquely identifies the task. + :rtype: flytekit.models.core.identifier.Identifier """ - return self._pod_spec + return self._id @property - def primary_container_name(self): + def type(self): """ + This is used to identify additional extensions for use by Propeller or SDK. :rtype: Text """ - return self._primary_container_name + return self._type @property - def annotations(self): + def metadata(self): """ - :rtype: dict[Text,Text] + This contains information needed at runtime to determine behavior such as whether or not outputs are + discoverable, timeouts, and retries. + :rtype: flytekit.models.core.task.TaskMetadata """ - return self._annotations + return self._metadata @property - def labels(self): + def interface(self): """ - :rtype: dict[Text,Text] + The interface definition for this task. + :rtype: flytekit.models.interface.TypedInterface """ - return self._labels + return self._interface - def to_flyte_idl(self): + @property + def custom(self): """ - :rtype: flyteidl.core.tasks_pb2.SidecarJob + Arbitrary dictionary containing metadata for custom plugins. + :rtype: dict[Text, T] """ - return _lazy_flyteidl.plugins.sidecar_pb2.SidecarJob( - pod_spec=self.pod_spec, - primary_container_name=self.primary_container_name, - annotations=self.annotations, - labels=self.labels, - ) + return self._custom - @classmethod - def from_flyte_idl(cls, pb2_object): + @property + def task_type_version(self): + return self._task_type_version + + @property + def container(self): """ - :param flyteidl.admin.task_pb2.Task pb2_object: + If not None, the target of execution should be a container. :rtype: Container """ - return cls( - pod_spec=pb2_object.pod_spec, - primary_container_name=pb2_object.primary_container_name, - annotations=pb2_object.annotations, - labels=pb2_object.labels, - ) + return self._container + + @property + def config(self): + """ + Arbitrary dictionary containing metadata for parsing and handling custom plugins. + :rtype: dict[Text, T] + """ + return self._config + @property + def security_context(self): + return self._security_context -class PyTorchJob(_common.FlyteIdlEntity): - def __init__(self, workers_count): - self._workers_count = workers_count + @property + def k8s_pod(self): + return self._k8s_pod @property - def workers_count(self): - return self._workers_count + def sql(self): + return self._sql def to_flyte_idl(self): - return _pytorch_task.DistributedPyTorchTrainingTask( - workers=self.workers_count, + """ + :rtype: flyteidl.core.tasks_pb2.TaskTemplate + """ + task_template = _core_task.TaskTemplate( + id=self.id.to_flyte_idl(), + type=self.type, + metadata=self.metadata.to_flyte_idl(), + interface=self.interface.to_flyte_idl(), + custom=_json_format.Parse(_json.dumps(self.custom), _struct.Struct()) if self.custom else None, + container=self.container.to_flyte_idl() if self.container else None, + task_type_version=self.task_type_version, + security_context=self.security_context.to_flyte_idl() if self.security_context else None, + config={k: v for k, v in self.config.items()} if self.config is not None else None, + k8s_pod=self.k8s_pod.to_flyte_idl() if self.k8s_pod else None, + sql=self.sql.to_flyte_idl() if self.sql else None, ) + return task_template @classmethod def from_flyte_idl(cls, pb2_object): + """ + :param flyteidl.core.tasks_pb2.TaskTemplate pb2_object: + :rtype: TaskTemplate + """ return cls( - workers_count=pb2_object.workers, + id=_identifier.Identifier.from_flyte_idl(pb2_object.id), + type=pb2_object.type, + metadata=TaskMetadata.from_flyte_idl(pb2_object.metadata), + interface=_interface.TypedInterface.from_flyte_idl(pb2_object.interface), + custom=_json_format.MessageToDict(pb2_object.custom) if pb2_object else None, + container=Container.from_flyte_idl(pb2_object.container) if pb2_object.HasField("container") else None, + task_type_version=pb2_object.task_type_version, + security_context=_sec.SecurityContext.from_flyte_idl(pb2_object.security_context) + if pb2_object.security_context + else None, + config={k: v for k, v in pb2_object.config.items()} if pb2_object.config is not None else None, + k8s_pod=K8sPod.from_flyte_idl(pb2_object.k8s_pod) if pb2_object.HasField("k8s_pod") else None, + sql=Sql.from_flyte_idl(pb2_object.sql) if pb2_object.HasField("sql") else None, ) -class TensorFlowJob(_common.FlyteIdlEntity): - def __init__(self, workers_count, ps_replicas_count, chief_replicas_count): - self._workers_count = workers_count - self._ps_replicas_count = ps_replicas_count - self._chief_replicas_count = chief_replicas_count +class RuntimeMetadata(_common.FlyteIdlEntity): + class RuntimeType(object): + OTHER = 0 + FLYTE_SDK = 1 + + def __init__(self, type, version, flavor): + """ + :param int type: Enum type from RuntimeMetadata.RuntimeType + :param Text version: Version string for SDK version. Can be used for metrics or managing breaking changes in + Admin or Propeller + :param Text flavor: Optional extra information about runtime environment (e.g. Python, GoLang, etc.) + """ + self._type = type + self._version = version + self._flavor = flavor @property - def workers_count(self): - return self._workers_count + def type(self): + """ + Enum type from RuntimeMetadata.RuntimeType + :rtype: int + """ + return self._type @property - def ps_replicas_count(self): - return self._ps_replicas_count + def version(self): + """ + Version string for SDK version. Can be used for metrics or managing breaking changes in Admin or Propeller + :rtype: Text + """ + return self._version @property - def chief_replicas_count(self): - return self._chief_replicas_count + def flavor(self): + """ + Optional extra information about runtime environment (e.g. Python, GoLang, etc.) + :rtype: Text + """ + return self._flavor def to_flyte_idl(self): - return _tensorflow_task.DistributedTensorflowTrainingTask( - workers=self.workers_count, ps_replicas=self.ps_replicas_count, chief_replicas=self.chief_replicas_count - ) + """ + :rtype: flyteidl.core.tasks_pb2.RuntimeMetadata + """ + return _core_task.RuntimeMetadata(type=self.type, version=self.version, flavor=self.flavor) @classmethod def from_flyte_idl(cls, pb2_object): - return cls( - workers_count=pb2_object.workers, - ps_replicas_count=pb2_object.ps_replicas, - chief_replicas_count=pb2_object.chief_replicas, - ) + """ + :param flyteidl.core.tasks_pb2.RuntimeMetadata pb2_object: + :rtype: RuntimeMetadata + """ + return cls(type=pb2_object.type, version=pb2_object.version, flavor=pb2_object.flavor) diff --git a/flytekit/models/core/types.py b/flytekit/models/core/types.py index 3ec44baddf..879293f1d1 100644 --- a/flytekit/models/core/types.py +++ b/flytekit/models/core/types.py @@ -1,10 +1,26 @@ +import json as _json import typing from flyteidl.core import types_pb2 as _types_pb2 +from google.protobuf import json_format as _json_format +from google.protobuf import struct_pb2 as _struct from flytekit.models import common as _common +class SimpleType(object): + NONE = _types_pb2.NONE + INTEGER = _types_pb2.INTEGER + FLOAT = _types_pb2.FLOAT + STRING = _types_pb2.STRING + BOOLEAN = _types_pb2.BOOLEAN + DATETIME = _types_pb2.DATETIME + DURATION = _types_pb2.DURATION + BINARY = _types_pb2.BINARY + ERROR = _types_pb2.ERROR + STRUCT = _types_pb2.STRUCT + + class EnumType(_common.FlyteIdlEntity): """ Models _types_pb2.EnumType @@ -69,3 +85,256 @@ def from_flyte_idl(cls, proto): :rtype: BlobType """ return cls(format=proto.format, dimensionality=proto.dimensionality) + + +class SchemaType(_common.FlyteIdlEntity): + class SchemaColumn(_common.FlyteIdlEntity): + class SchemaColumnType(object): + INTEGER = _types_pb2.SchemaType.SchemaColumn.INTEGER + FLOAT = _types_pb2.SchemaType.SchemaColumn.FLOAT + STRING = _types_pb2.SchemaType.SchemaColumn.STRING + DATETIME = _types_pb2.SchemaType.SchemaColumn.DATETIME + DURATION = _types_pb2.SchemaType.SchemaColumn.DURATION + BOOLEAN = _types_pb2.SchemaType.SchemaColumn.BOOLEAN + + def __init__(self, name, type): + """ + :param Text name: Name for the column + :param int type: Enum type from SchemaType.SchemaColumn.SchemaColumnType representing the type of the column + """ + self._name = name + self._type = type + + @property + def name(self): + """ + Name for the column + :rtype: Text + """ + return self._name + + @property + def type(self): + """ + Enum type from SchemaType.SchemaColumn.SchemaColumnType representing the type of the column + :rtype: int + """ + return self._type + + def to_flyte_idl(self): + """ + :rtype: flyteidl.core.types_pb2.SchemaType.SchemaColumn + """ + return _types_pb2.SchemaType.SchemaColumn(name=self.name, type=self.type) + + @classmethod + def from_flyte_idl(cls, proto): + """ + :param flyteidl.core.types_pb2.SchemaType.SchemaColumn proto: + :rtype: SchemaType.SchemaColumn + """ + return cls(name=proto.name, type=proto.type) + + def __init__(self, columns): + """ + :param list[SchemaType.SchemaColumn] columns: A list of columns defining the underlying data frame. + """ + self._columns = columns + + @property + def columns(self): + """ + A list of columns defining the underlying data frame. + :rtype: list[SchemaType.SchemaColumn] + """ + return self._columns + + def to_flyte_idl(self): + """ + :rtype: flyteidl.core.types_pb2.SchemaType + """ + return _types_pb2.SchemaType(columns=[c.to_flyte_idl() for c in self.columns]) + + @classmethod + def from_flyte_idl(cls, proto): + """ + :param flyteidl.core.types_pb2.SchemaType proto: + :rtype: SchemaType + """ + return cls(columns=[SchemaType.SchemaColumn.from_flyte_idl(c) for c in proto.columns]) + + +class LiteralType(_common.FlyteIdlEntity): + def __init__( + self, + simple=None, + schema=None, + collection_type=None, + map_value_type=None, + blob=None, + enum_type=None, + metadata=None, + ): + """ + Only one of the kwargs may be set. + :param int simple: Enum type from SimpleType + :param flytekit.models.core.types.SchemaType schema: Type definition for a dataframe-like object. + :param LiteralType collection_type: For list-like objects, this is the type of each entry in the list. + :param LiteralType map_value_type: For map objects, this is the type of the value. The key must always be a + string. + :param flytekit.models.core.types.BlobType blob: For blob objects, this describes the type. + :param flytekit.models.core.types.EnumType enum_type: For enum objects, describes an enum + :param dict[Text, T] metadata: Additional data describing the type + """ + self._simple = simple + self._schema = schema + self._collection_type = collection_type + self._map_value_type = map_value_type + self._blob = blob + self._enum_type = enum_type + self._metadata = metadata + + @property + def simple(self) -> SimpleType: + return self._simple + + @property + def schema(self) -> SchemaType: + return self._schema + + @property + def collection_type(self) -> "LiteralType": + """ + The collection value type + """ + return self._collection_type + + @property + def map_value_type(self) -> "LiteralType": + """ + The Value for a dictionary. Key is always string + """ + return self._map_value_type + + @property + def blob(self) -> BlobType: + return self._blob + + @property + def enum_type(self) -> EnumType: + return self._enum_type + + @property + def metadata(self): + """ + :rtype: dict[Text, T] + """ + return self._metadata + + def to_flyte_idl(self): + """ + :rtype: flyteidl.core.types_pb2.LiteralType + """ + if self.metadata is not None: + metadata = _json_format.Parse(_json.dumps(self.metadata), _struct.Struct()) + else: + metadata = None + t = _types_pb2.LiteralType( + simple=self.simple if self.simple is not None else None, + schema=self.schema.to_flyte_idl() if self.schema is not None else None, + collection_type=self.collection_type.to_flyte_idl() if self.collection_type is not None else None, + map_value_type=self.map_value_type.to_flyte_idl() if self.map_value_type is not None else None, + blob=self.blob.to_flyte_idl() if self.blob is not None else None, + enum_type=self.enum_type.to_flyte_idl() if self.enum_type else None, + metadata=metadata, + ) + return t + + @classmethod + def from_flyte_idl(cls, proto): + """ + :param flyteidl.core.types_pb2.LiteralType proto: + :rtype: LiteralType + """ + collection_type = None + map_value_type = None + if proto.HasField("collection_type"): + collection_type = LiteralType.from_flyte_idl(proto.collection_type) + if proto.HasField("map_value_type"): + map_value_type = LiteralType.from_flyte_idl(proto.map_value_type) + return cls( + simple=proto.simple if proto.HasField("simple") else None, + schema=SchemaType.from_flyte_idl(proto.schema) if proto.HasField("schema") else None, + collection_type=collection_type, + map_value_type=map_value_type, + blob=BlobType.from_flyte_idl(proto.blob) if proto.HasField("blob") else None, + enum_type=EnumType.from_flyte_idl(proto.enum_type) if proto.HasField("enum_type") else None, + metadata=_json_format.MessageToDict(proto.metadata) or None, + ) + + +class OutputReference(_common.FlyteIdlEntity): + def __init__(self, node_id, var): + """ + A reference to an output produced by a node. The type can be retrieved -and validated- from + the underlying interface of the node. + + :param Text node_id: Node id must exist at the graph layer. + :param Text var: Variable name must refer to an output variable for the node. + """ + self._node_id = node_id + self._var = var + + @property + def node_id(self): + """ + Node id must exist at the graph layer. + :rtype: Text + """ + return self._node_id + + @property + def var(self): + """ + Variable name must refer to an output variable for the node. + :rtype: Text + """ + return self._var + + @var.setter + def var(self, var_name): + self._var = var_name + + def to_flyte_idl(self): + """ + :rtype: flyteidl.core.types.OutputReference + """ + return _types_pb2.OutputReference(node_id=self.node_id, var=self.var) + + @classmethod + def from_flyte_idl(cls, pb2_object): + """ + :param flyteidl.core.types.OutputReference pb2_object: + :rtype: OutputReference + """ + return cls(node_id=pb2_object.node_id, var=pb2_object.var) + + +class Error(_common.FlyteIdlEntity): + def __init__(self, failed_node_id: str, message: str): + self._message = message + self._failed_node_id = failed_node_id + + def to_flyte_idl(self) -> _types_pb2.Error: + return _types_pb2.Error( + message=self._message, + failed_node_id=self._failed_node_id, + ) + + @classmethod + def from_flyte_idl(cls, pb2_object: _types_pb2.Error) -> "Error": + """ + :param flyteidl.core.types.OutputReference pb2_object: + :rtype: flytekit.models.core.types.OutputReference + """ + return cls(failed_node_id=pb2_object.failed_node_id, message=pb2_object.message) diff --git a/flytekit/models/core/workflow.py b/flytekit/models/core/workflow.py index 6338a6695f..3a31562414 100644 --- a/flytekit/models/core/workflow.py +++ b/flytekit/models/core/workflow.py @@ -3,14 +3,14 @@ from flyteidl.core import workflow_pb2 as _core_workflow +import flytekit.models.core.types from flytekit.models import common as _common -from flytekit.models import interface as _interface -from flytekit.models import types as _types from flytekit.models.core import condition as _condition from flytekit.models.core import identifier as _identifier -from flytekit.models.literals import Binding as _Binding -from flytekit.models.literals import RetryStrategy as _RetryStrategy -from flytekit.models.task import Resources +from flytekit.models.core import interface as _interface +from flytekit.models.core.literals import Binding as _Binding +from flytekit.models.core.literals import RetryStrategy as _RetryStrategy +from flytekit.models.core.task import Resources class IfBlock(_common.FlyteIdlEntity): @@ -61,7 +61,7 @@ def __init__(self, case, other=None, else_node=None, error=None): :param IfBlock case: :param list[IfBlock] other: :param Node else_node: - :param _types.Error error: + :param flytekit.models.core.types.Error error: """ self._case = case self._other = other @@ -121,7 +121,9 @@ def from_flyte_idl(cls, pb2_object): case=IfBlock.from_flyte_idl(pb2_object.case), other=[IfBlock.from_flyte_idl(a) for a in pb2_object.other], else_node=Node.from_flyte_idl(pb2_object.else_node) if pb2_object.HasField("else_node") else None, - error=_types.Error.from_flyte_idl(pb2_object.error) if pb2_object.HasField("error") else None, + error=flytekit.models.core.types.Error.from_flyte_idl(pb2_object.error) + if pb2_object.HasField("error") + else None, ) diff --git a/flytekit/models/workflow_closure.py b/flytekit/models/core/workflow_closure.py similarity index 82% rename from flytekit/models/workflow_closure.py rename to flytekit/models/core/workflow_closure.py index fbf0b08688..c21d36bb22 100644 --- a/flytekit/models/workflow_closure.py +++ b/flytekit/models/core/workflow_closure.py @@ -1,7 +1,7 @@ from flyteidl.core import workflow_closure_pb2 as _workflow_closure_pb2 +import flytekit.models.core.task from flytekit.models import common as _common -from flytekit.models import task as _task_models from flytekit.models.core import workflow as _core_workflow_models @@ -9,7 +9,7 @@ class WorkflowClosure(_common.FlyteIdlEntity): def __init__(self, workflow, tasks=None): """ :param flytekit.models.core.workflow.WorkflowTemplate workflow: Workflow template - :param list[flytekit.models.task.TaskTemplate] tasks: [Optional] + :param list[flytekit.models.core.task.TaskTemplate] tasks: [Optional] """ self._workflow = workflow self._tasks = tasks @@ -24,7 +24,7 @@ def workflow(self): @property def tasks(self): """ - :rtype: list[flytekit.models.task.TaskTemplate] + :rtype: list[flytekit.models.core.task.TaskTemplate] """ return self._tasks @@ -45,5 +45,5 @@ def from_flyte_idl(cls, pb2_object): """ return cls( workflow=_core_workflow_models.WorkflowTemplate.from_flyte_idl(pb2_object.workflow), - tasks=[_task_models.TaskTemplate.from_flyte_idl(t) for t in pb2_object.tasks], + tasks=[flytekit.models.core.task.TaskTemplate.from_flyte_idl(t) for t in pb2_object.tasks], ) diff --git a/flytekit/models/named_entity.py b/flytekit/models/named_entity.py deleted file mode 100644 index 63dd598d98..0000000000 --- a/flytekit/models/named_entity.py +++ /dev/null @@ -1,122 +0,0 @@ -from flyteidl.admin import common_pb2 as _common - -from flytekit.models import common as _common_models - - -class NamedEntityState(object): - ACTIVE = _common.NAMED_ENTITY_ACTIVE - ARCHIVED = _common.NAMED_ENTITY_ARCHIVED - - @classmethod - def enum_to_string(cls, val): - """ - :param int val: - :rtype: Text - """ - if val == cls.ACTIVE: - return "ACTIVE" - elif val == cls.ARCHIVED: - return "ARCHIVED" - else: - return "" - - -class NamedEntityIdentifier(_common_models.FlyteIdlEntity): - def __init__(self, project, domain, name): - """ - :param Text project: - :param Text domain: - :param Text name: - """ - self._project = project - self._domain = domain - self._name = name - - @property - def project(self): - """ - :rtype: Text - """ - return self._project - - @property - def domain(self): - """ - :rtype: Text - """ - return self._domain - - @property - def name(self): - """ - :rtype: Text - """ - return self._name - - def to_flyte_idl(self): - """ - :rtype: flyteidl.admin.common_pb2.NamedEntityIdentifier - """ - return _common.NamedEntityIdentifier( - project=self.project, - domain=self.domain, - name=self.name, - ) - - @classmethod - def from_flyte_idl(cls, p): - """ - :param flyteidl.core.common_pb2.NamedEntityIdentifier p: - :rtype: Identifier - """ - return cls( - project=p.project, - domain=p.domain, - name=p.name, - ) - - -class NamedEntityMetadata(_common_models.FlyteIdlEntity): - def __init__(self, description, state): - """ - - :param Text description: - :param int state: enum value from NamedEntityState - """ - self._description = description - self._state = state - - @property - def description(self): - """ - :rtype: Text - """ - return self._description - - @property - def state(self): - """ - enum value from NamedEntityState - :rtype: int - """ - return self._state - - def to_flyte_idl(self): - """ - :rtype: flyteidl.admin.common_pb2.NamedEntityMetadata - """ - return _common.NamedEntityMetadata( - description=self.description, - state=self.state, - ) - - @classmethod - def from_flyte_idl(cls, p): - """ - :param flyteidl.core.common_pb2.NamedEntityMetadata p: - :rtype: Identifier - """ - return cls( - description=p.description, - state=p.state, - ) diff --git a/flytekit/models/plugins/__init__.py b/flytekit/models/plugins/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/flytekit/models/array_job.py b/flytekit/models/plugins/array_job.py similarity index 100% rename from flytekit/models/array_job.py rename to flytekit/models/plugins/array_job.py diff --git a/flytekit/models/presto.py b/flytekit/models/plugins/presto.py similarity index 100% rename from flytekit/models/presto.py rename to flytekit/models/plugins/presto.py diff --git a/flytekit/models/qubole.py b/flytekit/models/plugins/qubole.py similarity index 100% rename from flytekit/models/qubole.py rename to flytekit/models/plugins/qubole.py diff --git a/flytekit/models/plugins/task.py b/flytekit/models/plugins/task.py new file mode 100644 index 0000000000..56c2eb6c61 --- /dev/null +++ b/flytekit/models/plugins/task.py @@ -0,0 +1,201 @@ +import typing + +from flyteidl.plugins import pytorch_pb2 as _pytorch_task +from flyteidl.plugins import spark_pb2 as _spark_task +from flyteidl.plugins import tensorflow_pb2 as _tensorflow_task + +from flytekit.common.exceptions import user as _user_exceptions +from flytekit.models import common as _common +from flytekit.sdk.spark_types import SparkType as _spark_type + + +class SparkJob(_common.FlyteIdlEntity): + def __init__( + self, + spark_type, + application_file, + main_class, + spark_conf, + hadoop_conf, + executor_path, + ): + """ + This defines a SparkJob target. It will execute the appropriate SparkJob. + + :param application_file: The main application file to execute. + :param dict[Text, Text] spark_conf: A definition of key-value pairs for spark config for the job. + :param dict[Text, Text] hadoop_conf: A definition of key-value pairs for hadoop config for the job. + """ + self._application_file = application_file + self._spark_type = spark_type + self._main_class = main_class + self._executor_path = executor_path + self._spark_conf = spark_conf + self._hadoop_conf = hadoop_conf + + def with_overrides( + self, new_spark_conf: typing.Dict[str, str] = None, new_hadoop_conf: typing.Dict[str, str] = None + ) -> "SparkJob": + if not new_spark_conf: + new_spark_conf = self.spark_conf + + if not new_hadoop_conf: + new_hadoop_conf = self.hadoop_conf + + return SparkJob( + spark_type=self.spark_type, + application_file=self.application_file, + main_class=self.main_class, + spark_conf=new_spark_conf, + hadoop_conf=new_hadoop_conf, + executor_path=self.executor_path, + ) + + @property + def main_class(self): + """ + The main class to execute + :rtype: Text + """ + return self._main_class + + @property + def spark_type(self): + """ + Spark Job Type + :rtype: Text + """ + return self._spark_type + + @property + def application_file(self): + """ + The main application file to execute + :rtype: Text + """ + return self._application_file + + @property + def executor_path(self): + """ + The python executable to use + :rtype: Text + """ + return self._executor_path + + @property + def spark_conf(self): + """ + A definition of key-value pairs for spark config for the job. + :rtype: dict[Text, Text] + """ + return self._spark_conf + + @property + def hadoop_conf(self): + """ + A definition of key-value pairs for hadoop config for the job. + :rtype: dict[Text, Text] + """ + return self._hadoop_conf + + def to_flyte_idl(self): + """ + :rtype: flyteidl.plugins.spark_pb2.SparkJob + """ + + if self.spark_type == _spark_type.PYTHON: + application_type = _spark_task.SparkApplication.PYTHON + elif self.spark_type == _spark_type.JAVA: + application_type = _spark_task.SparkApplication.JAVA + elif self.spark_type == _spark_type.SCALA: + application_type = _spark_task.SparkApplication.SCALA + elif self.spark_type == _spark_type.R: + application_type = _spark_task.SparkApplication.R + else: + raise _user_exceptions.FlyteValidationException("Invalid Spark Application Type Specified") + + return _spark_task.SparkJob( + applicationType=application_type, + mainApplicationFile=self.application_file, + mainClass=self.main_class, + executorPath=self.executor_path, + sparkConf=self.spark_conf, + hadoopConf=self.hadoop_conf, + ) + + @classmethod + def from_flyte_idl(cls, pb2_object): + """ + :param flyteidl.plugins.spark_pb2.SparkJob pb2_object: + :rtype: SparkJob + """ + + application_type = _spark_type.PYTHON + if pb2_object.type == _spark_task.SparkApplication.JAVA: + application_type = _spark_type.JAVA + elif pb2_object.type == _spark_task.SparkApplication.SCALA: + application_type = _spark_type.SCALA + elif pb2_object.type == _spark_task.SparkApplication.R: + application_type = _spark_type.R + + return cls( + type=application_type, + spark_conf=pb2_object.sparkConf, + application_file=pb2_object.mainApplicationFile, + main_class=pb2_object.mainClass, + hadoop_conf=pb2_object.hadoopConf, + executor_path=pb2_object.executorPath, + ) + + +class PyTorchJob(_common.FlyteIdlEntity): + def __init__(self, workers_count): + self._workers_count = workers_count + + @property + def workers_count(self): + return self._workers_count + + def to_flyte_idl(self): + return _pytorch_task.DistributedPyTorchTrainingTask( + workers=self.workers_count, + ) + + @classmethod + def from_flyte_idl(cls, pb2_object): + return cls( + workers_count=pb2_object.workers, + ) + + +class TensorFlowJob(_common.FlyteIdlEntity): + def __init__(self, workers_count, ps_replicas_count, chief_replicas_count): + self._workers_count = workers_count + self._ps_replicas_count = ps_replicas_count + self._chief_replicas_count = chief_replicas_count + + @property + def workers_count(self): + return self._workers_count + + @property + def ps_replicas_count(self): + return self._ps_replicas_count + + @property + def chief_replicas_count(self): + return self._chief_replicas_count + + def to_flyte_idl(self): + return _tensorflow_task.DistributedTensorflowTrainingTask( + workers=self.workers_count, ps_replicas=self.ps_replicas_count, chief_replicas=self.chief_replicas_count + ) + + @classmethod + def from_flyte_idl(cls, pb2_object): + return cls( + workers_count=pb2_object.workers, + ps_replicas_count=pb2_object.ps_replicas, + chief_replicas_count=pb2_object.chief_replicas, + ) diff --git a/flytekit/models/types.py b/flytekit/models/types.py deleted file mode 100644 index 03b71ef44e..0000000000 --- a/flytekit/models/types.py +++ /dev/null @@ -1,274 +0,0 @@ -import json as _json - -from flyteidl.core import types_pb2 as _types_pb2 -from google.protobuf import json_format as _json_format -from google.protobuf import struct_pb2 as _struct - -from flytekit.models import common as _common -from flytekit.models.core import types as _core_types - - -class SimpleType(object): - NONE = _types_pb2.NONE - INTEGER = _types_pb2.INTEGER - FLOAT = _types_pb2.FLOAT - STRING = _types_pb2.STRING - BOOLEAN = _types_pb2.BOOLEAN - DATETIME = _types_pb2.DATETIME - DURATION = _types_pb2.DURATION - BINARY = _types_pb2.BINARY - ERROR = _types_pb2.ERROR - STRUCT = _types_pb2.STRUCT - - -class SchemaType(_common.FlyteIdlEntity): - class SchemaColumn(_common.FlyteIdlEntity): - class SchemaColumnType(object): - INTEGER = _types_pb2.SchemaType.SchemaColumn.INTEGER - FLOAT = _types_pb2.SchemaType.SchemaColumn.FLOAT - STRING = _types_pb2.SchemaType.SchemaColumn.STRING - DATETIME = _types_pb2.SchemaType.SchemaColumn.DATETIME - DURATION = _types_pb2.SchemaType.SchemaColumn.DURATION - BOOLEAN = _types_pb2.SchemaType.SchemaColumn.BOOLEAN - - def __init__(self, name, type): - """ - :param Text name: Name for the column - :param int type: Enum type from SchemaType.SchemaColumn.SchemaColumnType representing the type of the column - """ - self._name = name - self._type = type - - @property - def name(self): - """ - Name for the column - :rtype: Text - """ - return self._name - - @property - def type(self): - """ - Enum type from SchemaType.SchemaColumn.SchemaColumnType representing the type of the column - :rtype: int - """ - return self._type - - def to_flyte_idl(self): - """ - :rtype: flyteidl.core.types_pb2.SchemaType.SchemaColumn - """ - return _types_pb2.SchemaType.SchemaColumn(name=self.name, type=self.type) - - @classmethod - def from_flyte_idl(cls, proto): - """ - :param flyteidl.core.types_pb2.SchemaType.SchemaColumn proto: - :rtype: SchemaType.SchemaColumn - """ - return cls(name=proto.name, type=proto.type) - - def __init__(self, columns): - """ - :param list[SchemaType.SchemaColumn] columns: A list of columns defining the underlying data frame. - """ - self._columns = columns - - @property - def columns(self): - """ - A list of columns defining the underlying data frame. - :rtype: list[SchemaType.SchemaColumn] - """ - return self._columns - - def to_flyte_idl(self): - """ - :rtype: flyteidl.core.types_pb2.SchemaType - """ - return _types_pb2.SchemaType(columns=[c.to_flyte_idl() for c in self.columns]) - - @classmethod - def from_flyte_idl(cls, proto): - """ - :param flyteidl.core.types_pb2.SchemaType proto: - :rtype: SchemaType - """ - return cls(columns=[SchemaType.SchemaColumn.from_flyte_idl(c) for c in proto.columns]) - - -class LiteralType(_common.FlyteIdlEntity): - def __init__( - self, - simple=None, - schema=None, - collection_type=None, - map_value_type=None, - blob=None, - enum_type=None, - metadata=None, - ): - """ - Only one of the kwargs may be set. - :param int simple: Enum type from SimpleType - :param SchemaType schema: Type definition for a dataframe-like object. - :param LiteralType collection_type: For list-like objects, this is the type of each entry in the list. - :param LiteralType map_value_type: For map objects, this is the type of the value. The key must always be a - string. - :param flytekit.models.core.types.BlobType blob: For blob objects, this describes the type. - :param flytekit.models.core.types.EnumType enum_type: For enum objects, describes an enum - :param dict[Text, T] metadata: Additional data describing the type - """ - self._simple = simple - self._schema = schema - self._collection_type = collection_type - self._map_value_type = map_value_type - self._blob = blob - self._enum_type = enum_type - self._metadata = metadata - - @property - def simple(self) -> SimpleType: - return self._simple - - @property - def schema(self) -> SchemaType: - return self._schema - - @property - def collection_type(self) -> "LiteralType": - """ - The collection value type - """ - return self._collection_type - - @property - def map_value_type(self) -> "LiteralType": - """ - The Value for a dictionary. Key is always string - """ - return self._map_value_type - - @property - def blob(self) -> _core_types.BlobType: - return self._blob - - @property - def enum_type(self) -> _core_types.EnumType: - return self._enum_type - - @property - def metadata(self): - """ - :rtype: dict[Text, T] - """ - return self._metadata - - def to_flyte_idl(self): - """ - :rtype: flyteidl.core.types_pb2.LiteralType - """ - if self.metadata is not None: - metadata = _json_format.Parse(_json.dumps(self.metadata), _struct.Struct()) - else: - metadata = None - t = _types_pb2.LiteralType( - simple=self.simple if self.simple is not None else None, - schema=self.schema.to_flyte_idl() if self.schema is not None else None, - collection_type=self.collection_type.to_flyte_idl() if self.collection_type is not None else None, - map_value_type=self.map_value_type.to_flyte_idl() if self.map_value_type is not None else None, - blob=self.blob.to_flyte_idl() if self.blob is not None else None, - enum_type=self.enum_type.to_flyte_idl() if self.enum_type else None, - metadata=metadata, - ) - return t - - @classmethod - def from_flyte_idl(cls, proto): - """ - :param flyteidl.core.types_pb2.LiteralType proto: - :rtype: LiteralType - """ - collection_type = None - map_value_type = None - if proto.HasField("collection_type"): - collection_type = LiteralType.from_flyte_idl(proto.collection_type) - if proto.HasField("map_value_type"): - map_value_type = LiteralType.from_flyte_idl(proto.map_value_type) - return cls( - simple=proto.simple if proto.HasField("simple") else None, - schema=SchemaType.from_flyte_idl(proto.schema) if proto.HasField("schema") else None, - collection_type=collection_type, - map_value_type=map_value_type, - blob=_core_types.BlobType.from_flyte_idl(proto.blob) if proto.HasField("blob") else None, - enum_type=_core_types.EnumType.from_flyte_idl(proto.enum_type) if proto.HasField("enum_type") else None, - metadata=_json_format.MessageToDict(proto.metadata) or None, - ) - - -class OutputReference(_common.FlyteIdlEntity): - def __init__(self, node_id, var): - """ - A reference to an output produced by a node. The type can be retrieved -and validated- from - the underlying interface of the node. - - :param Text node_id: Node id must exist at the graph layer. - :param Text var: Variable name must refer to an output variable for the node. - """ - self._node_id = node_id - self._var = var - - @property - def node_id(self): - """ - Node id must exist at the graph layer. - :rtype: Text - """ - return self._node_id - - @property - def var(self): - """ - Variable name must refer to an output variable for the node. - :rtype: Text - """ - return self._var - - @var.setter - def var(self, var_name): - self._var = var_name - - def to_flyte_idl(self): - """ - :rtype: flyteidl.core.types.OutputReference - """ - return _types_pb2.OutputReference(node_id=self.node_id, var=self.var) - - @classmethod - def from_flyte_idl(cls, pb2_object): - """ - :param flyteidl.core.types.OutputReference pb2_object: - :rtype: OutputReference - """ - return cls(node_id=pb2_object.node_id, var=pb2_object.var) - - -class Error(_common.FlyteIdlEntity): - def __init__(self, failed_node_id: str, message: str): - self._message = message - self._failed_node_id = failed_node_id - - def to_flyte_idl(self) -> _types_pb2.Error: - return _types_pb2.Error( - message=self._message, - failed_node_id=self._failed_node_id, - ) - - @classmethod - def from_flyte_idl(cls, pb2_object: _types_pb2.Error) -> "Error": - """ - :param flyteidl.core.types.OutputReference pb2_object: - :rtype: OutputReference - """ - return cls(failed_node_id=pb2_object.failed_node_id, message=pb2_object.message) diff --git a/flytekit/remote/component_nodes.py b/flytekit/remote/component_nodes.py index 06b885abfd..33a2ecf7e3 100644 --- a/flytekit/remote/component_nodes.py +++ b/flytekit/remote/component_nodes.py @@ -1,10 +1,9 @@ import logging as _logging from typing import Dict -import flytekit +import flytekit.models.core.task from flytekit.common.exceptions import system as _system_exceptions -from flytekit.models import launch_plan as _launch_plan_model -from flytekit.models import task as _task_model +from flytekit.models.admin import launch_plan as _launch_plan_model from flytekit.models.core import workflow as _workflow_model from flytekit.remote import identifier as _identifier @@ -29,7 +28,7 @@ def flyte_task(self) -> "flytekit.remote.tasks.task.FlyteTask": def promote_from_model( cls, base_model: _workflow_model.TaskNode, - tasks: Dict[_identifier.Identifier, _task_model.TaskTemplate], + tasks: Dict[_identifier.Identifier, flytekit.models.core.task.TaskTemplate], ) -> "FlyteTaskNode": """ Takes the idl wrapper for a TaskNode and returns the hydrated Flytekit object for it by fetching it with the @@ -98,7 +97,7 @@ def promote_from_model( base_model: _workflow_model.WorkflowNode, sub_workflows: Dict[_identifier.Identifier, _workflow_model.WorkflowTemplate], node_launch_plans: Dict[_identifier.Identifier, _launch_plan_model.LaunchPlanSpec], - tasks: Dict[_identifier.Identifier, _task_model.TaskTemplate], + tasks: Dict[_identifier.Identifier, flytekit.models.core.task.TaskTemplate], ) -> "FlyteWorkflowNode": from flytekit.remote import launch_plan as _launch_plan from flytekit.remote import workflow as _workflow diff --git a/flytekit/remote/interface.py b/flytekit/remote/interface.py index 6aeb0e236d..cabbf352e8 100644 --- a/flytekit/remote/interface.py +++ b/flytekit/remote/interface.py @@ -1,7 +1,7 @@ from typing import Any, Dict, List, Tuple -from flytekit.models import interface as _interface_models -from flytekit.models import literals as _literal_models +from flytekit.models.core import interface as _interface_models +from flytekit.models.core import literals as _literal_models from flytekit.remote import nodes as _nodes diff --git a/flytekit/remote/launch_plan.py b/flytekit/remote/launch_plan.py index 200244f394..0c8e391ce8 100644 --- a/flytekit/remote/launch_plan.py +++ b/flytekit/remote/launch_plan.py @@ -5,9 +5,9 @@ from flytekit.core.interface import Interface from flytekit.core.type_engine import TypeEngine from flytekit.engines.flyte import engine as _flyte_engine -from flytekit.models import interface as _interface_models -from flytekit.models import launch_plan as _launch_plan_models +from flytekit.models.admin import launch_plan as _launch_plan_models from flytekit.models.core import identifier as _identifier_model +from flytekit.models.core import interface as _interface_models from flytekit.remote import identifier as _identifier from flytekit.remote import interface as _interface diff --git a/flytekit/remote/nodes.py b/flytekit/remote/nodes.py index 68d84f00ea..d8e12ab8f1 100644 --- a/flytekit/remote/nodes.py +++ b/flytekit/remote/nodes.py @@ -1,7 +1,7 @@ import logging as _logging from typing import Any, Dict, List, Optional, Union -import flytekit +import flytekit.models.core.task from flytekit.clients.helpers import iterate_node_executions, iterate_task_executions from flytekit.common import constants as _constants from flytekit.common.exceptions import system as _system_exceptions @@ -11,9 +11,8 @@ from flytekit.common.utils import _dnsify from flytekit.core.promise import NodeOutput from flytekit.engines.flyte import engine as _flyte_engine -from flytekit.models import launch_plan as _launch_plan_model -from flytekit.models import node_execution as _node_execution_models -from flytekit.models import task as _task_model +from flytekit.models.admin import launch_plan as _launch_plan_model +from flytekit.models.admin import node_execution as _node_execution_models from flytekit.models.core import execution as _execution_models from flytekit.models.core import workflow as _workflow_model from flytekit.remote import component_nodes as _component_nodes @@ -72,7 +71,7 @@ def promote_from_model( model: _workflow_model.Node, sub_workflows: Optional[Dict[_identifier.Identifier, _workflow_model.WorkflowTemplate]], node_launch_plans: Optional[Dict[_identifier.Identifier, _launch_plan_model.LaunchPlanSpec]], - tasks: Optional[Dict[_identifier.Identifier, _task_model.TaskTemplate]], + tasks: Optional[Dict[_identifier.Identifier, flytekit.models.core.task.TaskTemplate]], ) -> "FlyteNode": id = model.id if id in {_constants.START_NODE_ID, _constants.END_NODE_ID}: diff --git a/flytekit/remote/remote.py b/flytekit/remote/remote.py index 41a128caf9..96c504a59f 100644 --- a/flytekit/remote/remote.py +++ b/flytekit/remote/remote.py @@ -13,6 +13,8 @@ import grpc from flyteidl.core import literals_pb2 as literals_pb2 +import flytekit.models.admin.common +import flytekit.models.admin.launch_plan from flytekit.clients.friendly import SynchronousFlyteClient from flytekit.common import utils as common_utils from flytekit.configuration import platform as platform_config @@ -22,6 +24,7 @@ from flytekit.loggers import remote_logger from flytekit.models import filters as filter_models from flytekit.models.admin import common as admin_common_models +from flytekit.models.admin import launch_plan as launch_plan_models try: from functools import singledispatchmethod @@ -41,18 +44,17 @@ from flytekit.core.launch_plan import LaunchPlan from flytekit.core.type_engine import TypeEngine from flytekit.core.workflow import WorkflowBase -from flytekit.models import common as common_models -from flytekit.models import launch_plan as launch_plan_models -from flytekit.models import literals as literal_models +from flytekit.models.admin.common import NamedEntityIdentifier as _namedEntityIdentifier from flytekit.models.admin.common import Sort -from flytekit.models.core.identifier import ResourceType -from flytekit.models.execution import ( +from flytekit.models.admin.execution import ( ExecutionMetadata, ExecutionSpec, NodeExecutionGetDataResponse, NotificationList, WorkflowExecutionGetDataResponse, ) +from flytekit.models.core import literals as literal_models +from flytekit.models.core.identifier import ResourceType from flytekit.remote.identifier import Identifier, WorkflowExecutionIdentifier from flytekit.remote.interface import TypedInterface from flytekit.remote.launch_plan import FlyteLaunchPlan @@ -76,7 +78,7 @@ class ResolvedIdentifiers: def _get_latest_version(list_entities_method: typing.Callable, project: str, domain: str, name: str): - named_entity = common_models.NamedEntityIdentifier(project, domain, name) + named_entity = _namedEntityIdentifier(project, domain, name) entity_list, _ = list_entities_method( named_entity, limit=1, @@ -153,7 +155,7 @@ def from_config( default_project=default_project or PROJECT.get() or None, default_domain=default_domain or DOMAIN.get() or None, file_access=file_access, - auth_role=common_models.AuthRole( + auth_role=flytekit.models.admin.common.AuthRole( assumable_iam_role=auth_config.ASSUMABLE_IAM_ROLE.get(), kubernetes_service_account=auth_config.KUBERNETES_SERVICE_ACCOUNT.get(), ), @@ -162,7 +164,7 @@ def from_config( annotations=None, image_config=get_image_config(), raw_output_data_config=( - common_models.RawOutputDataConfig(raw_output_data_prefix) if raw_output_data_prefix else None + admin_common_models.RawOutputDataConfig(raw_output_data_prefix) if raw_output_data_prefix else None ), grpc_credentials=grpc_credentials, ) @@ -174,12 +176,12 @@ def __init__( default_project: typing.Optional[str] = None, default_domain: typing.Optional[str] = None, file_access: typing.Optional[FileAccessProvider] = None, - auth_role: typing.Optional[common_models.AuthRole] = None, - notifications: typing.Optional[typing.List[common_models.Notification]] = None, - labels: typing.Optional[common_models.Labels] = None, - annotations: typing.Optional[common_models.Annotations] = None, + auth_role: typing.Optional[flytekit.models.admin.common.AuthRole] = None, + notifications: typing.Optional[typing.List[admin_common_models.Notification]] = None, + labels: typing.Optional[admin_common_models.Labels] = None, + annotations: typing.Optional[admin_common_models.Annotations] = None, image_config: typing.Optional[ImageConfig] = None, - raw_output_data_config: typing.Optional[common_models.RawOutputDataConfig] = None, + raw_output_data_config: typing.Optional[admin_common_models.RawOutputDataConfig] = None, grpc_credentials: typing.Optional[grpc.ChannelCredentials] = None, ): """Initialize a FlyteRemote object. @@ -290,12 +292,12 @@ def with_overrides( flyte_admin_url: typing.Optional[str] = None, insecure: typing.Optional[bool] = None, file_access: typing.Optional[FileAccessProvider] = None, - auth_role: typing.Optional[common_models.AuthRole] = None, - notifications: typing.Optional[typing.List[common_models.Notification]] = None, - labels: typing.Optional[common_models.Labels] = None, - annotations: typing.Optional[common_models.Annotations] = None, + auth_role: typing.Optional[flytekit.models.admin.common.AuthRole] = None, + notifications: typing.Optional[typing.List[admin_common_models.Notification]] = None, + labels: typing.Optional[admin_common_models.Labels] = None, + annotations: typing.Optional[admin_common_models.Annotations] = None, image_config: typing.Optional[ImageConfig] = None, - raw_output_data_config: typing.Optional[common_models.RawOutputDataConfig] = None, + raw_output_data_config: typing.Optional[admin_common_models.RawOutputDataConfig] = None, ): """Create a copy of the remote object, overriding the specified attributes.""" new_remote = deepcopy(self) @@ -486,7 +488,7 @@ def list_tasks_by_version( if not version: raise ValueError("Must specify a version") - named_entity_id = common_models.NamedEntityIdentifier( + named_entity_id = _namedEntityIdentifier( project=project or self.default_project, domain=domain or self.default_domain, ) @@ -581,11 +583,11 @@ def _( resolved_identifiers = asdict(self._resolve_identifier_kwargs(entity, project, domain, name, version)) serialized_lp: launch_plan_models.LaunchPlan = self._serialize(entity, **resolved_identifiers) if self.auth_role: - serialized_lp.spec._auth_role = common_models.AuthRole( + serialized_lp.spec._auth_role = flytekit.models.admin.common.AuthRole( self.auth_role.assumable_iam_role, self.auth_role.kubernetes_service_account ) if self.raw_output_data_config: - serialized_lp.spec._raw_output_data_config = common_models.RawOutputDataConfig( + serialized_lp.spec._raw_output_data_config = admin_common_models.RawOutputDataConfig( self.raw_output_data_config.output_location_prefix ) @@ -663,9 +665,9 @@ def _execute( domain: str, execution_name: typing.Optional[str] = None, wait: bool = False, - labels: typing.Optional[common_models.Labels] = None, - annotations: typing.Optional[common_models.Annotations] = None, - auth_role: typing.Optional[common_models.AuthRole] = None, + labels: typing.Optional[admin_common_models.Labels] = None, + annotations: typing.Optional[admin_common_models.Annotations] = None, + auth_role: typing.Optional[flytekit.models.admin.common.AuthRole] = None, ) -> FlyteWorkflowExecution: """Common method for execution across all entities. diff --git a/flytekit/remote/tasks/task.py b/flytekit/remote/tasks/task.py index 967f4e43aa..45c06907c8 100644 --- a/flytekit/remote/tasks/task.py +++ b/flytekit/remote/tasks/task.py @@ -1,16 +1,16 @@ from typing import Optional +import flytekit.models.core.task from flytekit.common.mixins import hash as _hash_mixin from flytekit.core.interface import Interface from flytekit.core.type_engine import TypeEngine from flytekit.loggers import logger -from flytekit.models import task as _task_model from flytekit.models.core import identifier as _identifier_model from flytekit.remote import identifier as _identifier from flytekit.remote import interface as _interfaces -class FlyteTask(_hash_mixin.HashOnReferenceMixin, _task_model.TaskTemplate): +class FlyteTask(_hash_mixin.HashOnReferenceMixin, flytekit.models.core.task.TaskTemplate): """A class encapsulating a remote Flyte task.""" def __init__(self, id, type, metadata, interface, custom, container=None, task_type_version=0, config=None): @@ -49,7 +49,7 @@ def guessed_python_interface(self, value): self._python_interface = value @classmethod - def promote_from_model(cls, base_model: _task_model.TaskTemplate) -> "FlyteTask": + def promote_from_model(cls, base_model: flytekit.models.core.task.TaskTemplate) -> "FlyteTask": t = cls( id=base_model.id, type=base_model.type, diff --git a/flytekit/remote/workflow.py b/flytekit/remote/workflow.py index 66bb7f00e7..a5e311f813 100644 --- a/flytekit/remote/workflow.py +++ b/flytekit/remote/workflow.py @@ -1,13 +1,13 @@ from typing import Dict, List, Optional +import flytekit.models.core.task from flytekit.common import constants as _constants from flytekit.common.exceptions import system as _system_exceptions from flytekit.common.exceptions import user as _user_exceptions from flytekit.common.mixins import hash as _hash_mixin from flytekit.core.interface import Interface from flytekit.core.type_engine import TypeEngine -from flytekit.models import launch_plan as _launch_plan_models -from flytekit.models import task as _task_models +from flytekit.models.admin import launch_plan as _launch_plan_models from flytekit.models.core import identifier as _identifier_model from flytekit.models.core import workflow as _workflow_models from flytekit.remote import identifier as _identifier @@ -117,7 +117,7 @@ def promote_from_model( base_model: _workflow_models.WorkflowTemplate, sub_workflows: Optional[Dict[_identifier.Identifier, _workflow_models.WorkflowTemplate]] = None, node_launch_plans: Optional[Dict[_identifier.Identifier, _launch_plan_models.LaunchPlanSpec]] = None, - tasks: Optional[Dict[_identifier.Identifier, _task_models.TaskTemplate]] = None, + tasks: Optional[Dict[_identifier.Identifier, flytekit.models.core.task.TaskTemplate]] = None, ) -> "FlyteWorkflow": base_model_non_system_nodes = cls.get_non_system_nodes(base_model.nodes) sub_workflows = sub_workflows or {} diff --git a/flytekit/remote/workflow_execution.py b/flytekit/remote/workflow_execution.py index ee04d38648..b9dac0680d 100644 --- a/flytekit/remote/workflow_execution.py +++ b/flytekit/remote/workflow_execution.py @@ -1,7 +1,7 @@ from typing import Any, Dict from flytekit.common.exceptions import user as _user_exceptions -from flytekit.models import execution as _execution_models +from flytekit.models.admin import execution as _execution_models from flytekit.models.core import execution as _core_execution_models from flytekit.remote import identifier as _core_identifier from flytekit.remote import nodes as _nodes diff --git a/flytekit/sdk/tasks.py b/flytekit/sdk/tasks.py index 63cc8ba2b7..c5b5d5ad34 100644 --- a/flytekit/sdk/tasks.py +++ b/flytekit/sdk/tasks.py @@ -15,7 +15,7 @@ from flytekit.common.tasks import tensorflow_task as _sdk_tensorflow_tasks from flytekit.common.types import helpers as _type_helpers from flytekit.contrib.notebook import tasks as _nb_tasks -from flytekit.models import interface as _interface_model +from flytekit.models.core import interface as _interface_model from flytekit.sdk.spark_types import SparkType as _spark_type diff --git a/flytekit/type_engines/common.py b/flytekit/type_engines/common.py index 33927325ce..924be39753 100644 --- a/flytekit/type_engines/common.py +++ b/flytekit/type_engines/common.py @@ -16,7 +16,7 @@ def python_std_to_sdk_type(self, t): def get_sdk_type_from_literal_type(self, literal_type): """ Takes the Flyte spec language and converts to an SDK object. - :param flytekit.models.types.LiteralType literal_type: + :param flytekit.models.core.types.LiteralType literal_type: :rtype: flytekit.common.types.base_sdk_types.FlyteSdkType """ pass diff --git a/flytekit/type_engines/default/flyte.py b/flytekit/type_engines/default/flyte.py index 0ec3a4c982..593e825caf 100644 --- a/flytekit/type_engines/default/flyte.py +++ b/flytekit/type_engines/default/flyte.py @@ -1,6 +1,7 @@ import importlib as _importer from typing import Type +import flytekit.models.core.types from flytekit.common.exceptions import system as _system_exceptions from flytekit.common.exceptions import user as _user_exceptions from flytekit.common.types import base_sdk_types as _base_sdk_types @@ -10,7 +11,6 @@ from flytekit.common.types import primitives as _primitive_types from flytekit.common.types import proto as _proto from flytekit.common.types import schema as _schema -from flytekit.models import types as _literal_type_models from flytekit.models.core import types as _core_types @@ -58,14 +58,14 @@ def _generic_proto_sdk_type_from_tag(tag: str) -> Type[_proto.GenericProtobuf]: class FlyteDefaultTypeEngine(object): _SIMPLE_TYPE_LOOKUP_TABLE = { - _literal_type_models.SimpleType.INTEGER: _primitive_types.Integer, - _literal_type_models.SimpleType.FLOAT: _primitive_types.Float, - _literal_type_models.SimpleType.BOOLEAN: _primitive_types.Boolean, - _literal_type_models.SimpleType.DATETIME: _primitive_types.Datetime, - _literal_type_models.SimpleType.DURATION: _primitive_types.Timedelta, - _literal_type_models.SimpleType.NONE: _base_sdk_types.Void, - _literal_type_models.SimpleType.STRING: _primitive_types.String, - _literal_type_models.SimpleType.STRUCT: _primitive_types.Generic, + flytekit.models.core.types.SimpleType.INTEGER: _primitive_types.Integer, + flytekit.models.core.types.SimpleType.FLOAT: _primitive_types.Float, + flytekit.models.core.types.SimpleType.BOOLEAN: _primitive_types.Boolean, + flytekit.models.core.types.SimpleType.DATETIME: _primitive_types.Datetime, + flytekit.models.core.types.SimpleType.DURATION: _primitive_types.Timedelta, + flytekit.models.core.types.SimpleType.NONE: _base_sdk_types.Void, + flytekit.models.core.types.SimpleType.STRING: _primitive_types.String, + flytekit.models.core.types.SimpleType.STRUCT: _primitive_types.Generic, } def python_std_to_sdk_type(self, t): @@ -96,7 +96,7 @@ def python_std_to_sdk_type(self, t): def get_sdk_type_from_literal_type(self, literal_type): """ - :param flytekit.models.types.LiteralType literal_type: + :param flytekit.models.core.types.LiteralType literal_type: :rtype: flytekit.common.types.base_sdk_types.FlyteSdkType """ if literal_type.collection_type is not None: @@ -109,12 +109,12 @@ def get_sdk_type_from_literal_type(self, literal_type): return self._get_blob_impl_from_type(literal_type.blob) elif literal_type.simple is not None: if ( - literal_type.simple == _literal_type_models.SimpleType.BINARY + literal_type.simple == flytekit.models.core.types.SimpleType.BINARY and _proto.Protobuf.PB_FIELD_KEY in literal_type.metadata ): return _proto_sdk_type_from_tag(literal_type.metadata[_proto.Protobuf.PB_FIELD_KEY]) if ( - literal_type.simple == _literal_type_models.SimpleType.STRUCT + literal_type.simple == flytekit.models.core.types.SimpleType.STRUCT and literal_type.metadata and _proto.Protobuf.PB_FIELD_KEY in literal_type.metadata ): diff --git a/flytekit/types/directory/types.py b/flytekit/types/directory/types.py index d2ef84f2d2..eca7b1b3a1 100644 --- a/flytekit/types/directory/types.py +++ b/flytekit/types/directory/types.py @@ -5,12 +5,12 @@ import typing from pathlib import Path +import flytekit.models.core.types from flytekit.core.context_manager import FlyteContext from flytekit.core.type_engine import TypeEngine, TypeTransformer -from flytekit.models import types as _type_models from flytekit.models.core import types as _core_types -from flytekit.models.literals import Blob, BlobMetadata, Literal, Scalar -from flytekit.models.types import LiteralType +from flytekit.models.core.literals import Blob, BlobMetadata, Literal, Scalar +from flytekit.models.core.types import LiteralType T = typing.TypeVar("T") @@ -223,7 +223,9 @@ def assert_type(self, t: typing.Type[FlyteDirectory], v: typing.Union[FlyteDirec ) def get_literal_type(self, t: typing.Type[FlyteDirectory]) -> LiteralType: - return _type_models.LiteralType(blob=self._blob_type(format=FlyteDirToMultipartBlobTransformer.get_format(t))) + return flytekit.models.core.types.LiteralType( + blob=self._blob_type(format=FlyteDirToMultipartBlobTransformer.get_format(t)) + ) def to_literal( self, diff --git a/flytekit/types/file/file.py b/flytekit/types/file/file.py index 9ebcc8e9a5..5032184038 100644 --- a/flytekit/types/file/file.py +++ b/flytekit/types/file/file.py @@ -7,9 +7,8 @@ from flytekit.core.context_manager import FlyteContext from flytekit.core.type_engine import TypeEngine, TypeTransformer from flytekit.loggers import logger -from flytekit.models.core.types import BlobType -from flytekit.models.literals import Blob, BlobMetadata, Literal, Scalar -from flytekit.models.types import LiteralType +from flytekit.models.core.literals import Blob, BlobMetadata, Literal, Scalar +from flytekit.models.core.types import BlobType, LiteralType def noop(): diff --git a/flytekit/types/schema/types.py b/flytekit/types/schema/types.py index 89b6da3a90..04e0ff4912 100644 --- a/flytekit/types/schema/types.py +++ b/flytekit/types/schema/types.py @@ -12,8 +12,8 @@ from flytekit.core.context_manager import FlyteContext, FlyteContextManager from flytekit.core.type_engine import T, TypeEngine, TypeTransformer -from flytekit.models.literals import Literal, Scalar, Schema -from flytekit.models.types import LiteralType, SchemaType +from flytekit.models.core.literals import Literal, Scalar, Schema +from flytekit.models.core.types import LiteralType, SchemaType from flytekit.plugins import pandas diff --git a/flytekit/types/schema/types_pandas.py b/flytekit/types/schema/types_pandas.py index 41a5423c08..7303f828fe 100644 --- a/flytekit/types/schema/types_pandas.py +++ b/flytekit/types/schema/types_pandas.py @@ -7,8 +7,8 @@ from flytekit import FlyteContext from flytekit.configuration import sdk from flytekit.core.type_engine import T, TypeEngine, TypeTransformer -from flytekit.models.literals import Literal, Scalar, Schema -from flytekit.models.types import LiteralType, SchemaType +from flytekit.models.core.literals import Literal, Scalar, Schema +from flytekit.models.core.types import LiteralType, SchemaType from flytekit.types.schema import LocalIOSchemaReader, LocalIOSchemaWriter, SchemaEngine, SchemaFormat, SchemaHandler diff --git a/plugins/flytekit-aws-athena/flytekitplugins/athena/task.py b/plugins/flytekit-aws-athena/flytekitplugins/athena/task.py index 45ece39106..ac60ac0d12 100644 --- a/plugins/flytekit-aws-athena/flytekitplugins/athena/task.py +++ b/plugins/flytekit-aws-athena/flytekitplugins/athena/task.py @@ -4,7 +4,7 @@ from google.protobuf.json_format import MessageToDict from flytekit.extend import SerializationSettings, SQLTask -from flytekit.models.presto import PrestoQuery +from flytekit.models.plugins.presto import PrestoQuery from flytekit.types.schema import FlyteSchema diff --git a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/hpo.py b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/hpo.py index cdf144d23e..94209fd58a 100644 --- a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/hpo.py +++ b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/hpo.py @@ -11,11 +11,11 @@ from flytekit import FlyteContext from flytekit.common.types import primitives from flytekit.extend import DictTransformer, PythonTask, SerializationSettings, TypeEngine, TypeTransformer -from flytekit.models.literals import Literal +from flytekit.models.core.literals import Literal +from flytekit.models.core.types import LiteralType from flytekit.models.sagemaker import hpo_job as _hpo_job_model from flytekit.models.sagemaker import parameter_ranges as _params from flytekit.models.sagemaker import training_job as _training_job_model -from flytekit.models.types import LiteralType @dataclass diff --git a/plugins/flytekit-dolt/flytekitplugins/dolt/schema.py b/plugins/flytekit-dolt/flytekitplugins/dolt/schema.py index 6e51f079c4..ebca97e6da 100644 --- a/plugins/flytekit-dolt/flytekitplugins/dolt/schema.py +++ b/plugins/flytekit-dolt/flytekitplugins/dolt/schema.py @@ -11,11 +11,11 @@ from google.protobuf.json_format import MessageToDict from google.protobuf.struct_pb2 import Struct +import flytekit.models.core.types from flytekit import FlyteContext from flytekit.extend import TypeEngine, TypeTransformer -from flytekit.models import types as _type_models -from flytekit.models.literals import Literal, Scalar -from flytekit.models.types import LiteralType +from flytekit.models.core.literals import Literal, Scalar +from flytekit.models.core.types import LiteralType logger = logging.getLogger("flytekitplugins.dolt") @@ -44,7 +44,7 @@ def __init__(self): super().__init__(name="DoltTable", t=DoltTable) def get_literal_type(self, t: Type[DoltTable]) -> LiteralType: - return LiteralType(simple=_type_models.SimpleType.STRUCT, metadata={}) + return LiteralType(simple=flytekit.models.core.types.SimpleType.STRUCT, metadata={}) def to_literal( self, diff --git a/plugins/flytekit-dolt/tests/test_schema.py b/plugins/flytekit-dolt/tests/test_schema.py index 40b5e2c7a6..925b30e3da 100644 --- a/plugins/flytekit-dolt/tests/test_schema.py +++ b/plugins/flytekit-dolt/tests/test_schema.py @@ -3,7 +3,7 @@ from flytekitplugins.dolt.schema import DoltConfig, DoltTable, DoltTableNameTransformer from google.protobuf.struct_pb2 import Struct -from flytekit.models.literals import Literal, Scalar +from flytekit.models.core.literals import Literal, Scalar def test_dolt_table_to_python_value(mocker): diff --git a/plugins/flytekit-greatexpectations/flytekitplugins/great_expectations/schema.py b/plugins/flytekit-greatexpectations/flytekitplugins/great_expectations/schema.py index 9e0cd88305..ea3bc217db 100644 --- a/plugins/flytekit-greatexpectations/flytekitplugins/great_expectations/schema.py +++ b/plugins/flytekit-greatexpectations/flytekitplugins/great_expectations/schema.py @@ -12,11 +12,11 @@ from great_expectations.core.util import convert_to_json_serializable from great_expectations.exceptions import ValidationError +import flytekit.models.core.types from flytekit import FlyteContext from flytekit.extend import TypeEngine, TypeTransformer -from flytekit.models import types as _type_models -from flytekit.models.literals import Literal, Primitive, Scalar -from flytekit.models.types import LiteralType +from flytekit.models.core.literals import Literal, Primitive, Scalar +from flytekit.models.core.types import LiteralType from flytekit.types.file.file import FlyteFile, FlyteFilePathTransformer from flytekit.types.schema.types import FlyteSchema, FlyteSchemaTransformer, SchemaOpenMode @@ -99,7 +99,7 @@ def get_literal_type(self, t: Type[GreatExpectationsType]) -> LiteralType: datatype = GreatExpectationsTypeTransformer.get_config(t)[0] if issubclass(datatype, str): - return LiteralType(simple=_type_models.SimpleType.STRING, metadata={}) + return LiteralType(simple=flytekit.models.core.types.SimpleType.STRING, metadata={}) elif issubclass(datatype, FlyteFile): return FlyteFilePathTransformer().get_literal_type(datatype) elif issubclass(datatype, FlyteSchema): diff --git a/plugins/flytekit-hive/flytekitplugins/hive/task.py b/plugins/flytekit-hive/flytekitplugins/hive/task.py index 3280f4fb7a..6142476a11 100644 --- a/plugins/flytekit-hive/flytekitplugins/hive/task.py +++ b/plugins/flytekit-hive/flytekitplugins/hive/task.py @@ -4,7 +4,7 @@ from google.protobuf.json_format import MessageToDict from flytekit.extend import SerializationSettings, SQLTask -from flytekit.models.qubole import HiveQuery, QuboleHiveJob +from flytekit.models.plugins.qubole import HiveQuery, QuboleHiveJob from flytekit.types.schema import FlyteSchema diff --git a/plugins/flytekit-k8s-pod/flytekitplugins/pod/task.py b/plugins/flytekit-k8s-pod/flytekitplugins/pod/task.py index 05b7cdd5e1..584017bbed 100644 --- a/plugins/flytekit-k8s-pod/flytekitplugins/pod/task.py +++ b/plugins/flytekit-k8s-pod/flytekitplugins/pod/task.py @@ -4,15 +4,15 @@ from kubernetes.client import ApiClient from kubernetes.client.models import V1Container, V1EnvVar, V1PodSpec, V1ResourceRequirements +import flytekit.models.core.task from flytekit import FlyteContext, PythonFunctionTask from flytekit.common.exceptions import user as _user_exceptions from flytekit.extend import Promise, SerializationSettings, TaskPlugins -from flytekit.models import task as _task_models _PRIMARY_CONTAINER_NAME_FIELD = "primary_container_name" -def _sanitize_resource_name(resource: _task_models.Resources.ResourceEntry) -> str: +def _sanitize_resource_name(resource: flytekit.models.core.task.Resources.ResourceEntry) -> str: return _core_task.Resources.ResourceName.Name(resource.name).lower().replace("_", "-") @@ -104,16 +104,16 @@ def _serialize_pod_spec(self, settings: SerializationSettings) -> Dict[str, Any] return ApiClient().sanitize_for_serialization(self.task_config.pod_spec) - def get_k8s_pod(self, settings: SerializationSettings) -> _task_models.K8sPod: - return _task_models.K8sPod( + def get_k8s_pod(self, settings: SerializationSettings) -> flytekit.models.core.task.K8sPod: + return flytekit.models.core.task.K8sPod( pod_spec=self._serialize_pod_spec(settings), - metadata=_task_models.K8sObjectMetadata( + metadata=flytekit.models.core.task.K8sObjectMetadata( labels=self.task_config.labels, annotations=self.task_config.annotations, ), ) - def get_container(self, settings: SerializationSettings) -> _task_models.Container: + def get_container(self, settings: SerializationSettings) -> flytekit.models.core.task.Container: return None def get_config(self, settings: SerializationSettings) -> Dict[str, str]: diff --git a/plugins/flytekit-kf-pytorch/flytekitplugins/kfpytorch/task.py b/plugins/flytekit-kf-pytorch/flytekitplugins/kfpytorch/task.py index 56a8cab9d1..433ef1121a 100644 --- a/plugins/flytekit-kf-pytorch/flytekitplugins/kfpytorch/task.py +++ b/plugins/flytekit-kf-pytorch/flytekitplugins/kfpytorch/task.py @@ -9,7 +9,7 @@ from flytekit import PythonFunctionTask, Resources from flytekit.extend import SerializationSettings, TaskPlugins -from flytekit.models import task as _task_model +from flytekit.models.plugins import task as _task_model @dataclass diff --git a/plugins/flytekit-kf-tensorflow/flytekitplugins/kftensorflow/task.py b/plugins/flytekit-kf-tensorflow/flytekitplugins/kftensorflow/task.py index f637b9d7cc..1dbe3bf5f0 100644 --- a/plugins/flytekit-kf-tensorflow/flytekitplugins/kftensorflow/task.py +++ b/plugins/flytekit-kf-tensorflow/flytekitplugins/kftensorflow/task.py @@ -9,7 +9,7 @@ from flytekit import PythonFunctionTask, Resources from flytekit.extend import SerializationSettings, TaskPlugins -from flytekit.models import task as _task_model +from flytekit.models.plugins import task as _task_model @dataclass diff --git a/plugins/flytekit-modin/flytekitplugins/modin/schema.py b/plugins/flytekit-modin/flytekitplugins/modin/schema.py index 41bb12af1f..b761844ab1 100644 --- a/plugins/flytekit-modin/flytekitplugins/modin/schema.py +++ b/plugins/flytekit-modin/flytekitplugins/modin/schema.py @@ -7,8 +7,8 @@ from flytekit import FlyteContext from flytekit.extend import T, TypeEngine, TypeTransformer -from flytekit.models.literals import Literal, Scalar, Schema -from flytekit.models.types import LiteralType, SchemaType +from flytekit.models.core.literals import Literal, Scalar, Schema +from flytekit.models.core.types import LiteralType, SchemaType from flytekit.types.schema import LocalIOSchemaReader, LocalIOSchemaWriter, SchemaEngine, SchemaFormat, SchemaHandler from flytekit.types.schema.types import FlyteSchemaTransformer diff --git a/plugins/flytekit-pandera/flytekitplugins/pandera/schema.py b/plugins/flytekit-pandera/flytekitplugins/pandera/schema.py index 4f6b49973e..ce7fd4ed7a 100644 --- a/plugins/flytekit-pandera/flytekitplugins/pandera/schema.py +++ b/plugins/flytekit-pandera/flytekitplugins/pandera/schema.py @@ -6,8 +6,8 @@ from flytekit import FlyteContext from flytekit.extend import TypeEngine, TypeTransformer -from flytekit.models.literals import Literal, Scalar, Schema -from flytekit.models.types import LiteralType, SchemaType +from flytekit.models.core.literals import Literal, Scalar, Schema +from flytekit.models.core.types import LiteralType, SchemaType from flytekit.types.schema import FlyteSchema, PandasSchemaWriter, SchemaFormat, SchemaOpenMode from flytekit.types.schema.types import FlyteSchemaTransformer diff --git a/plugins/flytekit-papermill/flytekitplugins/papermill/task.py b/plugins/flytekit-papermill/flytekitplugins/papermill/task.py index 2647a2cf66..0213f16973 100644 --- a/plugins/flytekit-papermill/flytekitplugins/papermill/task.py +++ b/plugins/flytekit-papermill/flytekitplugins/papermill/task.py @@ -13,7 +13,7 @@ from flytekit import FlyteContext, PythonInstanceTask from flytekit.common.tasks.sdk_runnable import ExecutionParameters from flytekit.extend import Interface, TaskPlugins, TypeEngine -from flytekit.models.literals import LiteralMap +from flytekit.models.core.literals import LiteralMap from flytekit.types.file import HTMLPage, PythonNotebook T = typing.TypeVar("T") diff --git a/plugins/flytekit-papermill/requirements.txt b/plugins/flytekit-papermill/requirements.txt index 119628099c..5afa5046db 100644 --- a/plugins/flytekit-papermill/requirements.txt +++ b/plugins/flytekit-papermill/requirements.txt @@ -64,8 +64,10 @@ flytekit==0.22.1 # via # flytekitplugins-papermill # flytekitplugins-spark -flytekitplugins-spark==0.22.1 - # via -r ./requirements.in +flytekitplugins-spark==0.23.1 + # via + # -r ./requirements.in + # flytekitplugins-papermill grpcio==1.40.0 # via flytekit idna==3.2 diff --git a/plugins/flytekit-snowflake/flytekitplugins/snowflake/task.py b/plugins/flytekit-snowflake/flytekitplugins/snowflake/task.py index 7652fcfb1f..8ab0380dfc 100644 --- a/plugins/flytekit-snowflake/flytekitplugins/snowflake/task.py +++ b/plugins/flytekit-snowflake/flytekitplugins/snowflake/task.py @@ -1,8 +1,8 @@ from dataclasses import dataclass from typing import Dict, Optional, Type +import flytekit.models.core.task from flytekit.extend import SerializationSettings, SQLTask -from flytekit.models import task as _task_model from flytekit.types.schema import FlyteSchema _ACCOUNT_FIELD = "account" @@ -81,6 +81,8 @@ def get_config(self, settings: SerializationSettings) -> Dict[str, str]: _WAREHOUSE_FIELD: self.task_config.warehouse, } - def get_sql(self, settings: SerializationSettings) -> Optional[_task_model.Sql]: - sql = _task_model.Sql(statement=self.query_template, dialect=_task_model.Sql.Dialect.ANSI) + def get_sql(self, settings: SerializationSettings) -> Optional[flytekit.models.core.task.Sql]: + sql = flytekit.models.core.task.Sql( + statement=self.query_template, dialect=flytekit.models.core.task.Sql.Dialect.ANSI + ) return sql diff --git a/plugins/flytekit-spark/flytekitplugins/spark/schema.py b/plugins/flytekit-spark/flytekitplugins/spark/schema.py index 1cae101295..1764ef2703 100644 --- a/plugins/flytekit-spark/flytekitplugins/spark/schema.py +++ b/plugins/flytekit-spark/flytekitplugins/spark/schema.py @@ -5,8 +5,8 @@ from flytekit import FlyteContext from flytekit.extend import T, TypeEngine, TypeTransformer -from flytekit.models.literals import Literal, Scalar, Schema -from flytekit.models.types import LiteralType, SchemaType +from flytekit.models.core.literals import Literal, Scalar, Schema +from flytekit.models.core.types import LiteralType, SchemaType from flytekit.types.schema import SchemaEngine, SchemaFormat, SchemaHandler, SchemaReader, SchemaWriter diff --git a/plugins/flytekit-spark/flytekitplugins/spark/task.py b/plugins/flytekit-spark/flytekitplugins/spark/task.py index a2893a06ad..a38a551033 100644 --- a/plugins/flytekit-spark/flytekitplugins/spark/task.py +++ b/plugins/flytekit-spark/flytekitplugins/spark/task.py @@ -9,7 +9,7 @@ from flytekit import FlyteContextManager, PythonFunctionTask from flytekit.common.tasks.sdk_runnable import ExecutionParameters from flytekit.extend import ExecutionState, SerializationSettings, TaskPlugins -from flytekit.models import task as _task_model +from flytekit.models.plugins import task as _task_model from flytekit.sdk.spark_types import SparkType diff --git a/plugins/flytekit-sqlalchemy/flytekitplugins/sqlalchemy/task.py b/plugins/flytekit-sqlalchemy/flytekitplugins/sqlalchemy/task.py index a04f143e58..3729d9a104 100644 --- a/plugins/flytekit-sqlalchemy/flytekitplugins/sqlalchemy/task.py +++ b/plugins/flytekit-sqlalchemy/flytekitplugins/sqlalchemy/task.py @@ -4,13 +4,13 @@ import pandas as pd from sqlalchemy import create_engine # type: ignore +import flytekit.models.core.task from flytekit import current_context, kwtypes from flytekit.core.base_sql_task import SQLTask from flytekit.core.context_manager import SerializationSettings from flytekit.core.python_customized_container_task import PythonCustomizedContainerTask from flytekit.core.shim_task import ShimTaskExecutor -from flytekit.models import task as task_models -from flytekit.models.security import Secret +from flytekit.models.core.security import Secret from flytekit.types.schema import FlyteSchema @@ -105,7 +105,7 @@ def get_custom(self, settings: SerializationSettings) -> typing.Dict[str, typing class SQLAlchemyTaskExecutor(ShimTaskExecutor[SQLAlchemyTask]): - def execute_from_model(self, tt: task_models.TaskTemplate, **kwargs) -> typing.Any: + def execute_from_model(self, tt: flytekit.models.core.task.TaskTemplate, **kwargs) -> typing.Any: if tt.custom["secret_connect_args"] is not None: for key, secret_dict in tt.custom["secret_connect_args"].items(): value = current_context().secrets.get(group=secret_dict["group"], key=secret_dict["key"]) diff --git a/plugins/flytekit-sqlalchemy/tests/test_task.py b/plugins/flytekit-sqlalchemy/tests/test_task.py index 2c853f7811..af34d1c196 100644 --- a/plugins/flytekit-sqlalchemy/tests/test_task.py +++ b/plugins/flytekit-sqlalchemy/tests/test_task.py @@ -11,7 +11,7 @@ from flytekitplugins.sqlalchemy.task import SQLAlchemyTaskExecutor from flytekit import kwtypes, task, workflow -from flytekit.models.security import Secret +from flytekit.models.core.security import Secret from flytekit.testing import SecretsManager from flytekit.types.schema import FlyteSchema diff --git a/tests/flytekit/common/parameterizers.py b/tests/flytekit/common/parameterizers.py index fbc8f07ba5..f066273dd5 100644 --- a/tests/flytekit/common/parameterizers.py +++ b/tests/flytekit/common/parameterizers.py @@ -3,53 +3,70 @@ from six.moves import range +import flytekit.models.core.task +import flytekit.models.core.types from flytekit.common.types.impl import blobs as _blob_impl from flytekit.common.types.impl import schema as _schema_impl -from flytekit.models import interface, literals, security, task, types -from flytekit.models.core import identifier +from flytekit.models.admin import task as task +from flytekit.models.core import identifier, interface, literals, security from flytekit.models.core import types as _core_types +from flytekit.models.core.compiler import CompiledTask as _compiledTask +from flytekit.models.core.task import Container as _task_container +from flytekit.models.core.task import Resources as _task_resource LIST_OF_SCALAR_LITERAL_TYPES = [ - types.LiteralType(simple=types.SimpleType.BINARY), - types.LiteralType(simple=types.SimpleType.BOOLEAN), - types.LiteralType(simple=types.SimpleType.DATETIME), - types.LiteralType(simple=types.SimpleType.DURATION), - types.LiteralType(simple=types.SimpleType.ERROR), - types.LiteralType(simple=types.SimpleType.FLOAT), - types.LiteralType(simple=types.SimpleType.INTEGER), - types.LiteralType(simple=types.SimpleType.NONE), - types.LiteralType(simple=types.SimpleType.STRING), - types.LiteralType( - schema=types.SchemaType( + flytekit.models.core.types.LiteralType(simple=flytekit.models.core.types.SimpleType.BINARY), + flytekit.models.core.types.LiteralType(simple=flytekit.models.core.types.SimpleType.BOOLEAN), + flytekit.models.core.types.LiteralType(simple=flytekit.models.core.types.SimpleType.DATETIME), + flytekit.models.core.types.LiteralType(simple=flytekit.models.core.types.SimpleType.DURATION), + flytekit.models.core.types.LiteralType(simple=flytekit.models.core.types.SimpleType.ERROR), + flytekit.models.core.types.LiteralType(simple=flytekit.models.core.types.SimpleType.FLOAT), + flytekit.models.core.types.LiteralType(simple=flytekit.models.core.types.SimpleType.INTEGER), + flytekit.models.core.types.LiteralType(simple=flytekit.models.core.types.SimpleType.NONE), + flytekit.models.core.types.LiteralType(simple=flytekit.models.core.types.SimpleType.STRING), + flytekit.models.core.types.LiteralType( + schema=flytekit.models.core.types.SchemaType( [ - types.SchemaType.SchemaColumn("a", types.SchemaType.SchemaColumn.SchemaColumnType.INTEGER), - types.SchemaType.SchemaColumn("b", types.SchemaType.SchemaColumn.SchemaColumnType.BOOLEAN), - types.SchemaType.SchemaColumn("c", types.SchemaType.SchemaColumn.SchemaColumnType.DATETIME), - types.SchemaType.SchemaColumn("d", types.SchemaType.SchemaColumn.SchemaColumnType.DURATION), - types.SchemaType.SchemaColumn("e", types.SchemaType.SchemaColumn.SchemaColumnType.FLOAT), - types.SchemaType.SchemaColumn("f", types.SchemaType.SchemaColumn.SchemaColumnType.STRING), + flytekit.models.core.types.SchemaType.SchemaColumn( + "a", flytekit.models.core.types.SchemaType.SchemaColumn.SchemaColumnType.INTEGER + ), + flytekit.models.core.types.SchemaType.SchemaColumn( + "b", flytekit.models.core.types.SchemaType.SchemaColumn.SchemaColumnType.BOOLEAN + ), + flytekit.models.core.types.SchemaType.SchemaColumn( + "c", flytekit.models.core.types.SchemaType.SchemaColumn.SchemaColumnType.DATETIME + ), + flytekit.models.core.types.SchemaType.SchemaColumn( + "d", flytekit.models.core.types.SchemaType.SchemaColumn.SchemaColumnType.DURATION + ), + flytekit.models.core.types.SchemaType.SchemaColumn( + "e", flytekit.models.core.types.SchemaType.SchemaColumn.SchemaColumnType.FLOAT + ), + flytekit.models.core.types.SchemaType.SchemaColumn( + "f", flytekit.models.core.types.SchemaType.SchemaColumn.SchemaColumnType.STRING + ), ] ) ), - types.LiteralType( + flytekit.models.core.types.LiteralType( blob=_core_types.BlobType( format="", dimensionality=_core_types.BlobType.BlobDimensionality.SINGLE, ) ), - types.LiteralType( + flytekit.models.core.types.LiteralType( blob=_core_types.BlobType( format="csv", dimensionality=_core_types.BlobType.BlobDimensionality.SINGLE, ) ), - types.LiteralType( + flytekit.models.core.types.LiteralType( blob=_core_types.BlobType( format="", dimensionality=_core_types.BlobType.BlobDimensionality.MULTIPART, ) ), - types.LiteralType( + flytekit.models.core.types.LiteralType( blob=_core_types.BlobType( format="csv", dimensionality=_core_types.BlobType.BlobDimensionality.MULTIPART, @@ -59,11 +76,13 @@ LIST_OF_COLLECTION_LITERAL_TYPES = [ - types.LiteralType(collection_type=literal_type) for literal_type in LIST_OF_SCALAR_LITERAL_TYPES + flytekit.models.core.types.LiteralType(collection_type=literal_type) + for literal_type in LIST_OF_SCALAR_LITERAL_TYPES ] LIST_OF_NESTED_COLLECTION_LITERAL_TYPES = [ - types.LiteralType(collection_type=literal_type) for literal_type in LIST_OF_COLLECTION_LITERAL_TYPES + flytekit.models.core.types.LiteralType(collection_type=literal_type) + for literal_type in LIST_OF_COLLECTION_LITERAL_TYPES ] LIST_OF_ALL_LITERAL_TYPES = ( @@ -80,11 +99,11 @@ LIST_OF_RESOURCE_ENTRIES = [ - task.Resources.ResourceEntry(task.Resources.ResourceName.CPU, "1"), - task.Resources.ResourceEntry(task.Resources.ResourceName.GPU, "1"), - task.Resources.ResourceEntry(task.Resources.ResourceName.MEMORY, "1G"), - task.Resources.ResourceEntry(task.Resources.ResourceName.STORAGE, "1G"), - task.Resources.ResourceEntry(task.Resources.ResourceName.EPHEMERAL_STORAGE, "1G"), + _task_resource.ResourceEntry(_task_resource.ResourceName.CPU, "1"), + _task_resource.ResourceEntry(_task_resource.ResourceName.GPU, "1"), + _task_resource.ResourceEntry(_task_resource.ResourceName.MEMORY, "1G"), + _task_resource.ResourceEntry(_task_resource.ResourceName.STORAGE, "1G"), + _task_resource.ResourceEntry(_task_resource.ResourceName.EPHEMERAL_STORAGE, "1G"), ] @@ -92,14 +111,18 @@ LIST_OF_RESOURCES = [ - task.Resources(request, limit) + _task_resource(request, limit) for request, limit in product(LIST_OF_RESOURCE_ENTRY_LISTS, LIST_OF_RESOURCE_ENTRY_LISTS) ] LIST_OF_RUNTIME_METADATA = [ - task.RuntimeMetadata(task.RuntimeMetadata.RuntimeType.OTHER, "1.0.0", "python"), - task.RuntimeMetadata(task.RuntimeMetadata.RuntimeType.FLYTE_SDK, "1.0.0b0", "golang"), + flytekit.models.core.task.RuntimeMetadata( + flytekit.models.core.task.RuntimeMetadata.RuntimeType.OTHER, "1.0.0", "python" + ), + flytekit.models.core.task.RuntimeMetadata( + flytekit.models.core.task.RuntimeMetadata.RuntimeType.FLYTE_SDK, "1.0.0b0", "golang" + ), ] @@ -108,7 +131,7 @@ LIST_OF_INTERRUPTIBLE = [None, True, False] LIST_OF_TASK_METADATA = [ - task.TaskMetadata( + flytekit.models.core.task.TaskMetadata( discoverable, runtime_metadata, timeout, @@ -130,13 +153,13 @@ LIST_OF_TASK_TEMPLATES = [ - task.TaskTemplate( + flytekit.models.core.task.TaskTemplate( identifier.Identifier(identifier.ResourceType.TASK, "project", "domain", "name", "version"), "python", task_metadata, interfaces, {"a": 1, "b": [1, 2, 3], "c": "abc", "d": {"x": 1, "y": 2, "z": 3}}, - container=task.Container( + container=_task_container( "my_image", ["this", "is", "a", "cmd"], ["this", "is", "an", "arg"], @@ -149,7 +172,7 @@ ] LIST_OF_CONTAINERS = [ - task.Container( + _task_container( "my_image", ["this", "is", "a", "cmd"], ["this", "is", "an", "arg"], @@ -160,7 +183,7 @@ for resources in LIST_OF_RESOURCES ] -LIST_OF_TASK_CLOSURES = [task.TaskClosure(task.CompiledTask(template)) for template in LIST_OF_TASK_TEMPLATES] +LIST_OF_TASK_CLOSURES = [task.TaskClosure(_compiledTask(template)) for template in LIST_OF_TASK_TEMPLATES] LIST_OF_SCALARS_AND_PYTHON_VALUES = [ (literals.Scalar(primitive=literals.Primitive(integer=100)), 100), @@ -212,14 +235,26 @@ literals.Scalar( schema=literals.Schema( "s3://some/where/", - types.SchemaType( + flytekit.models.core.types.SchemaType( [ - types.SchemaType.SchemaColumn("a", types.SchemaType.SchemaColumn.SchemaColumnType.INTEGER), - types.SchemaType.SchemaColumn("b", types.SchemaType.SchemaColumn.SchemaColumnType.BOOLEAN), - types.SchemaType.SchemaColumn("c", types.SchemaType.SchemaColumn.SchemaColumnType.DATETIME), - types.SchemaType.SchemaColumn("d", types.SchemaType.SchemaColumn.SchemaColumnType.DURATION), - types.SchemaType.SchemaColumn("e", types.SchemaType.SchemaColumn.SchemaColumnType.FLOAT), - types.SchemaType.SchemaColumn("f", types.SchemaType.SchemaColumn.SchemaColumnType.STRING), + flytekit.models.core.types.SchemaType.SchemaColumn( + "a", flytekit.models.core.types.SchemaType.SchemaColumn.SchemaColumnType.INTEGER + ), + flytekit.models.core.types.SchemaType.SchemaColumn( + "b", flytekit.models.core.types.SchemaType.SchemaColumn.SchemaColumnType.BOOLEAN + ), + flytekit.models.core.types.SchemaType.SchemaColumn( + "c", flytekit.models.core.types.SchemaType.SchemaColumn.SchemaColumnType.DATETIME + ), + flytekit.models.core.types.SchemaType.SchemaColumn( + "d", flytekit.models.core.types.SchemaType.SchemaColumn.SchemaColumnType.DURATION + ), + flytekit.models.core.types.SchemaType.SchemaColumn( + "e", flytekit.models.core.types.SchemaType.SchemaColumn.SchemaColumnType.FLOAT + ), + flytekit.models.core.types.SchemaType.SchemaColumn( + "f", flytekit.models.core.types.SchemaType.SchemaColumn.SchemaColumnType.STRING + ), ] ), ) @@ -227,14 +262,26 @@ _schema_impl.Schema( "s3://some/where/", _schema_impl.SchemaType.promote_from_model( - types.SchemaType( + flytekit.models.core.types.SchemaType( [ - types.SchemaType.SchemaColumn("a", types.SchemaType.SchemaColumn.SchemaColumnType.INTEGER), - types.SchemaType.SchemaColumn("b", types.SchemaType.SchemaColumn.SchemaColumnType.BOOLEAN), - types.SchemaType.SchemaColumn("c", types.SchemaType.SchemaColumn.SchemaColumnType.DATETIME), - types.SchemaType.SchemaColumn("d", types.SchemaType.SchemaColumn.SchemaColumnType.DURATION), - types.SchemaType.SchemaColumn("e", types.SchemaType.SchemaColumn.SchemaColumnType.FLOAT), - types.SchemaType.SchemaColumn("f", types.SchemaType.SchemaColumn.SchemaColumnType.STRING), + flytekit.models.core.types.SchemaType.SchemaColumn( + "a", flytekit.models.core.types.SchemaType.SchemaColumn.SchemaColumnType.INTEGER + ), + flytekit.models.core.types.SchemaType.SchemaColumn( + "b", flytekit.models.core.types.SchemaType.SchemaColumn.SchemaColumnType.BOOLEAN + ), + flytekit.models.core.types.SchemaType.SchemaColumn( + "c", flytekit.models.core.types.SchemaType.SchemaColumn.SchemaColumnType.DATETIME + ), + flytekit.models.core.types.SchemaType.SchemaColumn( + "d", flytekit.models.core.types.SchemaType.SchemaColumn.SchemaColumnType.DURATION + ), + flytekit.models.core.types.SchemaType.SchemaColumn( + "e", flytekit.models.core.types.SchemaType.SchemaColumn.SchemaColumnType.FLOAT + ), + flytekit.models.core.types.SchemaType.SchemaColumn( + "f", flytekit.models.core.types.SchemaType.SchemaColumn.SchemaColumnType.STRING + ), ] ) ), diff --git a/tests/flytekit/unit/bin/test_python_entrypoint.py b/tests/flytekit/unit/bin/test_python_entrypoint.py index d54d7dc52b..f803079c7b 100644 --- a/tests/flytekit/unit/bin/test_python_entrypoint.py +++ b/tests/flytekit/unit/bin/test_python_entrypoint.py @@ -24,9 +24,9 @@ from flytekit.core.type_engine import TypeEngine from flytekit.extras.persistence.gcs_gsutil import GCSPersistence from flytekit.extras.persistence.s3_awscli import S3Persistence -from flytekit.models import literals as _literal_models from flytekit.models.core import errors as error_models from flytekit.models.core import execution as execution_models +from flytekit.models.core import literals as _literal_models from tests.flytekit.common import task_definitions as _task_defs diff --git a/tests/flytekit/unit/cli/test_cli_helpers.py b/tests/flytekit/unit/cli/test_cli_helpers.py index 9188b2fc3a..480a056856 100644 --- a/tests/flytekit/unit/cli/test_cli_helpers.py +++ b/tests/flytekit/unit/cli/test_cli_helpers.py @@ -7,10 +7,11 @@ from flyteidl.core import workflow_pb2 as _core_workflow_pb2 from flyteidl.core.identifier_pb2 import LAUNCH_PLAN +import flytekit.models.core.types from flytekit.clis import helpers from flytekit.clis.helpers import _hydrate_identifier, _hydrate_workflow_template_nodes, hydrate_registration_parameters -from flytekit.models import literals, types -from flytekit.models.interface import Parameter, ParameterMap, Variable +from flytekit.models.core import literals +from flytekit.models.core.interface import Parameter, ParameterMap, Variable def test_parse_args_into_dict(): @@ -29,7 +30,10 @@ def test_parse_args_into_dict(): def test_construct_literal_map_from_variable_map(): - v = Variable(type=types.LiteralType(simple=types.SimpleType.INTEGER), description="some description") + v = Variable( + type=flytekit.models.core.types.LiteralType(simple=flytekit.models.core.types.SimpleType.INTEGER), + description="some description", + ) variable_map = { "inputa": v, } @@ -43,7 +47,10 @@ def test_construct_literal_map_from_variable_map(): def test_construct_literal_map_from_parameter_map(): - v = Variable(type=types.LiteralType(simple=types.SimpleType.INTEGER), description="some description") + v = Variable( + type=flytekit.models.core.types.LiteralType(simple=flytekit.models.core.types.SimpleType.INTEGER), + description="some description", + ) p = Parameter(var=v, required=True) pm = ParameterMap(parameters={"inputa": p}) diff --git a/tests/flytekit/unit/cli/test_flyte_cli.py b/tests/flytekit/unit/cli/test_flyte_cli.py index 3049590036..c10ebc6de3 100644 --- a/tests/flytekit/unit/cli/test_flyte_cli.py +++ b/tests/flytekit/unit/cli/test_flyte_cli.py @@ -9,8 +9,8 @@ from flytekit.configuration import TemporaryConfiguration from flytekit.models import filters as _filters from flytekit.models.admin import common as _admin_common +from flytekit.models.admin.project import Project as _Project from flytekit.models.core import identifier as _core_identifier -from flytekit.models.project import Project as _Project from flytekit.sdk.tasks import inputs, outputs, python_task mm = _mock.MagicMock() diff --git a/tests/flytekit/unit/clients/test_friendly.py b/tests/flytekit/unit/clients/test_friendly.py index e2e147dc1d..913d8ca0ac 100644 --- a/tests/flytekit/unit/clients/test_friendly.py +++ b/tests/flytekit/unit/clients/test_friendly.py @@ -2,7 +2,7 @@ from flyteidl.admin import project_pb2 as _project_pb2 from flytekit.clients.friendly import SynchronousFlyteClient as _SynchronousFlyteClient -from flytekit.models.project import Project as _Project +from flytekit.models.admin.project import Project as _Project @_mock.patch("flytekit.clients.friendly._RawSynchronousFlyteClient.update_project") diff --git a/tests/flytekit/unit/common_tests/tasks/test_sdk_runnable.py b/tests/flytekit/unit/common_tests/tasks/test_sdk_runnable.py index 29402bea55..c580efdc14 100644 --- a/tests/flytekit/unit/common_tests/tasks/test_sdk_runnable.py +++ b/tests/flytekit/unit/common_tests/tasks/test_sdk_runnable.py @@ -4,7 +4,7 @@ from flytekit.common.exceptions import user as _user_exceptions from flytekit.common.tasks import sdk_runnable from flytekit.common.types import primitives -from flytekit.models import interface +from flytekit.models.core import interface def test_basic_unit_test(): diff --git a/tests/flytekit/unit/common_tests/tasks/test_task.py b/tests/flytekit/unit/common_tests/tasks/test_task.py index e33a757412..91aa49b7a3 100644 --- a/tests/flytekit/unit/common_tests/tasks/test_task.py +++ b/tests/flytekit/unit/common_tests/tasks/test_task.py @@ -10,7 +10,7 @@ from flytekit.common.tasks.presto_task import SdkPrestoTask from flytekit.common.types import primitives from flytekit.configuration import TemporaryConfiguration -from flytekit.models import task as _task_models +from flytekit.models.admin import task as _task_models from flytekit.models.core import identifier as _identifier from flytekit.sdk.tasks import inputs, outputs, python_task from flytekit.sdk.types import Types diff --git a/tests/flytekit/unit/common_tests/test_launch_plan.py b/tests/flytekit/unit/common_tests/test_launch_plan.py index 92943d37f1..acbaeed857 100644 --- a/tests/flytekit/unit/common_tests/test_launch_plan.py +++ b/tests/flytekit/unit/common_tests/test_launch_plan.py @@ -2,14 +2,14 @@ import pytest as _pytest +import flytekit.models.core.types from flytekit import configuration as _configuration from flytekit.common import launch_plan as _launch_plan from flytekit.common import notifications as _notifications from flytekit.common import schedules as _schedules from flytekit.common.exceptions import user as _user_exceptions -from flytekit.models import common as _common_models -from flytekit.models import schedule as _schedule -from flytekit.models import types as _type_models +from flytekit.models.admin import common as _common_models +from flytekit.models.admin import schedule as _schedule from flytekit.models.core import execution as _execution from flytekit.models.core import identifier as _identifier from flytekit.sdk import types as _types @@ -273,7 +273,10 @@ def test_launch_plan_node(): # Test that outputs are promised n.assign_id_and_return("node-id") - assert n.outputs["out"].sdk_type.to_flyte_literal_type().collection_type.simple == _type_models.SimpleType.INTEGER + assert ( + n.outputs["out"].sdk_type.to_flyte_literal_type().collection_type.simple + == flytekit.models.core.types.SimpleType.INTEGER + ) assert n.outputs["out"].var == "out" assert n.outputs["out"].node_id == "node-id" diff --git a/tests/flytekit/unit/common_tests/test_nodes.py b/tests/flytekit/unit/common_tests/test_nodes.py index 0c462c552c..b44d65d77e 100644 --- a/tests/flytekit/unit/common_tests/test_nodes.py +++ b/tests/flytekit/unit/common_tests/test_nodes.py @@ -6,8 +6,8 @@ from flytekit.common import interface as _interface from flytekit.common import nodes as _nodes from flytekit.common.exceptions import system as _system_exceptions -from flytekit.models import literals as _literals from flytekit.models.core import identifier as _identifier +from flytekit.models.core import literals as _literals from flytekit.models.core import workflow as _core_workflow_models from flytekit.sdk import tasks as _tasks from flytekit.sdk import types as _types diff --git a/tests/flytekit/unit/common_tests/test_promise.py b/tests/flytekit/unit/common_tests/test_promise.py index e06cca7dfa..abac1f2420 100644 --- a/tests/flytekit/unit/common_tests/test_promise.py +++ b/tests/flytekit/unit/common_tests/test_promise.py @@ -7,7 +7,7 @@ from flytekit.core.interface import Interface from flytekit.core.promise import Promise, create_native_named_tuple, extract_obj_name from flytekit.core.type_engine import TypeEngine -from flytekit.models.types import LiteralType, SimpleType +from flytekit.models.core.types import LiteralType, SimpleType def test_input(): diff --git a/tests/flytekit/unit/common_tests/test_workflow.py b/tests/flytekit/unit/common_tests/test_workflow.py index 13a9e65d8e..0828bec657 100644 --- a/tests/flytekit/unit/common_tests/test_workflow.py +++ b/tests/flytekit/unit/common_tests/test_workflow.py @@ -7,8 +7,8 @@ from flytekit.common.exceptions import user as _user_exceptions from flytekit.common.local_workflow import build_sdk_workflow_from_metaclass from flytekit.common.types import containers, primitives -from flytekit.models import literals as _literals from flytekit.models.core import identifier as _identifier +from flytekit.models.core import literals as _literals from flytekit.models.core import workflow as _workflow_models from flytekit.sdk import types as _types from flytekit.sdk.tasks import inputs, outputs, python_task diff --git a/tests/flytekit/unit/common_tests/test_workflow_promote.py b/tests/flytekit/unit/common_tests/test_workflow_promote.py index cdd1231107..0e39fcb881 100644 --- a/tests/flytekit/unit/common_tests/test_workflow_promote.py +++ b/tests/flytekit/unit/common_tests/test_workflow_promote.py @@ -5,15 +5,18 @@ from flyteidl.core import workflow_pb2 as _workflow_pb2 from mock import patch as _patch +import flytekit.models.core.task +import flytekit.models.core.types from flytekit.common import workflow as _workflow_common from flytekit.common.tasks import task as _task -from flytekit.models import interface as _interface -from flytekit.models import literals as _literals -from flytekit.models import task as _task_model -from flytekit.models import types as _types from flytekit.models.core import compiler as _compiler_model from flytekit.models.core import identifier as _identifier +from flytekit.models.core import interface as _interface +from flytekit.models.core import literals as _literals from flytekit.models.core import workflow as _workflow_model +from flytekit.models.core.task import RuntimeMetadata as _runtimeMetadata +from flytekit.models.core.task import TaskMetadata as _taskMetadata +from flytekit.models.core.task import TaskTemplate as _taskTemplate from flytekit.sdk import tasks as _sdk_tasks from flytekit.sdk import workflow as _sdk_workflow from flytekit.sdk.tasks import inputs, outputs, python_task @@ -35,10 +38,12 @@ def get_sample_container(): """ :rtype: flytekit.models.task.Container """ - cpu_resource = _task_model.Resources.ResourceEntry(_task_model.Resources.ResourceName.CPU, "1") - resources = _task_model.Resources(requests=[cpu_resource], limits=[cpu_resource]) + cpu_resource = flytekit.models.core.task.Resources.ResourceEntry( + flytekit.models.core.task.Resources.ResourceName.CPU, "1" + ) + resources = flytekit.models.core.task.Resources(requests=[cpu_resource], limits=[cpu_resource]) - return _task_model.Container( + return flytekit.models.core.task.Container( "my_image", ["this", "is", "a", "cmd"], ["this", "is", "an", "arg"], @@ -50,11 +55,11 @@ def get_sample_container(): def get_sample_task_metadata(): """ - :rtype: flytekit.models.task.TaskMetadata + :rtype: flytekit.models.core.task.TaskMetadata """ - return _task_model.TaskMetadata( + return _taskMetadata( True, - _task_model.RuntimeMetadata(_task_model.RuntimeMetadata.RuntimeType.FLYTE_SDK, "1.0.0", "python"), + _runtimeMetadata(_runtimeMetadata.RuntimeType.FLYTE_SDK, "1.0.0", "python"), timedelta(days=1), _literals.RetryStrategy(3), True, @@ -125,7 +130,7 @@ class TestPromoteExampleWf(object): wf_output_c = _sdk_workflow.Output(my_task_node.outputs.c, sdk_type=_Types.Integer) # This section uses the TaskTemplate stored in Admin to promote back to an Sdk Workflow - int_type = _types.LiteralType(_types.SimpleType.INTEGER) + int_type = flytekit.models.core.types.LiteralType(flytekit.models.core.types.SimpleType.INTEGER) task_interface = _interface.TypedInterface( # inputs {"a": _interface.Variable(int_type, "description1")}, @@ -133,7 +138,7 @@ class TestPromoteExampleWf(object): {"b": _interface.Variable(int_type, "description2"), "c": _interface.Variable(int_type, "description3")}, ) # Since the promotion of a workflow requires retrieving the task from Admin, we mock the SdkTask to return - task_template = _task_model.TaskTemplate( + task_template = _taskTemplate( _identifier.Identifier( _identifier.ResourceType.TASK, "project", diff --git a/tests/flytekit/unit/common_tests/types/impl/test_schema.py b/tests/flytekit/unit/common_tests/types/impl/test_schema.py index de5bba4a46..327f681f40 100644 --- a/tests/flytekit/unit/common_tests/types/impl/test_schema.py +++ b/tests/flytekit/unit/common_tests/types/impl/test_schema.py @@ -7,13 +7,13 @@ import pytest as _pytest import six.moves as _six_moves +import flytekit.models.core.types from flytekit.common import utils as _utils from flytekit.common.exceptions import user as _user_exceptions from flytekit.common.types import blobs as _blobs from flytekit.common.types import primitives as _primitives from flytekit.common.types.impl import schema as _schema_impl -from flytekit.models import literals as _literal_models -from flytekit.models import types as _type_models +from flytekit.models.core import literals as _literal_models from flytekit.sdk import test_utils as _test_utils @@ -385,14 +385,26 @@ def empty_list(): def test_promote_from_model_schema_type(): - m = _type_models.SchemaType( + m = flytekit.models.core.types.SchemaType( [ - _type_models.SchemaType.SchemaColumn("a", _type_models.SchemaType.SchemaColumn.SchemaColumnType.BOOLEAN), - _type_models.SchemaType.SchemaColumn("b", _type_models.SchemaType.SchemaColumn.SchemaColumnType.DATETIME), - _type_models.SchemaType.SchemaColumn("c", _type_models.SchemaType.SchemaColumn.SchemaColumnType.DURATION), - _type_models.SchemaType.SchemaColumn("d", _type_models.SchemaType.SchemaColumn.SchemaColumnType.FLOAT), - _type_models.SchemaType.SchemaColumn("e", _type_models.SchemaType.SchemaColumn.SchemaColumnType.INTEGER), - _type_models.SchemaType.SchemaColumn("f", _type_models.SchemaType.SchemaColumn.SchemaColumnType.STRING), + flytekit.models.core.types.SchemaType.SchemaColumn( + "a", flytekit.models.core.types.SchemaType.SchemaColumn.SchemaColumnType.BOOLEAN + ), + flytekit.models.core.types.SchemaType.SchemaColumn( + "b", flytekit.models.core.types.SchemaType.SchemaColumn.SchemaColumnType.DATETIME + ), + flytekit.models.core.types.SchemaType.SchemaColumn( + "c", flytekit.models.core.types.SchemaType.SchemaColumn.SchemaColumnType.DURATION + ), + flytekit.models.core.types.SchemaType.SchemaColumn( + "d", flytekit.models.core.types.SchemaType.SchemaColumn.SchemaColumnType.FLOAT + ), + flytekit.models.core.types.SchemaType.SchemaColumn( + "e", flytekit.models.core.types.SchemaType.SchemaColumn.SchemaColumnType.INTEGER + ), + flytekit.models.core.types.SchemaType.SchemaColumn( + "f", flytekit.models.core.types.SchemaType.SchemaColumn.SchemaColumnType.STRING + ), ] ) s = _schema_impl.SchemaType.promote_from_model(m) @@ -409,22 +421,26 @@ def test_promote_from_model_schema_type(): def test_promote_from_model_schema(): m = _literal_models.Schema( "s3://some/place/", - _type_models.SchemaType( + flytekit.models.core.types.SchemaType( [ - _type_models.SchemaType.SchemaColumn( - "a", _type_models.SchemaType.SchemaColumn.SchemaColumnType.BOOLEAN + flytekit.models.core.types.SchemaType.SchemaColumn( + "a", flytekit.models.core.types.SchemaType.SchemaColumn.SchemaColumnType.BOOLEAN ), - _type_models.SchemaType.SchemaColumn( - "b", _type_models.SchemaType.SchemaColumn.SchemaColumnType.DATETIME + flytekit.models.core.types.SchemaType.SchemaColumn( + "b", flytekit.models.core.types.SchemaType.SchemaColumn.SchemaColumnType.DATETIME ), - _type_models.SchemaType.SchemaColumn( - "c", _type_models.SchemaType.SchemaColumn.SchemaColumnType.DURATION + flytekit.models.core.types.SchemaType.SchemaColumn( + "c", flytekit.models.core.types.SchemaType.SchemaColumn.SchemaColumnType.DURATION ), - _type_models.SchemaType.SchemaColumn("d", _type_models.SchemaType.SchemaColumn.SchemaColumnType.FLOAT), - _type_models.SchemaType.SchemaColumn( - "e", _type_models.SchemaType.SchemaColumn.SchemaColumnType.INTEGER + flytekit.models.core.types.SchemaType.SchemaColumn( + "d", flytekit.models.core.types.SchemaType.SchemaColumn.SchemaColumnType.FLOAT + ), + flytekit.models.core.types.SchemaType.SchemaColumn( + "e", flytekit.models.core.types.SchemaType.SchemaColumn.SchemaColumnType.INTEGER + ), + flytekit.models.core.types.SchemaType.SchemaColumn( + "f", flytekit.models.core.types.SchemaType.SchemaColumn.SchemaColumnType.STRING ), - _type_models.SchemaType.SchemaColumn("f", _type_models.SchemaType.SchemaColumn.SchemaColumnType.STRING), ] ), ) diff --git a/tests/flytekit/unit/common_tests/types/test_blobs.py b/tests/flytekit/unit/common_tests/types/test_blobs.py index 3057b11c6a..748bcdb864 100644 --- a/tests/flytekit/unit/common_tests/types/test_blobs.py +++ b/tests/flytekit/unit/common_tests/types/test_blobs.py @@ -1,6 +1,6 @@ from flytekit.common.types import blobs from flytekit.common.types.impl import blobs as blob_impl -from flytekit.models import literals as _literal_models +from flytekit.models.core import literals as _literal_models from flytekit.models.core import types as _core_types from flytekit.sdk import test_utils diff --git a/tests/flytekit/unit/common_tests/types/test_containers.py b/tests/flytekit/unit/common_tests/types/test_containers.py index dcb8730b08..c544b577ec 100644 --- a/tests/flytekit/unit/common_tests/types/test_containers.py +++ b/tests/flytekit/unit/common_tests/types/test_containers.py @@ -1,10 +1,10 @@ import pytest from six.moves import range as _range +import flytekit.models.core.types from flytekit.common.exceptions import user as _user_exceptions from flytekit.common.types import containers, primitives -from flytekit.models import literals -from flytekit.models import types as literal_types +from flytekit.models.core import literals def test_list(): @@ -12,7 +12,7 @@ def test_list(): assert list_type.to_flyte_literal_type().simple is None assert list_type.to_flyte_literal_type().map_value_type is None assert list_type.to_flyte_literal_type().schema is None - assert list_type.to_flyte_literal_type().collection_type.simple == literal_types.SimpleType.INTEGER + assert list_type.to_flyte_literal_type().collection_type.simple == flytekit.models.core.types.SimpleType.INTEGER list_value = list_type.from_python_std([1, 2, 3, 4]) assert list_value.to_python_std() == [1, 2, 3, 4] @@ -85,7 +85,10 @@ def test_nested_list(): assert list_type.to_flyte_literal_type().collection_type.simple is None assert list_type.to_flyte_literal_type().collection_type.map_value_type is None assert list_type.to_flyte_literal_type().collection_type.schema is None - assert list_type.to_flyte_literal_type().collection_type.collection_type.simple == literal_types.SimpleType.INTEGER + assert ( + list_type.to_flyte_literal_type().collection_type.collection_type.simple + == flytekit.models.core.types.SimpleType.INTEGER + ) gt = [[1, 2, 3], [4, 5, 6], []] list_value = list_type.from_python_std(gt) diff --git a/tests/flytekit/unit/common_tests/types/test_helpers.py b/tests/flytekit/unit/common_tests/types/test_helpers.py index dd8b45af23..fe472c1212 100644 --- a/tests/flytekit/unit/common_tests/types/test_helpers.py +++ b/tests/flytekit/unit/common_tests/types/test_helpers.py @@ -1,20 +1,22 @@ +import flytekit.models.core.types from flytekit.common.types import base_sdk_types as _base_sdk_types from flytekit.common.types import helpers as _type_helpers -from flytekit.models import literals as _literals -from flytekit.models import types as _model_types +from flytekit.models.core import literals as _literals from flytekit.sdk import types as _sdk_types def test_python_std_to_sdk_type(): o = _type_helpers.python_std_to_sdk_type(_sdk_types.Types.Integer) - assert o.to_flyte_literal_type().simple == _model_types.SimpleType.INTEGER + assert o.to_flyte_literal_type().simple == flytekit.models.core.types.SimpleType.INTEGER o = _type_helpers.python_std_to_sdk_type([_sdk_types.Types.Boolean]) - assert o.to_flyte_literal_type().collection_type.simple == _model_types.SimpleType.BOOLEAN + assert o.to_flyte_literal_type().collection_type.simple == flytekit.models.core.types.SimpleType.BOOLEAN def test_get_sdk_type_from_literal_type(): - o = _type_helpers.get_sdk_type_from_literal_type(_model_types.LiteralType(simple=_model_types.SimpleType.FLOAT)) + o = _type_helpers.get_sdk_type_from_literal_type( + flytekit.models.core.types.LiteralType(simple=flytekit.models.core.types.SimpleType.FLOAT) + ) assert o == _sdk_types.Types.Float diff --git a/tests/flytekit/unit/common_tests/types/test_primitives.py b/tests/flytekit/unit/common_tests/types/test_primitives.py index b161c751dd..f2b29b62ce 100644 --- a/tests/flytekit/unit/common_tests/types/test_primitives.py +++ b/tests/flytekit/unit/common_tests/types/test_primitives.py @@ -3,14 +3,14 @@ import pytest from dateutil import tz +import flytekit.models.core.types from flytekit.common.exceptions import user as user_exceptions from flytekit.common.types import base_sdk_types, primitives -from flytekit.models import types as literal_types def test_integer(): # Check type specification - assert primitives.Integer.to_flyte_literal_type().simple == literal_types.SimpleType.INTEGER + assert primitives.Integer.to_flyte_literal_type().simple == flytekit.models.core.types.SimpleType.INTEGER # Test value behavior obj = primitives.Integer.from_python_std(1) @@ -45,7 +45,7 @@ def test_integer(): def test_float(): # Check type specification - assert primitives.Float.to_flyte_literal_type().simple == literal_types.SimpleType.FLOAT + assert primitives.Float.to_flyte_literal_type().simple == flytekit.models.core.types.SimpleType.FLOAT # Test value behavior obj = primitives.Float.from_python_std(1.0) @@ -80,7 +80,7 @@ def test_float(): def test_boolean(): # Check type specification - assert primitives.Boolean.to_flyte_literal_type().simple == literal_types.SimpleType.BOOLEAN + assert primitives.Boolean.to_flyte_literal_type().simple == flytekit.models.core.types.SimpleType.BOOLEAN # Test value behavior obj = primitives.Boolean.from_python_std(True) @@ -119,7 +119,7 @@ def test_boolean(): def test_string(): # Check type specification - assert primitives.String.to_flyte_literal_type().simple == literal_types.SimpleType.STRING + assert primitives.String.to_flyte_literal_type().simple == flytekit.models.core.types.SimpleType.STRING # Test value behavior obj = primitives.String.from_python_std("abc") @@ -172,7 +172,7 @@ def dst(self, dt): def test_datetime(): # Check type specification - assert primitives.Datetime.to_flyte_literal_type().simple == literal_types.SimpleType.DATETIME + assert primitives.Datetime.to_flyte_literal_type().simple == flytekit.models.core.types.SimpleType.DATETIME # Test value behavior dt = datetime.datetime.now(tz=tz.UTC) @@ -205,7 +205,7 @@ def test_datetime(): def test_timedelta(): # Check type specification - assert primitives.Timedelta.to_flyte_literal_type().simple == literal_types.SimpleType.DURATION + assert primitives.Timedelta.to_flyte_literal_type().simple == flytekit.models.core.types.SimpleType.DURATION # Test value behavior obj = primitives.Timedelta.from_python_std(datetime.timedelta(seconds=1)) @@ -258,7 +258,7 @@ def test_void(): def test_generic(): # Check type specification - assert primitives.Generic.to_flyte_literal_type().simple == literal_types.SimpleType.STRUCT + assert primitives.Generic.to_flyte_literal_type().simple == flytekit.models.core.types.SimpleType.STRUCT # Test value behavior d = {"a": [1, 2, 3], "b": "abc", "c": 1, "d": {"a": 1}} diff --git a/tests/flytekit/unit/common_tests/types/test_proto.py b/tests/flytekit/unit/common_tests/types/test_proto.py index 074fb1dfd8..a0d91ccb5d 100644 --- a/tests/flytekit/unit/common_tests/types/test_proto.py +++ b/tests/flytekit/unit/common_tests/types/test_proto.py @@ -3,10 +3,10 @@ import pytest as _pytest from flyteidl.core import errors_pb2 as _errors_pb2 +import flytekit.models.core.types from flytekit.common.exceptions import user as _user_exceptions from flytekit.common.types import proto as _proto from flytekit.common.types.proto import ProtobufType -from flytekit.models import types as _type_models def test_wrong_type(): @@ -16,7 +16,7 @@ def test_wrong_type(): def test_proto_to_literal_type(): proto_type = _proto.create_protobuf(_errors_pb2.ContainerError) - assert proto_type.to_flyte_literal_type().simple == _type_models.SimpleType.BINARY + assert proto_type.to_flyte_literal_type().simple == flytekit.models.core.types.SimpleType.BINARY assert len(proto_type.to_flyte_literal_type().metadata) == 1 assert ( proto_type.to_flyte_literal_type().metadata[_proto.Protobuf.PB_FIELD_KEY] diff --git a/tests/flytekit/unit/core/test_composition.py b/tests/flytekit/unit/core/test_composition.py index 37f5d10195..e9cb253a41 100644 --- a/tests/flytekit/unit/core/test_composition.py +++ b/tests/flytekit/unit/core/test_composition.py @@ -3,7 +3,7 @@ from flytekit.core import launch_plan from flytekit.core.task import task from flytekit.core.workflow import workflow -from flytekit.models import literals as _literal_models +from flytekit.models.core import literals as _literal_models def test_wf1_with_subwf(): diff --git a/tests/flytekit/unit/core/test_flyte_directory.py b/tests/flytekit/unit/core/test_flyte_directory.py index db7e8c3d10..fcdb80910e 100644 --- a/tests/flytekit/unit/core/test_flyte_directory.py +++ b/tests/flytekit/unit/core/test_flyte_directory.py @@ -15,8 +15,8 @@ from flytekit.core.task import task from flytekit.core.type_engine import TypeEngine from flytekit.core.workflow import workflow +from flytekit.models.core.literals import LiteralMap from flytekit.models.core.types import BlobType -from flytekit.models.literals import LiteralMap from flytekit.types.directory.types import FlyteDirectory, FlyteDirToMultipartBlobTransformer diff --git a/tests/flytekit/unit/core/test_flyte_file.py b/tests/flytekit/unit/core/test_flyte_file.py index 3a34445f08..c5030c1718 100644 --- a/tests/flytekit/unit/core/test_flyte_file.py +++ b/tests/flytekit/unit/core/test_flyte_file.py @@ -14,8 +14,8 @@ from flytekit.core.task import task from flytekit.core.type_engine import TypeEngine from flytekit.core.workflow import workflow +from flytekit.models.core.literals import LiteralMap from flytekit.models.core.types import BlobType -from flytekit.models.literals import LiteralMap from flytekit.types.file.file import FlyteFile diff --git a/tests/flytekit/unit/core/test_imperative.py b/tests/flytekit/unit/core/test_imperative.py index fd45c831e8..15f58a0f98 100644 --- a/tests/flytekit/unit/core/test_imperative.py +++ b/tests/flytekit/unit/core/test_imperative.py @@ -13,7 +13,7 @@ from flytekit.core.task import reference_task, task from flytekit.core.workflow import ImperativeWorkflow, get_promise, workflow from flytekit.extras.sqlite3.task import SQLite3Config, SQLite3Task -from flytekit.models import literals as literal_models +from flytekit.models.core import literals as literal_models from flytekit.types.file import FlyteFile from flytekit.types.schema import FlyteSchema diff --git a/tests/flytekit/unit/core/test_launch_plan.py b/tests/flytekit/unit/core/test_launch_plan.py index 72d10fb2d1..12c04072ad 100644 --- a/tests/flytekit/unit/core/test_launch_plan.py +++ b/tests/flytekit/unit/core/test_launch_plan.py @@ -10,7 +10,7 @@ from flytekit.core.schedule import CronSchedule from flytekit.core.task import task from flytekit.core.workflow import workflow -from flytekit.models.common import Annotations, AuthRole, Labels, RawOutputDataConfig +from flytekit.models.admin.common import Annotations, AuthRole, Labels, RawOutputDataConfig from flytekit.models.core import execution as _execution_model from flytekit.models.core import identifier as identifier_models diff --git a/tests/flytekit/unit/core/test_node_creation.py b/tests/flytekit/unit/core/test_node_creation.py index 4b1c59c479..78d14d29e2 100644 --- a/tests/flytekit/unit/core/test_node_creation.py +++ b/tests/flytekit/unit/core/test_node_creation.py @@ -13,8 +13,8 @@ from flytekit.core.node_creation import create_node from flytekit.core.task import task from flytekit.core.workflow import workflow -from flytekit.models import literals as _literal_models -from flytekit.models.task import Resources as _resources_models +from flytekit.models.core import literals as _literal_models +from flytekit.models.core.task import Resources as _resources_models def test_normal_task(): diff --git a/tests/flytekit/unit/core/test_notifications.py b/tests/flytekit/unit/core/test_notifications.py index 05169ede83..1862d9b6b1 100644 --- a/tests/flytekit/unit/core/test_notifications.py +++ b/tests/flytekit/unit/core/test_notifications.py @@ -4,7 +4,7 @@ from flytekit.core.launch_plan import LaunchPlan from flytekit.core.task import task from flytekit.core.workflow import workflow -from flytekit.models import common as _common_model +from flytekit.models.admin import common as _common_model from flytekit.models.core import execution as _execution_model _workflow_execution_succeeded = _execution_model.WorkflowExecutionPhase.SUCCEEDED diff --git a/tests/flytekit/unit/core/test_protobuf.py b/tests/flytekit/unit/core/test_protobuf.py index 6509e5229a..5b1091e9ec 100644 --- a/tests/flytekit/unit/core/test_protobuf.py +++ b/tests/flytekit/unit/core/test_protobuf.py @@ -5,7 +5,7 @@ from flytekit.core.task import task from flytekit.core.type_engine import TypeEngine from flytekit.core.workflow import workflow -from flytekit.models.types import LiteralType, SimpleType +from flytekit.models.core.types import LiteralType, SimpleType def test_proto(): diff --git a/tests/flytekit/unit/core/test_schedule.py b/tests/flytekit/unit/core/test_schedule.py index bd76a24405..c4483ad9d3 100644 --- a/tests/flytekit/unit/core/test_schedule.py +++ b/tests/flytekit/unit/core/test_schedule.py @@ -6,7 +6,7 @@ from flytekit.core.schedule import CronSchedule, FixedRate from flytekit.core.task import task from flytekit.core.workflow import workflow -from flytekit.models import schedule as _schedule_models +from flytekit.models.admin import schedule as _schedule_models def test_cron(): diff --git a/tests/flytekit/unit/core/test_serialization.py b/tests/flytekit/unit/core/test_serialization.py index fac26994a1..d6b090a11a 100644 --- a/tests/flytekit/unit/core/test_serialization.py +++ b/tests/flytekit/unit/core/test_serialization.py @@ -12,7 +12,7 @@ from flytekit.core.context_manager import Image, ImageConfig, SerializationSettings, get_image_config from flytekit.core.task import task from flytekit.core.workflow import workflow -from flytekit.models.types import SimpleType +from flytekit.models.core.types import SimpleType default_img = Image(name="default", fqn="test", tag="tag") serialization_settings = context_manager.SerializationSettings( diff --git a/tests/flytekit/unit/core/test_type_engine.py b/tests/flytekit/unit/core/test_type_engine.py index 74628bb415..3912cae4b3 100644 --- a/tests/flytekit/unit/core/test_type_engine.py +++ b/tests/flytekit/unit/core/test_type_engine.py @@ -12,6 +12,7 @@ from google.protobuf import struct_pb2 as _struct from marshmallow_jsonschema import JSONSchema +import flytekit.models.core.types from flytekit.common.exceptions import user as user_exceptions from flytekit.core.context_manager import FlyteContext, FlyteContextManager from flytekit.core.type_engine import ( @@ -23,10 +24,8 @@ convert_json_schema_to_python_class, dataclass_from_dict, ) -from flytekit.models import types as model_types -from flytekit.models.core.types import BlobType -from flytekit.models.literals import Blob, BlobMetadata, Literal, LiteralCollection, LiteralMap, Primitive, Scalar -from flytekit.models.types import LiteralType, SimpleType +from flytekit.models.core.literals import Blob, BlobMetadata, Literal, LiteralCollection, LiteralMap, Primitive, Scalar +from flytekit.models.core.types import BlobType, LiteralType, SimpleType from flytekit.types.directory.types import FlyteDirectory from flytekit.types.file import JPEGImageFile from flytekit.types.file.file import FlyteFile, FlyteFilePathTransformer @@ -35,18 +34,18 @@ def test_type_engine(): t = int lt = TypeEngine.to_literal_type(t) - assert lt.simple == model_types.SimpleType.INTEGER + assert lt.simple == flytekit.models.core.types.SimpleType.INTEGER t = typing.Dict[str, typing.List[typing.Dict[str, timedelta]]] lt = TypeEngine.to_literal_type(t) - assert lt.map_value_type.collection_type.map_value_type.simple == model_types.SimpleType.DURATION + assert lt.map_value_type.collection_type.map_value_type.simple == flytekit.models.core.types.SimpleType.DURATION def test_named_tuple(): t = typing.NamedTuple("Outputs", [("x_str", str), ("y_int", int)]) var_map = TypeEngine.named_tuple_to_variable_map(t) - assert var_map.variables["x_str"].type.simple == model_types.SimpleType.STRING - assert var_map.variables["y_int"].type.simple == model_types.SimpleType.INTEGER + assert var_map.variables["x_str"].type.simple == flytekit.models.core.types.SimpleType.STRING + assert var_map.variables["y_int"].type.simple == flytekit.models.core.types.SimpleType.INTEGER def test_type_resolution(): @@ -322,43 +321,43 @@ def test_protos(): def test_guessing_basic(): - b = model_types.LiteralType(simple=model_types.SimpleType.BOOLEAN) + b = flytekit.models.core.types.LiteralType(simple=flytekit.models.core.types.SimpleType.BOOLEAN) pt = TypeEngine.guess_python_type(b) assert pt is bool - lt = model_types.LiteralType(simple=model_types.SimpleType.INTEGER) + lt = flytekit.models.core.types.LiteralType(simple=flytekit.models.core.types.SimpleType.INTEGER) pt = TypeEngine.guess_python_type(lt) assert pt is int - lt = model_types.LiteralType(simple=model_types.SimpleType.STRING) + lt = flytekit.models.core.types.LiteralType(simple=flytekit.models.core.types.SimpleType.STRING) pt = TypeEngine.guess_python_type(lt) assert pt is str - lt = model_types.LiteralType(simple=model_types.SimpleType.DURATION) + lt = flytekit.models.core.types.LiteralType(simple=flytekit.models.core.types.SimpleType.DURATION) pt = TypeEngine.guess_python_type(lt) assert pt is timedelta - lt = model_types.LiteralType(simple=model_types.SimpleType.DATETIME) + lt = flytekit.models.core.types.LiteralType(simple=flytekit.models.core.types.SimpleType.DATETIME) pt = TypeEngine.guess_python_type(lt) assert pt is datetime.datetime - lt = model_types.LiteralType(simple=model_types.SimpleType.FLOAT) + lt = flytekit.models.core.types.LiteralType(simple=flytekit.models.core.types.SimpleType.FLOAT) pt = TypeEngine.guess_python_type(lt) assert pt is float - lt = model_types.LiteralType(simple=model_types.SimpleType.NONE) + lt = flytekit.models.core.types.LiteralType(simple=flytekit.models.core.types.SimpleType.NONE) pt = TypeEngine.guess_python_type(lt) assert pt is None def test_guessing_containers(): - b = model_types.LiteralType(simple=model_types.SimpleType.BOOLEAN) - lt = model_types.LiteralType(collection_type=b) + b = flytekit.models.core.types.LiteralType(simple=flytekit.models.core.types.SimpleType.BOOLEAN) + lt = flytekit.models.core.types.LiteralType(collection_type=b) pt = TypeEngine.guess_python_type(lt) assert pt == typing.List[bool] - dur = model_types.LiteralType(simple=model_types.SimpleType.DURATION) - lt = model_types.LiteralType(map_value_type=dur) + dur = flytekit.models.core.types.LiteralType(simple=flytekit.models.core.types.SimpleType.DURATION) + lt = flytekit.models.core.types.LiteralType(map_value_type=dur) pt = TypeEngine.guess_python_type(lt) assert pt == typing.Dict[str, timedelta] diff --git a/tests/flytekit/unit/core/test_type_hints.py b/tests/flytekit/unit/core/test_type_hints.py index c2957ad9e0..dacfa392bd 100644 --- a/tests/flytekit/unit/core/test_type_hints.py +++ b/tests/flytekit/unit/core/test_type_hints.py @@ -26,11 +26,11 @@ from flytekit.core.testing import patch, task_mock from flytekit.core.type_engine import RestrictedTypeError, TypeEngine from flytekit.core.workflow import workflow -from flytekit.models import literals as _literal_models +from flytekit.models.core import literals as _literal_models from flytekit.models.core import types as _core_types -from flytekit.models.interface import Parameter -from flytekit.models.task import Resources as _resource_models -from flytekit.models.types import LiteralType, SimpleType +from flytekit.models.core.interface import Parameter +from flytekit.models.core.task import Resources as _resource_models +from flytekit.models.core.types import LiteralType, SimpleType from flytekit.types.schema import FlyteSchema, SchemaOpenMode serialization_settings = context_manager.SerializationSettings( diff --git a/tests/flytekit/unit/engines/flyte/test_engine.py b/tests/flytekit/unit/engines/flyte/test_engine.py index ba7d478cd0..69969179af 100644 --- a/tests/flytekit/unit/engines/flyte/test_engine.py +++ b/tests/flytekit/unit/engines/flyte/test_engine.py @@ -8,13 +8,12 @@ from flytekit.common.exceptions import scopes from flytekit.configuration import TemporaryConfiguration from flytekit.engines.flyte import engine -from flytekit.models import common as _common_models -from flytekit.models import execution as _execution_models -from flytekit.models import launch_plan as _launch_plan_models -from flytekit.models import literals -from flytekit.models import task as _task_models -from flytekit.models.admin import common as _common -from flytekit.models.core import errors, identifier +from flytekit.models.admin import common as _common_models +from flytekit.models.admin import execution as _execution_models +from flytekit.models.admin import launch_plan as _launch_plan_models +from flytekit.models.admin import task as _task_models +from flytekit.models.admin.common import NamedEntityIdentifier as _namedEntityIdentifier +from flytekit.models.core import errors, identifier, literals from flytekit.sdk import test_utils _INPUT_MAP = literals.LiteralMap( @@ -290,7 +289,7 @@ def test_fetch_active_launch_plan(mock_client_factory): ) assert lp.id == identifier.Identifier(identifier.ResourceType.LAUNCH_PLAN, "p1", "d1", "n1", "v1") - mock_client.get_active_launch_plan.assert_called_once_with(_common_models.NamedEntityIdentifier("p", "d", "n")) + mock_client.get_active_launch_plan.assert_called_once_with(_namedEntityIdentifier("p", "d", "n")) @patch.object(engine._FlyteClientManager, "_CLIENT", new_callable=PropertyMock) @@ -793,7 +792,7 @@ def test_fetch_latest_task(mock_client_factory, tasks): mock_client.list_tasks_paginated = MagicMock(return_value=(tasks, 0)) mock_client_factory.return_value = mock_client - task = engine.FlyteEngineFactory().fetch_latest_task(_common_models.NamedEntityIdentifier("p", "d", "n")) + task = engine.FlyteEngineFactory().fetch_latest_task(_namedEntityIdentifier("p", "d", "n")) if tasks: assert task.id == tasks[0].id @@ -801,7 +800,7 @@ def test_fetch_latest_task(mock_client_factory, tasks): assert not task mock_client.list_tasks_paginated.assert_called_once_with( - _common_models.NamedEntityIdentifier("p", "d", "n"), + _namedEntityIdentifier("p", "d", "n"), limit=1, - sort_by=_common.Sort("created_at", _common.Sort.Direction.DESCENDING), + sort_by=_common_models.Sort("created_at", _common_models.Sort.Direction.DESCENDING), ) diff --git a/tests/flytekit/unit/models/admin/test_node_executions.py b/tests/flytekit/unit/models/admin/test_node_executions.py index 84d8785d09..1ccb63f410 100644 --- a/tests/flytekit/unit/models/admin/test_node_executions.py +++ b/tests/flytekit/unit/models/admin/test_node_executions.py @@ -1,4 +1,4 @@ -from flytekit.models import node_execution as node_execution_models +from flytekit.models.admin import node_execution as node_execution_models def test_metadata(): diff --git a/tests/flytekit/unit/models/core/test_workflow.py b/tests/flytekit/unit/models/core/test_workflow.py index de83f66f78..1b00826d19 100644 --- a/tests/flytekit/unit/models/core/test_workflow.py +++ b/tests/flytekit/unit/models/core/test_workflow.py @@ -1,12 +1,12 @@ from datetime import timedelta -from flytekit.models import interface as _interface -from flytekit.models import literals as _literals -from flytekit.models import types as _types +import flytekit.models.core.types from flytekit.models.core import condition as _condition from flytekit.models.core import identifier as _identifier +from flytekit.models.core import interface as _interface +from flytekit.models.core import literals as _literals from flytekit.models.core import workflow as _workflow -from flytekit.models.task import Resources +from flytekit.models.core.task import Resources _generic_id = _identifier.Identifier(_identifier.ResourceType.WORKFLOW, "project", "domain", "name", "version") @@ -34,7 +34,7 @@ def test_alias(): def test_workflow_template(): task = _workflow.TaskNode(reference_id=_generic_id) nm = _get_sample_node_metadata() - int_type = _types.LiteralType(_types.SimpleType.INTEGER) + int_type = flytekit.models.core.types.LiteralType(flytekit.models.core.types.SimpleType.INTEGER) wf_metadata = _workflow.WorkflowMetadata() wf_metadata_defaults = _workflow.WorkflowMetadataDefaults() typed_interface = _interface.TypedInterface( diff --git a/tests/flytekit/unit/models/test_common.py b/tests/flytekit/unit/models/test_common.py index 48e8c0ba1f..ed9478f5d2 100644 --- a/tests/flytekit/unit/models/test_common.py +++ b/tests/flytekit/unit/models/test_common.py @@ -1,25 +1,27 @@ -from flytekit.models import common as _common +import flytekit.models.admin.common +import flytekit.models.admin.launch_plan +from flytekit.models.admin import common as _admin_common from flytekit.models.core import execution as _execution def test_notification_email(): - obj = _common.EmailNotification(["a", "b", "c"]) + obj = _admin_common.EmailNotification(["a", "b", "c"]) assert obj.recipients_email == ["a", "b", "c"] - obj2 = _common.EmailNotification.from_flyte_idl(obj.to_flyte_idl()) + obj2 = _admin_common.EmailNotification.from_flyte_idl(obj.to_flyte_idl()) assert obj2 == obj def test_notification_pagerduty(): - obj = _common.PagerDutyNotification(["a", "b", "c"]) + obj = _admin_common.PagerDutyNotification(["a", "b", "c"]) assert obj.recipients_email == ["a", "b", "c"] - obj2 = _common.PagerDutyNotification.from_flyte_idl(obj.to_flyte_idl()) + obj2 = _admin_common.PagerDutyNotification.from_flyte_idl(obj.to_flyte_idl()) assert obj2 == obj def test_notification_slack(): - obj = _common.SlackNotification(["a", "b", "c"]) + obj = _admin_common.SlackNotification(["a", "b", "c"]) assert obj.recipients_email == ["a", "b", "c"] - obj2 = _common.SlackNotification.from_flyte_idl(obj.to_flyte_idl()) + obj2 = _admin_common.SlackNotification.from_flyte_idl(obj.to_flyte_idl()) assert obj2 == obj @@ -30,76 +32,78 @@ def test_notification(): ] recipients = ["a", "b", "c"] - obj = _common.Notification(phases, email=_common.EmailNotification(recipients)) + obj = _admin_common.Notification(phases, email=_admin_common.EmailNotification(recipients)) assert obj.phases == phases assert obj.email.recipients_email == recipients - obj2 = _common.Notification.from_flyte_idl(obj.to_flyte_idl()) + obj2 = _admin_common.Notification.from_flyte_idl(obj.to_flyte_idl()) assert obj == obj2 assert obj2.phases == phases assert obj2.email.recipients_email == recipients - obj = _common.Notification(phases, pager_duty=_common.PagerDutyNotification(recipients)) + obj = _admin_common.Notification(phases, pager_duty=_admin_common.PagerDutyNotification(recipients)) assert obj.phases == phases assert obj.pager_duty.recipients_email == recipients - obj2 = _common.Notification.from_flyte_idl(obj.to_flyte_idl()) + obj2 = _admin_common.Notification.from_flyte_idl(obj.to_flyte_idl()) assert obj == obj2 assert obj2.phases == phases assert obj2.pager_duty.recipients_email == recipients - obj = _common.Notification(phases, slack=_common.SlackNotification(recipients)) + obj = _admin_common.Notification(phases, slack=_admin_common.SlackNotification(recipients)) assert obj.phases == phases assert obj.slack.recipients_email == recipients - obj2 = _common.Notification.from_flyte_idl(obj.to_flyte_idl()) + obj2 = _admin_common.Notification.from_flyte_idl(obj.to_flyte_idl()) assert obj == obj2 assert obj2.phases == phases assert obj2.slack.recipients_email == recipients def test_labels(): - obj = _common.Labels({"my": "label"}) + obj = _admin_common.Labels({"my": "label"}) assert obj.values == {"my": "label"} - obj2 = _common.Labels.from_flyte_idl(obj.to_flyte_idl()) + obj2 = _admin_common.Labels.from_flyte_idl(obj.to_flyte_idl()) assert obj2 == obj def test_annotations(): - obj = _common.Annotations({"my": "annotation"}) + obj = _admin_common.Annotations({"my": "annotation"}) assert obj.values == {"my": "annotation"} - obj2 = _common.Annotations.from_flyte_idl(obj.to_flyte_idl()) + obj2 = _admin_common.Annotations.from_flyte_idl(obj.to_flyte_idl()) assert obj2 == obj def test_auth_role(): - obj = _common.AuthRole(assumable_iam_role="rollie-pollie") + obj = flytekit.models.admin.common.AuthRole(assumable_iam_role="rollie-pollie") assert obj.assumable_iam_role == "rollie-pollie" assert not obj.kubernetes_service_account - obj2 = _common.AuthRole.from_flyte_idl(obj.to_flyte_idl()) + obj2 = flytekit.models.admin.common.AuthRole.from_flyte_idl(obj.to_flyte_idl()) assert obj == obj2 - obj = _common.AuthRole(kubernetes_service_account="service-account-name") + obj = flytekit.models.admin.common.AuthRole(kubernetes_service_account="service-account-name") assert obj.kubernetes_service_account == "service-account-name" assert not obj.assumable_iam_role - obj2 = _common.AuthRole.from_flyte_idl(obj.to_flyte_idl()) + obj2 = flytekit.models.admin.common.AuthRole.from_flyte_idl(obj.to_flyte_idl()) assert obj == obj2 - obj = _common.AuthRole(assumable_iam_role="rollie-pollie", kubernetes_service_account="service-account-name") + obj = flytekit.models.admin.common.AuthRole( + assumable_iam_role="rollie-pollie", kubernetes_service_account="service-account-name" + ) assert obj.assumable_iam_role == "rollie-pollie" assert obj.kubernetes_service_account == "service-account-name" - obj2 = _common.AuthRole.from_flyte_idl(obj.to_flyte_idl()) + obj2 = flytekit.models.admin.common.AuthRole.from_flyte_idl(obj.to_flyte_idl()) assert obj == obj2 def test_raw_output_data_config(): - obj = _common.RawOutputDataConfig("s3://bucket") + obj = _admin_common.RawOutputDataConfig("s3://bucket") assert obj.output_location_prefix == "s3://bucket" - obj2 = _common.RawOutputDataConfig.from_flyte_idl(obj.to_flyte_idl()) + obj2 = _admin_common.RawOutputDataConfig.from_flyte_idl(obj.to_flyte_idl()) assert obj2 == obj def test_auth_role_empty(): # This test is here to ensure we can serialize launch plans with an empty auth role. # Auth roles are empty because they are filled in at registration time. - obj = _common.AuthRole() + obj = flytekit.models.admin.common.AuthRole() x = obj.to_flyte_idl() - y = _common.AuthRole.from_flyte_idl(x) + y = flytekit.models.admin.common.AuthRole.from_flyte_idl(x) assert y == obj diff --git a/tests/flytekit/unit/models/test_dynamic_job.py b/tests/flytekit/unit/models/test_dynamic_job.py index 1aff800abd..efd5184c01 100644 --- a/tests/flytekit/unit/models/test_dynamic_job.py +++ b/tests/flytekit/unit/models/test_dynamic_job.py @@ -4,22 +4,23 @@ import pytest from google.protobuf import text_format -from flytekit.models import array_job as _array_job -from flytekit.models import dynamic_job as _dynamic_job -from flytekit.models import literals as _literals -from flytekit.models import task as _task +import flytekit.models.core.task +from flytekit.models.core import dynamic_job as _dynamic_job from flytekit.models.core import identifier as _identifier +from flytekit.models.core import literals as _literals from flytekit.models.core import workflow as _workflow +from flytekit.models.core.task import TaskTemplate as _taskTemplate +from flytekit.models.plugins import array_job as _array_job from tests.flytekit.common import parameterizers LIST_OF_DYNAMIC_TASKS = [ - _task.TaskTemplate( + _taskTemplate( _identifier.Identifier(_identifier.ResourceType.TASK, "p", "d", "n", "v"), "python", task_metadata, interfaces, _array_job.ArrayJob(2, 2, 2).to_dict(), - container=_task.Container( + container=flytekit.models.core.task.Container( "my_image", ["this", "is", "a", "cmd"], ["this", "is", "an", "arg"], diff --git a/tests/flytekit/unit/models/test_execution.py b/tests/flytekit/unit/models/test_execution.py index c0fdf5ba2a..91d7c07bdc 100644 --- a/tests/flytekit/unit/models/test_execution.py +++ b/tests/flytekit/unit/models/test_execution.py @@ -1,10 +1,10 @@ import pytest -from flytekit.models import common as _common_models -from flytekit.models import execution as _execution -from flytekit.models import literals as _literals +from flytekit.models.admin import common as _common_models +from flytekit.models.admin import execution as _execution from flytekit.models.core import execution as _core_exec from flytekit.models.core import identifier as _identifier +from flytekit.models.core import literals as _literals from tests.flytekit.common import parameterizers as _parameterizers _INPUT_MAP = _literals.LiteralMap( diff --git a/tests/flytekit/unit/models/test_interface.py b/tests/flytekit/unit/models/test_interface.py index 03f89f69e4..ac026d2d85 100644 --- a/tests/flytekit/unit/models/test_interface.py +++ b/tests/flytekit/unit/models/test_interface.py @@ -1,6 +1,7 @@ import pytest -from flytekit.models import interface, types +import flytekit.models.core.types +from flytekit.models.core import interface from tests.flytekit.common.parameterizers import LIST_OF_ALL_LITERAL_TYPES @@ -43,7 +44,9 @@ def test_typed_interface(literal_type): def test_parameter(): - v = interface.Variable(types.LiteralType(simple=types.SimpleType.BOOLEAN), "asdf asdf asdf") + v = interface.Variable( + flytekit.models.core.types.LiteralType(simple=flytekit.models.core.types.SimpleType.BOOLEAN), "asdf asdf asdf" + ) obj = interface.Parameter(var=v) assert obj.var == v @@ -53,7 +56,9 @@ def test_parameter(): def test_parameter_map(): - v = interface.Variable(types.LiteralType(simple=types.SimpleType.BOOLEAN), "asdf asdf asdf") + v = interface.Variable( + flytekit.models.core.types.LiteralType(simple=flytekit.models.core.types.SimpleType.BOOLEAN), "asdf asdf asdf" + ) p = interface.Parameter(var=v) obj = interface.ParameterMap({"ppp": p}) @@ -62,7 +67,9 @@ def test_parameter_map(): def test_variable_map(): - v = interface.Variable(types.LiteralType(simple=types.SimpleType.BOOLEAN), "asdf asdf asdf") + v = interface.Variable( + flytekit.models.core.types.LiteralType(simple=flytekit.models.core.types.SimpleType.BOOLEAN), "asdf asdf asdf" + ) obj = interface.VariableMap({"vvv": v}) obj2 = interface.VariableMap.from_flyte_idl(obj.to_flyte_idl()) diff --git a/tests/flytekit/unit/models/test_launch_plan.py b/tests/flytekit/unit/models/test_launch_plan.py index 6c0b8b2f13..5e81bfdd4f 100644 --- a/tests/flytekit/unit/models/test_launch_plan.py +++ b/tests/flytekit/unit/models/test_launch_plan.py @@ -1,7 +1,11 @@ from flyteidl.admin import launch_plan_pb2 as _launch_plan_idl -from flytekit.models import common, interface, launch_plan, literals, schedule, types -from flytekit.models.core import identifier +import flytekit.models.admin.common +import flytekit.models.admin.launch_plan +import flytekit.models.core.types +from flytekit.models.admin import common as _common +from flytekit.models.admin import launch_plan, schedule +from flytekit.models.core import identifier, interface, literals def test_metadata(): @@ -20,7 +24,9 @@ def test_metadata_schedule(): def test_lp_closure(): - v = interface.Variable(types.LiteralType(simple=types.SimpleType.BOOLEAN), "asdf asdf asdf") + v = interface.Variable( + flytekit.models.core.types.LiteralType(simple=flytekit.models.core.types.SimpleType.BOOLEAN), "asdf asdf asdf" + ) p = interface.Parameter(var=v) parameter_map = interface.ParameterMap({"ppp": p}) parameter_map.to_flyte_idl() @@ -45,7 +51,9 @@ def test_launch_plan_spec(): s = schedule.Schedule("asdf", "1 3 4 5 6 7") launch_plan_metadata_model = launch_plan.LaunchPlanMetadata(schedule=s, notifications=[]) - v = interface.Variable(types.LiteralType(simple=types.SimpleType.BOOLEAN), "asdf asdf asdf") + v = interface.Variable( + flytekit.models.core.types.LiteralType(simple=flytekit.models.core.types.SimpleType.BOOLEAN), "asdf asdf asdf" + ) p = interface.Parameter(var=v) parameter_map = interface.ParameterMap({"ppp": p}) @@ -53,12 +61,12 @@ def test_launch_plan_spec(): {"a": literals.Literal(scalar=literals.Scalar(primitive=literals.Primitive(integer=1)))} ) - labels_model = common.Labels({}) - annotations_model = common.Annotations({"my": "annotation"}) + labels_model = _common.Labels({}) + annotations_model = _common.Annotations({"my": "annotation"}) - auth_role_model = common.AuthRole(assumable_iam_role="my:iam:role") - raw_data_output_config = common.RawOutputDataConfig("s3://bucket") - empty_raw_data_output_config = common.RawOutputDataConfig("") + auth_role_model = flytekit.models.admin.common.AuthRole(assumable_iam_role="my:iam:role") + raw_data_output_config = _common.RawOutputDataConfig("s3://bucket") + empty_raw_data_output_config = _common.RawOutputDataConfig("") max_parallelism = 100 lp_spec_raw_output_prefixed = launch_plan.LaunchPlanSpec( @@ -98,7 +106,9 @@ def test_old_style_role(): s = schedule.Schedule("asdf", "1 3 4 5 6 7") launch_plan_metadata_model = launch_plan.LaunchPlanMetadata(schedule=s, notifications=[]) - v = interface.Variable(types.LiteralType(simple=types.SimpleType.BOOLEAN), "asdf asdf asdf") + v = interface.Variable( + flytekit.models.core.types.LiteralType(simple=flytekit.models.core.types.SimpleType.BOOLEAN), "asdf asdf asdf" + ) p = interface.Parameter(var=v) parameter_map = interface.ParameterMap({"ppp": p}) @@ -106,10 +116,10 @@ def test_old_style_role(): {"a": literals.Literal(scalar=literals.Scalar(primitive=literals.Primitive(integer=1)))} ) - labels_model = common.Labels({}) - annotations_model = common.Annotations({"my": "annotation"}) + labels_model = _common.Labels({}) + annotations_model = _common.Annotations({"my": "annotation"}) - raw_data_output_config = common.RawOutputDataConfig("s3://bucket") + raw_data_output_config = _common.RawOutputDataConfig("s3://bucket") old_role = _launch_plan_idl.Auth(kubernetes_service_account="my:service:account") diff --git a/tests/flytekit/unit/models/test_literals.py b/tests/flytekit/unit/models/test_literals.py index b6eec52cd7..6decac0ac9 100644 --- a/tests/flytekit/unit/models/test_literals.py +++ b/tests/flytekit/unit/models/test_literals.py @@ -3,8 +3,8 @@ import pytest import pytz -from flytekit.models import literals -from flytekit.models import types as _types +import flytekit.models.core.types +from flytekit.models.core import literals from tests.flytekit.common import parameterizers @@ -324,14 +324,26 @@ def test_scalar_binary(): def test_scalar_schema(): - schema_type = _types.SchemaType( + schema_type = flytekit.models.core.types.SchemaType( [ - _types.SchemaType.SchemaColumn("a", _types.SchemaType.SchemaColumn.SchemaColumnType.INTEGER), - _types.SchemaType.SchemaColumn("b", _types.SchemaType.SchemaColumn.SchemaColumnType.FLOAT), - _types.SchemaType.SchemaColumn("c", _types.SchemaType.SchemaColumn.SchemaColumnType.STRING), - _types.SchemaType.SchemaColumn("d", _types.SchemaType.SchemaColumn.SchemaColumnType.DATETIME), - _types.SchemaType.SchemaColumn("e", _types.SchemaType.SchemaColumn.SchemaColumnType.DURATION), - _types.SchemaType.SchemaColumn("f", _types.SchemaType.SchemaColumn.SchemaColumnType.BOOLEAN), + flytekit.models.core.types.SchemaType.SchemaColumn( + "a", flytekit.models.core.types.SchemaType.SchemaColumn.SchemaColumnType.INTEGER + ), + flytekit.models.core.types.SchemaType.SchemaColumn( + "b", flytekit.models.core.types.SchemaType.SchemaColumn.SchemaColumnType.FLOAT + ), + flytekit.models.core.types.SchemaType.SchemaColumn( + "c", flytekit.models.core.types.SchemaType.SchemaColumn.SchemaColumnType.STRING + ), + flytekit.models.core.types.SchemaType.SchemaColumn( + "d", flytekit.models.core.types.SchemaType.SchemaColumn.SchemaColumnType.DATETIME + ), + flytekit.models.core.types.SchemaType.SchemaColumn( + "e", flytekit.models.core.types.SchemaType.SchemaColumn.SchemaColumnType.DURATION + ), + flytekit.models.core.types.SchemaType.SchemaColumn( + "f", flytekit.models.core.types.SchemaType.SchemaColumn.SchemaColumnType.BOOLEAN + ), ] ) @@ -398,7 +410,7 @@ def test_binding_data_map(): def test_binding_data_promise(): - obj = literals.BindingData(promise=_types.OutputReference("some_node", "myvar")) + obj = literals.BindingData(promise=flytekit.models.core.types.OutputReference("some_node", "myvar")) assert obj.scalar is None assert obj.promise is not None assert obj.collection is None diff --git a/tests/flytekit/unit/models/test_matchable_resource.py b/tests/flytekit/unit/models/test_matchable_resource.py index 41aa1bf793..ce33fc2ae1 100644 --- a/tests/flytekit/unit/models/test_matchable_resource.py +++ b/tests/flytekit/unit/models/test_matchable_resource.py @@ -1,4 +1,4 @@ -from flytekit.models import matchable_resource +from flytekit.models.admin import matchable_resource def test_cluster_resource_attributes(): diff --git a/tests/flytekit/unit/models/test_named_entity.py b/tests/flytekit/unit/models/test_named_entity.py index bb3db95439..bd56ddca79 100644 --- a/tests/flytekit/unit/models/test_named_entity.py +++ b/tests/flytekit/unit/models/test_named_entity.py @@ -1,13 +1,15 @@ -from flytekit.models import named_entity +import flytekit.models.admin.common def test_identifier(): - obj = named_entity.NamedEntityIdentifier("proj", "development", "MyWorkflow") - obj2 = named_entity.NamedEntityIdentifier.from_flyte_idl(obj.to_flyte_idl()) + obj = flytekit.models.admin.common.NamedEntityIdentifier("proj", "development", "MyWorkflow") + obj2 = flytekit.models.admin.common.NamedEntityIdentifier.from_flyte_idl(obj.to_flyte_idl()) assert obj == obj2 def test_metadata(): - obj = named_entity.NamedEntityMetadata("i am a description", named_entity.NamedEntityState.ACTIVE) - obj2 = named_entity.NamedEntityMetadata.from_flyte_idl(obj.to_flyte_idl()) + obj = flytekit.models.admin.common.NamedEntityMetadata( + "i am a description", flytekit.models.admin.common.NamedEntityState.ACTIVE + ) + obj2 = flytekit.models.admin.common.NamedEntityMetadata.from_flyte_idl(obj.to_flyte_idl()) assert obj == obj2 diff --git a/tests/flytekit/unit/models/test_project.py b/tests/flytekit/unit/models/test_project.py index 74006ad9b2..41ac84703c 100644 --- a/tests/flytekit/unit/models/test_project.py +++ b/tests/flytekit/unit/models/test_project.py @@ -1,6 +1,6 @@ import pytest as _pytest -from flytekit.models import project +from flytekit.models.admin import project def test_project_with_default_state(): diff --git a/tests/flytekit/unit/models/test_qubole.py b/tests/flytekit/unit/models/test_qubole.py index 34655379f8..ca795ef4ec 100644 --- a/tests/flytekit/unit/models/test_qubole.py +++ b/tests/flytekit/unit/models/test_qubole.py @@ -1,4 +1,4 @@ -from flytekit.models import qubole +from flytekit.models.plugins import qubole def test_hive_query(): diff --git a/tests/flytekit/unit/models/test_schedule.py b/tests/flytekit/unit/models/test_schedule.py index b7fad79124..71044b481c 100644 --- a/tests/flytekit/unit/models/test_schedule.py +++ b/tests/flytekit/unit/models/test_schedule.py @@ -1,6 +1,6 @@ import pytest as _pytest -from flytekit.models import schedule as _schedule +from flytekit.models.admin import schedule as _schedule def test_schedule_cron_expression(): diff --git a/tests/flytekit/unit/models/test_tasks.py b/tests/flytekit/unit/models/test_tasks.py index f3ae95fb3a..e29b0e2adc 100644 --- a/tests/flytekit/unit/models/test_tasks.py +++ b/tests/flytekit/unit/models/test_tasks.py @@ -6,65 +6,76 @@ from google.protobuf import text_format from k8s.io.api.core.v1 import generated_pb2 -import flytekit.models.interface as interface_models -import flytekit.models.literals as literal_models -from flytekit.models import literals, task, types -from flytekit.models.core import identifier +import flytekit.models.core.interface as interface_models +import flytekit.models.core.literals as literal_models +import flytekit.models.core.task +import flytekit.models.core.types +from flytekit.models.admin.task import Task as _task +from flytekit.models.core import identifier, literals +from flytekit.models.core.task import RuntimeMetadata as _runtimeMetadata +from flytekit.models.core.task import TaskMetadata as _taskMetadata +from flytekit.models.core.task import TaskTemplate as _taskTemplate from tests.flytekit.common import parameterizers def test_resource_entry(): - obj = task.Resources.ResourceEntry(task.Resources.ResourceName.CPU, "blah") - assert task.Resources.ResourceEntry.from_flyte_idl(obj.to_flyte_idl()) == obj - assert obj != task.Resources.ResourceEntry(task.Resources.ResourceName.GPU, "blah") - assert obj != task.Resources.ResourceEntry(task.Resources.ResourceName.CPU, "bloop") - assert obj.name == task.Resources.ResourceName.CPU + obj = flytekit.models.core.task.Resources.ResourceEntry( + flytekit.models.core.task.Resources.ResourceName.CPU, "blah" + ) + assert flytekit.models.core.task.Resources.ResourceEntry.from_flyte_idl(obj.to_flyte_idl()) == obj + assert obj != flytekit.models.core.task.Resources.ResourceEntry( + flytekit.models.core.task.Resources.ResourceName.GPU, "blah" + ) + assert obj != flytekit.models.core.task.Resources.ResourceEntry( + flytekit.models.core.task.Resources.ResourceName.CPU, "bloop" + ) + assert obj.name == flytekit.models.core.task.Resources.ResourceName.CPU assert obj.value == "blah" @pytest.mark.parametrize("resource_list", parameterizers.LIST_OF_RESOURCE_ENTRY_LISTS) def test_resources(resource_list): - obj = task.Resources(resource_list, resource_list) - obj1 = task.Resources([], resource_list) - obj2 = task.Resources(resource_list, []) + obj = flytekit.models.core.task.Resources(resource_list, resource_list) + obj1 = flytekit.models.core.task.Resources([], resource_list) + obj2 = flytekit.models.core.task.Resources(resource_list, []) assert obj.requests == obj2.requests assert obj.limits == obj1.limits - assert obj == task.Resources.from_flyte_idl(obj.to_flyte_idl()) + assert obj == flytekit.models.core.task.Resources.from_flyte_idl(obj.to_flyte_idl()) def test_runtime_metadata(): - obj = task.RuntimeMetadata(task.RuntimeMetadata.RuntimeType.FLYTE_SDK, "1.0.0", "python") - assert obj.type == task.RuntimeMetadata.RuntimeType.FLYTE_SDK + obj = _runtimeMetadata(_runtimeMetadata.RuntimeType.FLYTE_SDK, "1.0.0", "python") + assert obj.type == _runtimeMetadata.RuntimeType.FLYTE_SDK assert obj.version == "1.0.0" assert obj.flavor == "python" - assert obj == task.RuntimeMetadata.from_flyte_idl(obj.to_flyte_idl()) - assert obj != task.RuntimeMetadata(task.RuntimeMetadata.RuntimeType.FLYTE_SDK, "1.0.1", "python") - assert obj != task.RuntimeMetadata(task.RuntimeMetadata.RuntimeType.OTHER, "1.0.0", "python") - assert obj != task.RuntimeMetadata(task.RuntimeMetadata.RuntimeType.FLYTE_SDK, "1.0.0", "golang") + assert obj == _runtimeMetadata.from_flyte_idl(obj.to_flyte_idl()) + assert obj != _runtimeMetadata(_runtimeMetadata.RuntimeType.FLYTE_SDK, "1.0.1", "python") + assert obj != _runtimeMetadata(_runtimeMetadata.RuntimeType.OTHER, "1.0.0", "python") + assert obj != _runtimeMetadata(_runtimeMetadata.RuntimeType.FLYTE_SDK, "1.0.0", "golang") def test_task_metadata_interruptible_from_flyte_idl(): # Interruptible not set idl = TaskMetadata() - obj = task.TaskMetadata.from_flyte_idl(idl) + obj = _taskMetadata.from_flyte_idl(idl) assert obj.interruptible is None idl = TaskMetadata() idl.interruptible = True - obj = task.TaskMetadata.from_flyte_idl(idl) + obj = _taskMetadata.from_flyte_idl(idl) assert obj.interruptible is True idl = TaskMetadata() idl.interruptible = False - obj = task.TaskMetadata.from_flyte_idl(idl) + obj = _taskMetadata.from_flyte_idl(idl) assert obj.interruptible is False def test_task_metadata(): - obj = task.TaskMetadata( + obj = _taskMetadata( True, - task.RuntimeMetadata(task.RuntimeMetadata.RuntimeType.FLYTE_SDK, "1.0.0", "python"), + _runtimeMetadata(_runtimeMetadata.RuntimeType.FLYTE_SDK, "1.0.0", "python"), timedelta(days=1), literals.RetryStrategy(3), True, @@ -77,11 +88,11 @@ def test_task_metadata(): assert obj.interruptible is True assert obj.timeout == timedelta(days=1) assert obj.runtime.flavor == "python" - assert obj.runtime.type == task.RuntimeMetadata.RuntimeType.FLYTE_SDK + assert obj.runtime.type == _runtimeMetadata.RuntimeType.FLYTE_SDK assert obj.runtime.version == "1.0.0" assert obj.deprecated_error_message == "This is deprecated!" assert obj.discovery_version == "0.1.1b0" - assert obj == task.TaskMetadata.from_flyte_idl(obj.to_flyte_idl()) + assert obj == _taskMetadata.from_flyte_idl(obj.to_flyte_idl()) @pytest.mark.parametrize( @@ -90,13 +101,13 @@ def test_task_metadata(): ) def test_task_template(in_tuple): task_metadata, interfaces, resources = in_tuple - obj = task.TaskTemplate( + obj = _taskTemplate( identifier.Identifier(identifier.ResourceType.TASK, "project", "domain", "name", "version"), "python", task_metadata, interfaces, {"a": 1, "b": {"c": 2, "d": 3}}, - container=task.Container( + container=flytekit.models.core.task.Container( "my_image", ["this", "is", "a", "cmd"], ["this", "is", "an", "arg"], @@ -118,19 +129,19 @@ def test_task_template(in_tuple): assert obj.container.image == "my_image" assert obj.container.resources == resources assert text_format.MessageToString(obj.to_flyte_idl()) == text_format.MessageToString( - task.TaskTemplate.from_flyte_idl(obj.to_flyte_idl()).to_flyte_idl() + _taskTemplate.from_flyte_idl(obj.to_flyte_idl()).to_flyte_idl() ) assert obj.config == {"a": "b"} def test_task_template__k8s_pod_target(): - int_type = types.LiteralType(types.SimpleType.INTEGER) - obj = task.TaskTemplate( + int_type = flytekit.models.core.types.LiteralType(flytekit.models.core.types.SimpleType.INTEGER) + obj = _taskTemplate( identifier.Identifier(identifier.ResourceType.TASK, "project", "domain", "name", "version"), "python", - task.TaskMetadata( + _taskMetadata( False, - task.RuntimeMetadata(1, "v", "f"), + _runtimeMetadata(1, "v", "f"), timedelta(days=1), literal_models.RetryStrategy(5), False, @@ -148,8 +159,8 @@ def test_task_template__k8s_pod_target(): ), {"a": 1, "b": {"c": 2, "d": 3}}, config={"a": "b"}, - k8s_pod=task.K8sPod( - metadata=task.K8sObjectMetadata(labels={"label": "foo"}, annotations={"anno": "bar"}), + k8s_pod=flytekit.models.core.task.K8sPod( + metadata=flytekit.models.core.task.K8sObjectMetadata(labels={"label": "foo"}, annotations={"anno": "bar"}), pod_spec={"str": "val", "int": 1}, ), ) @@ -160,23 +171,25 @@ def test_task_template__k8s_pod_target(): assert obj.id.version == "version" assert obj.type == "python" assert obj.custom == {"a": 1, "b": {"c": 2, "d": 3}} - assert obj.k8s_pod.metadata == task.K8sObjectMetadata(labels={"label": "foo"}, annotations={"anno": "bar"}) + assert obj.k8s_pod.metadata == flytekit.models.core.task.K8sObjectMetadata( + labels={"label": "foo"}, annotations={"anno": "bar"} + ) assert obj.k8s_pod.pod_spec == {"str": "val", "int": 1} assert text_format.MessageToString(obj.to_flyte_idl()) == text_format.MessageToString( - task.TaskTemplate.from_flyte_idl(obj.to_flyte_idl()).to_flyte_idl() + _taskTemplate.from_flyte_idl(obj.to_flyte_idl()).to_flyte_idl() ) assert obj.config == {"a": "b"} @pytest.mark.parametrize("sec_ctx", parameterizers.LIST_OF_SECURITY_CONTEXT) def test_task_template_security_context(sec_ctx): - obj = task.TaskTemplate( + obj = _taskTemplate( identifier.Identifier(identifier.ResourceType.TASK, "project", "domain", "name", "version"), "python", parameterizers.LIST_OF_TASK_METADATA[0], parameterizers.LIST_OF_INTERFACES[0], {"a": 1, "b": {"c": 2, "d": 3}}, - container=task.Container( + container=flytekit.models.core.task.Container( "my_image", ["this", "is", "a", "cmd"], ["this", "is", "an", "arg"], @@ -188,13 +201,13 @@ def test_task_template_security_context(sec_ctx): ) assert obj.security_context == sec_ctx assert text_format.MessageToString(obj.to_flyte_idl()) == text_format.MessageToString( - task.TaskTemplate.from_flyte_idl(obj.to_flyte_idl()).to_flyte_idl() + _taskTemplate.from_flyte_idl(obj.to_flyte_idl()).to_flyte_idl() ) @pytest.mark.parametrize("task_closure", parameterizers.LIST_OF_TASK_CLOSURES) def test_task(task_closure): - obj = task.Task( + obj = _task( identifier.Identifier(identifier.ResourceType.TASK, "project", "domain", "name", "version"), task_closure, ) @@ -203,12 +216,12 @@ def test_task(task_closure): assert obj.id.name == "name" assert obj.id.version == "version" assert obj.closure == task_closure - assert obj == task.Task.from_flyte_idl(obj.to_flyte_idl()) + assert obj == _task.from_flyte_idl(obj.to_flyte_idl()) @pytest.mark.parametrize("resources", parameterizers.LIST_OF_RESOURCES) def test_container(resources): - obj = task.Container( + obj = flytekit.models.core.task.Container( "my_image", ["this", "is", "a", "cmd"], ["this", "is", "an", "arg"], @@ -222,14 +235,14 @@ def test_container(resources): obj.resources == resources obj.env == {"a": "b"} obj.config == {"d": "e"} - assert obj == task.Container.from_flyte_idl(obj.to_flyte_idl()) + assert obj == flytekit.models.core.task.Container.from_flyte_idl(obj.to_flyte_idl()) def test_sidecar_task(): pod_spec = generated_pb2.PodSpec() container = generated_pb2.Container(name="containery") pod_spec.containers.extend([container]) - obj = task.SidecarJob( + obj = flytekit.models.core.task.SidecarJob( pod_spec=pod_spec, primary_container_name="primary", annotations={"a1": "a1"}, @@ -241,55 +254,60 @@ def test_sidecar_task(): assert obj.annotations["a1"] == "a1" assert obj.labels["b1"] == "b1" - obj2 = task.SidecarJob.from_flyte_idl(obj.to_flyte_idl()) + obj2 = flytekit.models.core.task.SidecarJob.from_flyte_idl(obj.to_flyte_idl()) assert obj2 == obj def test_sidecar_task_label_annotation_not_provided(): pod_spec = generated_pb2.PodSpec() - obj = task.SidecarJob(pod_spec=pod_spec, primary_container_name="primary") + obj = flytekit.models.core.task.SidecarJob(pod_spec=pod_spec, primary_container_name="primary") assert obj.primary_container_name == "primary" - obj2 = task.SidecarJob.from_flyte_idl(obj.to_flyte_idl()) + obj2 = flytekit.models.core.task.SidecarJob.from_flyte_idl(obj.to_flyte_idl()) assert obj2 == obj def test_dataloadingconfig(): - dlc = task.DataLoadingConfig( + dlc = flytekit.models.core.task.DataLoadingConfig( "s3://input/path", "s3://output/path", True, - task.DataLoadingConfig.LITERALMAP_FORMAT_YAML, + flytekit.models.core.task.DataLoadingConfig.LITERALMAP_FORMAT_YAML, ) - dlc2 = task.DataLoadingConfig.from_flyte_idl(dlc.to_flyte_idl()) + dlc2 = flytekit.models.core.task.DataLoadingConfig.from_flyte_idl(dlc.to_flyte_idl()) assert dlc2 == dlc - dlc = task.DataLoadingConfig( + dlc = flytekit.models.core.task.DataLoadingConfig( "s3://input/path", "s3://output/path", True, - task.DataLoadingConfig.LITERALMAP_FORMAT_YAML, - io_strategy=task.IOStrategy(), + flytekit.models.core.task.DataLoadingConfig.LITERALMAP_FORMAT_YAML, + io_strategy=flytekit.models.core.task.IOStrategy(), ) - dlc2 = task.DataLoadingConfig.from_flyte_idl(dlc.to_flyte_idl()) + dlc2 = flytekit.models.core.task.DataLoadingConfig.from_flyte_idl(dlc.to_flyte_idl()) assert dlc2 == dlc def test_ioconfig(): - io = task.IOStrategy(task.IOStrategy.DOWNLOAD_MODE_NO_DOWNLOAD, task.IOStrategy.UPLOAD_MODE_NO_UPLOAD) - assert io == task.IOStrategy.from_flyte_idl(io.to_flyte_idl()) + io = flytekit.models.core.task.IOStrategy( + flytekit.models.core.task.IOStrategy.DOWNLOAD_MODE_NO_DOWNLOAD, + flytekit.models.core.task.IOStrategy.UPLOAD_MODE_NO_UPLOAD, + ) + assert io == flytekit.models.core.task.IOStrategy.from_flyte_idl(io.to_flyte_idl()) def test_k8s_metadata(): - obj = task.K8sObjectMetadata(labels={"label": "foo"}, annotations={"anno": "bar"}) + obj = flytekit.models.core.task.K8sObjectMetadata(labels={"label": "foo"}, annotations={"anno": "bar"}) assert obj.labels == {"label": "foo"} assert obj.annotations == {"anno": "bar"} - assert obj == task.K8sObjectMetadata.from_flyte_idl(obj.to_flyte_idl()) + assert obj == flytekit.models.core.task.K8sObjectMetadata.from_flyte_idl(obj.to_flyte_idl()) def test_k8s_pod(): - obj = task.K8sPod(metadata=task.K8sObjectMetadata(labels={"label": "foo"}), pod_spec={"pod_spec": "bar"}) + obj = flytekit.models.core.task.K8sPod( + metadata=flytekit.models.core.task.K8sObjectMetadata(labels={"label": "foo"}), pod_spec={"pod_spec": "bar"} + ) assert obj.metadata.labels == {"label": "foo"} assert obj.pod_spec == {"pod_spec": "bar"} - assert obj == task.K8sPod.from_flyte_idl(obj.to_flyte_idl()) + assert obj == flytekit.models.core.task.K8sPod.from_flyte_idl(obj.to_flyte_idl()) diff --git a/tests/flytekit/unit/models/test_types.py b/tests/flytekit/unit/models/test_types.py index 27b5ddf595..42a66028d1 100644 --- a/tests/flytekit/unit/models/test_types.py +++ b/tests/flytekit/unit/models/test_types.py @@ -1,40 +1,70 @@ import pytest from flyteidl.core import types_pb2 -from flytekit.models import types as _types +import flytekit.models.core.types from tests.flytekit.common import parameterizers def test_simple_type(): - assert _types.SimpleType.NONE == types_pb2.NONE - assert _types.SimpleType.INTEGER == types_pb2.INTEGER - assert _types.SimpleType.FLOAT == types_pb2.FLOAT - assert _types.SimpleType.STRING == types_pb2.STRING - assert _types.SimpleType.BOOLEAN == types_pb2.BOOLEAN - assert _types.SimpleType.DATETIME == types_pb2.DATETIME - assert _types.SimpleType.DURATION == types_pb2.DURATION - assert _types.SimpleType.BINARY == types_pb2.BINARY - assert _types.SimpleType.ERROR == types_pb2.ERROR + assert flytekit.models.core.types.SimpleType.NONE == types_pb2.NONE + assert flytekit.models.core.types.SimpleType.INTEGER == types_pb2.INTEGER + assert flytekit.models.core.types.SimpleType.FLOAT == types_pb2.FLOAT + assert flytekit.models.core.types.SimpleType.STRING == types_pb2.STRING + assert flytekit.models.core.types.SimpleType.BOOLEAN == types_pb2.BOOLEAN + assert flytekit.models.core.types.SimpleType.DATETIME == types_pb2.DATETIME + assert flytekit.models.core.types.SimpleType.DURATION == types_pb2.DURATION + assert flytekit.models.core.types.SimpleType.BINARY == types_pb2.BINARY + assert flytekit.models.core.types.SimpleType.ERROR == types_pb2.ERROR def test_schema_column(): - assert _types.SchemaType.SchemaColumn.SchemaColumnType.INTEGER == types_pb2.SchemaType.SchemaColumn.INTEGER - assert _types.SchemaType.SchemaColumn.SchemaColumnType.FLOAT == types_pb2.SchemaType.SchemaColumn.FLOAT - assert _types.SchemaType.SchemaColumn.SchemaColumnType.STRING == types_pb2.SchemaType.SchemaColumn.STRING - assert _types.SchemaType.SchemaColumn.SchemaColumnType.DATETIME == types_pb2.SchemaType.SchemaColumn.DATETIME - assert _types.SchemaType.SchemaColumn.SchemaColumnType.DURATION == types_pb2.SchemaType.SchemaColumn.DURATION - assert _types.SchemaType.SchemaColumn.SchemaColumnType.BOOLEAN == types_pb2.SchemaType.SchemaColumn.BOOLEAN + assert ( + flytekit.models.core.types.SchemaType.SchemaColumn.SchemaColumnType.INTEGER + == types_pb2.SchemaType.SchemaColumn.INTEGER + ) + assert ( + flytekit.models.core.types.SchemaType.SchemaColumn.SchemaColumnType.FLOAT + == types_pb2.SchemaType.SchemaColumn.FLOAT + ) + assert ( + flytekit.models.core.types.SchemaType.SchemaColumn.SchemaColumnType.STRING + == types_pb2.SchemaType.SchemaColumn.STRING + ) + assert ( + flytekit.models.core.types.SchemaType.SchemaColumn.SchemaColumnType.DATETIME + == types_pb2.SchemaType.SchemaColumn.DATETIME + ) + assert ( + flytekit.models.core.types.SchemaType.SchemaColumn.SchemaColumnType.DURATION + == types_pb2.SchemaType.SchemaColumn.DURATION + ) + assert ( + flytekit.models.core.types.SchemaType.SchemaColumn.SchemaColumnType.BOOLEAN + == types_pb2.SchemaType.SchemaColumn.BOOLEAN + ) def test_schema_type(): - obj = _types.SchemaType( + obj = flytekit.models.core.types.SchemaType( [ - _types.SchemaType.SchemaColumn("a", _types.SchemaType.SchemaColumn.SchemaColumnType.INTEGER), - _types.SchemaType.SchemaColumn("b", _types.SchemaType.SchemaColumn.SchemaColumnType.FLOAT), - _types.SchemaType.SchemaColumn("c", _types.SchemaType.SchemaColumn.SchemaColumnType.STRING), - _types.SchemaType.SchemaColumn("d", _types.SchemaType.SchemaColumn.SchemaColumnType.DATETIME), - _types.SchemaType.SchemaColumn("e", _types.SchemaType.SchemaColumn.SchemaColumnType.DURATION), - _types.SchemaType.SchemaColumn("f", _types.SchemaType.SchemaColumn.SchemaColumnType.BOOLEAN), + flytekit.models.core.types.SchemaType.SchemaColumn( + "a", flytekit.models.core.types.SchemaType.SchemaColumn.SchemaColumnType.INTEGER + ), + flytekit.models.core.types.SchemaType.SchemaColumn( + "b", flytekit.models.core.types.SchemaType.SchemaColumn.SchemaColumnType.FLOAT + ), + flytekit.models.core.types.SchemaType.SchemaColumn( + "c", flytekit.models.core.types.SchemaType.SchemaColumn.SchemaColumnType.STRING + ), + flytekit.models.core.types.SchemaType.SchemaColumn( + "d", flytekit.models.core.types.SchemaType.SchemaColumn.SchemaColumnType.DATETIME + ), + flytekit.models.core.types.SchemaType.SchemaColumn( + "e", flytekit.models.core.types.SchemaType.SchemaColumn.SchemaColumnType.DURATION + ), + flytekit.models.core.types.SchemaType.SchemaColumn( + "f", flytekit.models.core.types.SchemaType.SchemaColumn.SchemaColumnType.BOOLEAN + ), ] ) @@ -45,55 +75,67 @@ def test_schema_type(): assert obj.columns[4].name == "e" assert obj.columns[5].name == "f" - assert obj.columns[0].type == _types.SchemaType.SchemaColumn.SchemaColumnType.INTEGER - assert obj.columns[1].type == _types.SchemaType.SchemaColumn.SchemaColumnType.FLOAT - assert obj.columns[2].type == _types.SchemaType.SchemaColumn.SchemaColumnType.STRING - assert obj.columns[3].type == _types.SchemaType.SchemaColumn.SchemaColumnType.DATETIME - assert obj.columns[4].type == _types.SchemaType.SchemaColumn.SchemaColumnType.DURATION - assert obj.columns[5].type == _types.SchemaType.SchemaColumn.SchemaColumnType.BOOLEAN + assert obj.columns[0].type == flytekit.models.core.types.SchemaType.SchemaColumn.SchemaColumnType.INTEGER + assert obj.columns[1].type == flytekit.models.core.types.SchemaType.SchemaColumn.SchemaColumnType.FLOAT + assert obj.columns[2].type == flytekit.models.core.types.SchemaType.SchemaColumn.SchemaColumnType.STRING + assert obj.columns[3].type == flytekit.models.core.types.SchemaType.SchemaColumn.SchemaColumnType.DATETIME + assert obj.columns[4].type == flytekit.models.core.types.SchemaType.SchemaColumn.SchemaColumnType.DURATION + assert obj.columns[5].type == flytekit.models.core.types.SchemaType.SchemaColumn.SchemaColumnType.BOOLEAN - assert obj == _types.SchemaType.from_flyte_idl(obj.to_flyte_idl()) + assert obj == flytekit.models.core.types.SchemaType.from_flyte_idl(obj.to_flyte_idl()) def test_literal_types(): - obj = _types.LiteralType(simple=_types.SimpleType.INTEGER) - assert obj.simple == _types.SimpleType.INTEGER + obj = flytekit.models.core.types.LiteralType(simple=flytekit.models.core.types.SimpleType.INTEGER) + assert obj.simple == flytekit.models.core.types.SimpleType.INTEGER assert obj.schema is None assert obj.collection_type is None assert obj.map_value_type is None - assert obj == _types.LiteralType.from_flyte_idl(obj.to_flyte_idl()) + assert obj == flytekit.models.core.types.LiteralType.from_flyte_idl(obj.to_flyte_idl()) - schema_type = _types.SchemaType( + schema_type = flytekit.models.core.types.SchemaType( [ - _types.SchemaType.SchemaColumn("a", _types.SchemaType.SchemaColumn.SchemaColumnType.INTEGER), - _types.SchemaType.SchemaColumn("b", _types.SchemaType.SchemaColumn.SchemaColumnType.FLOAT), - _types.SchemaType.SchemaColumn("c", _types.SchemaType.SchemaColumn.SchemaColumnType.STRING), - _types.SchemaType.SchemaColumn("d", _types.SchemaType.SchemaColumn.SchemaColumnType.DATETIME), - _types.SchemaType.SchemaColumn("e", _types.SchemaType.SchemaColumn.SchemaColumnType.DURATION), - _types.SchemaType.SchemaColumn("f", _types.SchemaType.SchemaColumn.SchemaColumnType.BOOLEAN), + flytekit.models.core.types.SchemaType.SchemaColumn( + "a", flytekit.models.core.types.SchemaType.SchemaColumn.SchemaColumnType.INTEGER + ), + flytekit.models.core.types.SchemaType.SchemaColumn( + "b", flytekit.models.core.types.SchemaType.SchemaColumn.SchemaColumnType.FLOAT + ), + flytekit.models.core.types.SchemaType.SchemaColumn( + "c", flytekit.models.core.types.SchemaType.SchemaColumn.SchemaColumnType.STRING + ), + flytekit.models.core.types.SchemaType.SchemaColumn( + "d", flytekit.models.core.types.SchemaType.SchemaColumn.SchemaColumnType.DATETIME + ), + flytekit.models.core.types.SchemaType.SchemaColumn( + "e", flytekit.models.core.types.SchemaType.SchemaColumn.SchemaColumnType.DURATION + ), + flytekit.models.core.types.SchemaType.SchemaColumn( + "f", flytekit.models.core.types.SchemaType.SchemaColumn.SchemaColumnType.BOOLEAN + ), ] ) - obj = _types.LiteralType(schema=schema_type) + obj = flytekit.models.core.types.LiteralType(schema=schema_type) assert obj.simple is None assert obj.schema == schema_type assert obj.collection_type is None assert obj.map_value_type is None - assert obj == _types.LiteralType.from_flyte_idl(obj.to_flyte_idl()) + assert obj == flytekit.models.core.types.LiteralType.from_flyte_idl(obj.to_flyte_idl()) @pytest.mark.parametrize("literal_type", parameterizers.LIST_OF_ALL_LITERAL_TYPES) def test_literal_collections(literal_type): - obj = _types.LiteralType(collection_type=literal_type) + obj = flytekit.models.core.types.LiteralType(collection_type=literal_type) assert obj.collection_type == literal_type assert obj.simple is None assert obj.schema is None assert obj.map_value_type is None - assert obj == _types.LiteralType.from_flyte_idl(obj.to_flyte_idl()) + assert obj == flytekit.models.core.types.LiteralType.from_flyte_idl(obj.to_flyte_idl()) def test_output_reference(): - obj = _types.OutputReference(node_id="node1", var="var1") + obj = flytekit.models.core.types.OutputReference(node_id="node1", var="var1") assert obj.node_id == "node1" assert obj.var == "var1" - obj2 = _types.OutputReference.from_flyte_idl(obj.to_flyte_idl()) + obj2 = flytekit.models.core.types.OutputReference.from_flyte_idl(obj.to_flyte_idl()) assert obj == obj2 diff --git a/tests/flytekit/unit/models/test_workflow_closure.py b/tests/flytekit/unit/models/test_workflow_closure.py index 3e19a80657..da65a4df46 100644 --- a/tests/flytekit/unit/models/test_workflow_closure.py +++ b/tests/flytekit/unit/models/test_workflow_closure.py @@ -1,16 +1,19 @@ from datetime import timedelta -from flytekit.models import interface as _interface -from flytekit.models import literals as _literals -from flytekit.models import task as _task -from flytekit.models import types as _types -from flytekit.models import workflow_closure as _workflow_closure +import flytekit.models.core.task +import flytekit.models.core.types from flytekit.models.core import identifier as _identifier +from flytekit.models.core import interface as _interface +from flytekit.models.core import literals as _literals from flytekit.models.core import workflow as _workflow +from flytekit.models.core import workflow_closure as _workflow_closure +from flytekit.models.core.task import RuntimeMetadata as _runtimeMetadata +from flytekit.models.core.task import TaskMetadata as _taskMetadata +from flytekit.models.core.task import TaskTemplate as _taskTemplate def test_workflow_closure(): - int_type = _types.LiteralType(_types.SimpleType.INTEGER) + int_type = flytekit.models.core.types.LiteralType(flytekit.models.core.types.SimpleType.INTEGER) typed_interface = _interface.TypedInterface( {"a": _interface.Variable(int_type, "description1")}, {"b": _interface.Variable(int_type, "description2"), "c": _interface.Variable(int_type, "description3")}, @@ -20,16 +23,20 @@ def test_workflow_closure(): "a", _literals.BindingData(scalar=_literals.Scalar(primitive=_literals.Primitive(integer=5))), ) - b1 = _literals.Binding("b", _literals.BindingData(promise=_types.OutputReference("my_node", "b"))) - b2 = _literals.Binding("c", _literals.BindingData(promise=_types.OutputReference("my_node", "c"))) + b1 = _literals.Binding( + "b", _literals.BindingData(promise=flytekit.models.core.types.OutputReference("my_node", "b")) + ) + b2 = _literals.Binding( + "c", _literals.BindingData(promise=flytekit.models.core.types.OutputReference("my_node", "c")) + ) node_metadata = _workflow.NodeMetadata( name="node1", timeout=timedelta(seconds=10), retries=_literals.RetryStrategy(0) ) - task_metadata = _task.TaskMetadata( + task_metadata = _taskMetadata( True, - _task.RuntimeMetadata(_task.RuntimeMetadata.RuntimeType.FLYTE_SDK, "1.0.0", "python"), + _runtimeMetadata(_runtimeMetadata.RuntimeType.FLYTE_SDK, "1.0.0", "python"), timedelta(days=1), _literals.RetryStrategy(3), True, @@ -37,16 +44,18 @@ def test_workflow_closure(): "This is deprecated!", ) - cpu_resource = _task.Resources.ResourceEntry(_task.Resources.ResourceName.CPU, "1") - resources = _task.Resources(requests=[cpu_resource], limits=[cpu_resource]) + cpu_resource = flytekit.models.core.task.Resources.ResourceEntry( + flytekit.models.core.task.Resources.ResourceName.CPU, "1" + ) + resources = flytekit.models.core.task.Resources(requests=[cpu_resource], limits=[cpu_resource]) - task = _task.TaskTemplate( + task = _taskTemplate( _identifier.Identifier(_identifier.ResourceType.TASK, "project", "domain", "name", "version"), "python", task_metadata, typed_interface, {"a": 1, "b": {"c": 2, "d": 3}}, - container=_task.Container( + container=flytekit.models.core.task.Container( "my_image", ["this", "is", "a", "cmd"], ["this", "is", "an", "arg"], diff --git a/tests/flytekit/unit/remote/test_remote.py b/tests/flytekit/unit/remote/test_remote.py index 873f67d752..2eba580f25 100644 --- a/tests/flytekit/unit/remote/test_remote.py +++ b/tests/flytekit/unit/remote/test_remote.py @@ -3,9 +3,15 @@ import pytest from mock import MagicMock, patch +import flytekit.models.admin.common +import flytekit.models.admin.launch_plan from flytekit.common.exceptions import user as user_exceptions from flytekit.configuration import internal -from flytekit.models import common as common_models +from flytekit.models.admin import common as _common +from flytekit.models.admin.execution import Execution +from flytekit.models.admin.launch_plan import LaunchPlan +from flytekit.models.admin.node_execution import NodeExecution, NodeExecutionMetaData +from flytekit.models.admin.task import Task from flytekit.models.admin.workflow import Workflow from flytekit.models.core.identifier import ( Identifier, @@ -13,12 +19,8 @@ ResourceType, WorkflowExecutionIdentifier, ) -from flytekit.models.execution import Execution -from flytekit.models.interface import TypedInterface, Variable -from flytekit.models.launch_plan import LaunchPlan -from flytekit.models.node_execution import NodeExecution, NodeExecutionMetaData -from flytekit.models.task import Task -from flytekit.models.types import LiteralType, SimpleType +from flytekit.models.core.interface import TypedInterface, Variable +from flytekit.models.core.types import LiteralType, SimpleType from flytekit.remote import FlyteWorkflow from flytekit.remote.remote import FlyteRemote @@ -153,8 +155,8 @@ def test_underscore_execute_uses_launch_plan_attributes(mock_insecure, mock_url, def local_assertions(*args, **kwargs): execution_spec = args[3] assert execution_spec.auth_role.kubernetes_service_account == "svc" - assert execution_spec.labels == common_models.Labels({"a": "my_label_value"}) - assert execution_spec.annotations == common_models.Annotations({"b": "my_annotation_value"}) + assert execution_spec.labels == _common.Labels({"a": "my_label_value"}) + assert execution_spec.annotations == _common.Annotations({"b": "my_annotation_value"}) mock_client.create_execution.side_effect = local_assertions @@ -165,9 +167,9 @@ def local_assertions(*args, **kwargs): inputs={}, project="proj", domain="dev", - labels=common_models.Labels({"a": "my_label_value"}), - annotations=common_models.Annotations({"b": "my_annotation_value"}), - auth_role=common_models.AuthRole(kubernetes_service_account="svc"), + labels=_common.Labels({"a": "my_label_value"}), + annotations=_common.Annotations({"b": "my_annotation_value"}), + auth_role=flytekit.models.admin.common.AuthRole(kubernetes_service_account="svc"), ) diff --git a/tests/flytekit/unit/sdk/tasks/test_hive_tasks.py b/tests/flytekit/unit/sdk/tasks/test_hive_tasks.py index e81edec5de..adc44b68f0 100644 --- a/tests/flytekit/unit/sdk/tasks/test_hive_tasks.py +++ b/tests/flytekit/unit/sdk/tasks/test_hive_tasks.py @@ -13,7 +13,7 @@ from flytekit.common.types import schema as _schema from flytekit.common.types.impl.schema import Schema from flytekit.engines import common as _common_engine -from flytekit.models import literals as _literals +from flytekit.models.core import literals as _literals from flytekit.models.core.identifier import WorkflowExecutionIdentifier from flytekit.sdk.tasks import hive_task, inputs, outputs, qubole_hive_task from flytekit.sdk.types import Types diff --git a/tests/flytekit/unit/sdk/tasks/test_pytorch_task.py b/tests/flytekit/unit/sdk/tasks/test_pytorch_task.py index 8472cde292..1c9b5c62e9 100644 --- a/tests/flytekit/unit/sdk/tasks/test_pytorch_task.py +++ b/tests/flytekit/unit/sdk/tasks/test_pytorch_task.py @@ -1,9 +1,9 @@ import datetime as _datetime +import flytekit.models.core.types from flytekit.common import constants as _common_constants from flytekit.common.tasks import pytorch_task as _pytorch_task from flytekit.common.tasks import sdk_runnable as _sdk_runnable -from flytekit.models import types as _type_models from flytekit.models.core import identifier as _identifier from flytekit.sdk.tasks import inputs, outputs, pytorch_task from flytekit.sdk.types import Types @@ -23,12 +23,12 @@ def test_simple_pytorch_task(): assert isinstance(simple_pytorch_task, _pytorch_task.SdkPyTorchTask) assert isinstance(simple_pytorch_task, _sdk_runnable.SdkRunnableTask) assert simple_pytorch_task.interface.inputs["in1"].description == "" - assert simple_pytorch_task.interface.inputs["in1"].type == _type_models.LiteralType( - simple=_type_models.SimpleType.INTEGER + assert simple_pytorch_task.interface.inputs["in1"].type == flytekit.models.core.types.LiteralType( + simple=flytekit.models.core.types.SimpleType.INTEGER ) assert simple_pytorch_task.interface.outputs["out1"].description == "" - assert simple_pytorch_task.interface.outputs["out1"].type == _type_models.LiteralType( - simple=_type_models.SimpleType.STRING + assert simple_pytorch_task.interface.outputs["out1"].type == flytekit.models.core.types.LiteralType( + simple=flytekit.models.core.types.SimpleType.STRING ) assert simple_pytorch_task.type == _common_constants.SdkTaskType.PYTORCH_TASK assert simple_pytorch_task.task_function_name == "simple_pytorch_task" diff --git a/tests/flytekit/unit/sdk/tasks/test_sagemaker_tasks.py b/tests/flytekit/unit/sdk/tasks/test_sagemaker_tasks.py index 19b4b8e2a1..e411e73851 100644 --- a/tests/flytekit/unit/sdk/tasks/test_sagemaker_tasks.py +++ b/tests/flytekit/unit/sdk/tasks/test_sagemaker_tasks.py @@ -7,6 +7,7 @@ from flyteidl.plugins.sagemaker.training_job_pb2 import TrainingJobResourceConfig as _pb2_TrainingJobResourceConfig from google.protobuf.json_format import ParseDict +import flytekit.models.core.types from flytekit.common import constants as _common_constants from flytekit.common import utils as _utils from flytekit.common.core.identifier import WorkflowExecutionIdentifier @@ -22,9 +23,8 @@ from flytekit.common.types import helpers as _type_helpers from flytekit.engines import common as _common_engine from flytekit.engines.unit.mock_stats import MockStats -from flytekit.models import literals as _literals -from flytekit.models import types as _idl_types from flytekit.models.core import identifier as _identifier +from flytekit.models.core import literals as _literals from flytekit.models.core import types as _core_types from flytekit.models.sagemaker.hpo_job import HyperparameterTuningJobConfig as _HyperparameterTuningJobConfig from flytekit.models.sagemaker.hpo_job import ( @@ -101,7 +101,7 @@ def test_builtin_algorithm_training_job_task(): assert isinstance(builtin_algorithm_training_job_task, SdkBuiltinAlgorithmTrainingJobTask) assert isinstance(builtin_algorithm_training_job_task, _sdk_task.SdkTask) assert builtin_algorithm_training_job_task.interface.inputs["train"].description == "" - assert builtin_algorithm_training_job_task.interface.inputs["train"].type == _idl_types.LiteralType( + assert builtin_algorithm_training_job_task.interface.inputs["train"].type == flytekit.models.core.types.LiteralType( blob=_core_types.BlobType( format="csv", dimensionality=_core_types.BlobType.BlobDimensionality.MULTIPART, @@ -116,7 +116,7 @@ def test_builtin_algorithm_training_job_task(): builtin_algorithm_training_job_task.interface.inputs["validation"].type == _sdk_types.Types.MultiPartCSV.to_flyte_literal_type() ) - assert builtin_algorithm_training_job_task.interface.inputs["train"].type == _idl_types.LiteralType( + assert builtin_algorithm_training_job_task.interface.inputs["train"].type == flytekit.models.core.types.LiteralType( blob=_core_types.BlobType( format="csv", dimensionality=_core_types.BlobType.BlobDimensionality.MULTIPART, @@ -185,7 +185,7 @@ def test_simple_hpo_job_task(): simple_xgboost_hpo_job_task.interface.inputs["train"].type == _sdk_types.Types.MultiPartCSV.to_flyte_literal_type() ) - assert simple_xgboost_hpo_job_task.interface.inputs["train"].type == _idl_types.LiteralType( + assert simple_xgboost_hpo_job_task.interface.inputs["train"].type == flytekit.models.core.types.LiteralType( blob=_core_types.BlobType( format="csv", dimensionality=_core_types.BlobType.BlobDimensionality.MULTIPART, @@ -196,7 +196,7 @@ def test_simple_hpo_job_task(): simple_xgboost_hpo_job_task.interface.inputs["validation"].type == _sdk_types.Types.MultiPartCSV.to_flyte_literal_type() ) - assert simple_xgboost_hpo_job_task.interface.inputs["validation"].type == _idl_types.LiteralType( + assert simple_xgboost_hpo_job_task.interface.inputs["validation"].type == flytekit.models.core.types.LiteralType( blob=_core_types.BlobType( format="csv", dimensionality=_core_types.BlobType.BlobDimensionality.MULTIPART, diff --git a/tests/flytekit/unit/sdk/tasks/test_spark_task.py b/tests/flytekit/unit/sdk/tasks/test_spark_task.py index cfccd5ecde..44e7a53add 100644 --- a/tests/flytekit/unit/sdk/tasks/test_spark_task.py +++ b/tests/flytekit/unit/sdk/tasks/test_spark_task.py @@ -2,11 +2,11 @@ import os as _os import sys as _sys +import flytekit.models.core.types from flytekit.bin import entrypoint as _entrypoint from flytekit.common import constants as _common_constants from flytekit.common.tasks import sdk_runnable as _sdk_runnable from flytekit.common.tasks import spark_task as _spark_task -from flytekit.models import types as _type_models from flytekit.models.core import identifier as _identifier from flytekit.sdk.tasks import inputs, outputs, spark_task from flytekit.sdk.types import Types @@ -26,10 +26,12 @@ def test_default_python_task(): assert isinstance(default_task, _spark_task.SdkSparkTask) assert isinstance(default_task, _sdk_runnable.SdkRunnableTask) assert default_task.interface.inputs["in1"].description == "" - assert default_task.interface.inputs["in1"].type == _type_models.LiteralType(simple=_type_models.SimpleType.INTEGER) + assert default_task.interface.inputs["in1"].type == flytekit.models.core.types.LiteralType( + simple=flytekit.models.core.types.SimpleType.INTEGER + ) assert default_task.interface.outputs["out1"].description == "" - assert default_task.interface.outputs["out1"].type == _type_models.LiteralType( - simple=_type_models.SimpleType.STRING + assert default_task.interface.outputs["out1"].type == flytekit.models.core.types.LiteralType( + simple=flytekit.models.core.types.SimpleType.STRING ) assert default_task.type == _common_constants.SdkTaskType.SPARK_TASK assert default_task.task_function_name == "default_task" diff --git a/tests/flytekit/unit/sdk/tasks/test_tasks.py b/tests/flytekit/unit/sdk/tasks/test_tasks.py index 33e0287199..430a7d6e07 100644 --- a/tests/flytekit/unit/sdk/tasks/test_tasks.py +++ b/tests/flytekit/unit/sdk/tasks/test_tasks.py @@ -1,11 +1,11 @@ import datetime as _datetime import os as _os +import flytekit.models.core.task +import flytekit.models.core.types from flytekit import configuration as _configuration from flytekit.common import constants as _common_constants from flytekit.common.tasks import sdk_runnable as _sdk_runnable -from flytekit.models import task as _task_models -from flytekit.models import types as _type_models from flytekit.models.core import identifier as _identifier from flytekit.sdk.tasks import inputs, outputs, python_task from flytekit.sdk.types import Types @@ -24,10 +24,12 @@ def default_task(wf_params, in1, out1): def test_default_python_task(): assert isinstance(default_task, _sdk_runnable.SdkRunnableTask) assert default_task.interface.inputs["in1"].description == "" - assert default_task.interface.inputs["in1"].type == _type_models.LiteralType(simple=_type_models.SimpleType.INTEGER) + assert default_task.interface.inputs["in1"].type == flytekit.models.core.types.LiteralType( + simple=flytekit.models.core.types.SimpleType.INTEGER + ) assert default_task.interface.outputs["out1"].description == "" - assert default_task.interface.outputs["out1"].type == _type_models.LiteralType( - simple=_type_models.SimpleType.STRING + assert default_task.interface.outputs["out1"].type == flytekit.models.core.types.LiteralType( + simple=flytekit.models.core.types.SimpleType.STRING ) assert default_task.type == _common_constants.SdkTaskType.PYTHON_TASK assert default_task.task_function_name == "default_task" @@ -59,15 +61,15 @@ def default_task2(wf_params, in1, out1): limit_map = {l.name: l.value for l in default_task2.container.resources.limits} - assert request_map[_task_models.Resources.ResourceName.CPU] == "500m" - assert request_map[_task_models.Resources.ResourceName.MEMORY] == "500Gi" - assert request_map[_task_models.Resources.ResourceName.GPU] == "1" - assert request_map[_task_models.Resources.ResourceName.STORAGE] == "500Gi" + assert request_map[flytekit.models.core.task.Resources.ResourceName.CPU] == "500m" + assert request_map[flytekit.models.core.task.Resources.ResourceName.MEMORY] == "500Gi" + assert request_map[flytekit.models.core.task.Resources.ResourceName.GPU] == "1" + assert request_map[flytekit.models.core.task.Resources.ResourceName.STORAGE] == "500Gi" - assert limit_map[_task_models.Resources.ResourceName.CPU] == "501m" - assert limit_map[_task_models.Resources.ResourceName.MEMORY] == "501Gi" - assert limit_map[_task_models.Resources.ResourceName.GPU] == "2" - assert limit_map[_task_models.Resources.ResourceName.STORAGE] == "501Gi" + assert limit_map[flytekit.models.core.task.Resources.ResourceName.CPU] == "501m" + assert limit_map[flytekit.models.core.task.Resources.ResourceName.MEMORY] == "501Gi" + assert limit_map[flytekit.models.core.task.Resources.ResourceName.GPU] == "2" + assert limit_map[flytekit.models.core.task.Resources.ResourceName.STORAGE] == "501Gi" def test_overriden_resources(): @@ -97,12 +99,12 @@ def default_task2(wf_params, in1, out1): limit_map = {l.name: l.value for l in default_task2.container.resources.limits} - assert request_map[_task_models.Resources.ResourceName.CPU] == "500m" - assert request_map[_task_models.Resources.ResourceName.MEMORY] == "50Gi" - assert request_map[_task_models.Resources.ResourceName.GPU] == "0" - assert request_map[_task_models.Resources.ResourceName.STORAGE] == "100Gi" + assert request_map[flytekit.models.core.task.Resources.ResourceName.CPU] == "500m" + assert request_map[flytekit.models.core.task.Resources.ResourceName.MEMORY] == "50Gi" + assert request_map[flytekit.models.core.task.Resources.ResourceName.GPU] == "0" + assert request_map[flytekit.models.core.task.Resources.ResourceName.STORAGE] == "100Gi" - assert limit_map[_task_models.Resources.ResourceName.CPU] == "1000m" - assert limit_map[_task_models.Resources.ResourceName.MEMORY] == "100Gi" - assert limit_map[_task_models.Resources.ResourceName.GPU] == "1" - assert limit_map[_task_models.Resources.ResourceName.STORAGE] == "200Gi" + assert limit_map[flytekit.models.core.task.Resources.ResourceName.CPU] == "1000m" + assert limit_map[flytekit.models.core.task.Resources.ResourceName.MEMORY] == "100Gi" + assert limit_map[flytekit.models.core.task.Resources.ResourceName.GPU] == "1" + assert limit_map[flytekit.models.core.task.Resources.ResourceName.STORAGE] == "200Gi" diff --git a/tests/flytekit/unit/sdk/tasks/test_tensorflow_task.py b/tests/flytekit/unit/sdk/tasks/test_tensorflow_task.py index 9e6caf29d3..fd1c0ac5f9 100644 --- a/tests/flytekit/unit/sdk/tasks/test_tensorflow_task.py +++ b/tests/flytekit/unit/sdk/tasks/test_tensorflow_task.py @@ -1,9 +1,9 @@ import datetime as _datetime +import flytekit.models.core.types from flytekit.common import constants as _common_constants from flytekit.common.tasks import sdk_runnable as _sdk_runnable from flytekit.common.tasks import tensorflow_task as _tensorflow_task -from flytekit.models import types as _type_models from flytekit.models.core import identifier as _identifier from flytekit.sdk.tasks import inputs, outputs, tensorflow_task from flytekit.sdk.types import Types @@ -25,12 +25,12 @@ def test_simple_tensorflow_task(): assert isinstance(simple_tensorflow_task, _tensorflow_task.SdkTensorFlowTask) assert isinstance(simple_tensorflow_task, _sdk_runnable.SdkRunnableTask) assert simple_tensorflow_task.interface.inputs["in1"].description == "" - assert simple_tensorflow_task.interface.inputs["in1"].type == _type_models.LiteralType( - simple=_type_models.SimpleType.INTEGER + assert simple_tensorflow_task.interface.inputs["in1"].type == flytekit.models.core.types.LiteralType( + simple=flytekit.models.core.types.SimpleType.INTEGER ) assert simple_tensorflow_task.interface.outputs["out1"].description == "" - assert simple_tensorflow_task.interface.outputs["out1"].type == _type_models.LiteralType( - simple=_type_models.SimpleType.STRING + assert simple_tensorflow_task.interface.outputs["out1"].type == flytekit.models.core.types.LiteralType( + simple=flytekit.models.core.types.SimpleType.STRING ) assert simple_tensorflow_task.type == _common_constants.SdkTaskType.TENSORFLOW_TASK assert simple_tensorflow_task.task_function_name == "simple_tensorflow_task" diff --git a/tests/flytekit/unit/type_engines/default/test_flyte_type_engine.py b/tests/flytekit/unit/type_engines/default/test_flyte_type_engine.py index 8157739456..0ea5675c11 100644 --- a/tests/flytekit/unit/type_engines/default/test_flyte_type_engine.py +++ b/tests/flytekit/unit/type_engines/default/test_flyte_type_engine.py @@ -1,17 +1,17 @@ import pytest from flyteidl.core import errors_pb2 as _errors_pb2 +import flytekit.models.core.types from flytekit.common.exceptions import user as _user_exceptions from flytekit.common.types import proto as _proto -from flytekit.models import literals as _literal_models -from flytekit.models import types as _type_models +from flytekit.models.core import literals as _literal_models from flytekit.type_engines.default import flyte as _flyte_engine def test_proto_from_literal_type(): sdk_type = _flyte_engine.FlyteDefaultTypeEngine().get_sdk_type_from_literal_type( - _type_models.LiteralType( - simple=_type_models.SimpleType.BINARY, + flytekit.models.core.types.LiteralType( + simple=flytekit.models.core.types.SimpleType.BINARY, metadata={_proto.Protobuf.PB_FIELD_KEY: "flyteidl.core.errors_pb2.ContainerError"}, ) ) @@ -21,8 +21,8 @@ def test_proto_from_literal_type(): def test_generic_proto_from_literal_type(): sdk_type = _flyte_engine.FlyteDefaultTypeEngine().get_sdk_type_from_literal_type( - _type_models.LiteralType( - simple=_type_models.SimpleType.STRUCT, + flytekit.models.core.types.LiteralType( + simple=flytekit.models.core.types.SimpleType.STRUCT, metadata={_proto.Protobuf.PB_FIELD_KEY: "flyteidl.core.errors_pb2.ContainerError"}, ) ) @@ -33,8 +33,8 @@ def test_generic_proto_from_literal_type(): def test_unloadable_module_from_literal_type(): with pytest.raises(_user_exceptions.FlyteAssertion): _flyte_engine.FlyteDefaultTypeEngine().get_sdk_type_from_literal_type( - _type_models.LiteralType( - simple=_type_models.SimpleType.BINARY, + flytekit.models.core.types.LiteralType( + simple=flytekit.models.core.types.SimpleType.BINARY, metadata={_proto.Protobuf.PB_FIELD_KEY: "flyteidl.core.errors_pb2_no_exist.ContainerError"}, ) ) @@ -43,8 +43,8 @@ def test_unloadable_module_from_literal_type(): def test_unloadable_proto_from_literal_type(): with pytest.raises(_user_exceptions.FlyteAssertion): _flyte_engine.FlyteDefaultTypeEngine().get_sdk_type_from_literal_type( - _type_models.LiteralType( - simple=_type_models.SimpleType.BINARY, + flytekit.models.core.types.LiteralType( + simple=flytekit.models.core.types.SimpleType.BINARY, metadata={_proto.Protobuf.PB_FIELD_KEY: "flyteidl.core.errors_pb2.ContainerErrorNoExist"}, ) )