diff --git a/flytekit/common/tasks/sagemaker/hpo_job_task.py b/flytekit/common/tasks/sagemaker/hpo_job_task.py index 918a4388fd..518816cf2b 100644 --- a/flytekit/common/tasks/sagemaker/hpo_job_task.py +++ b/flytekit/common/tasks/sagemaker/hpo_job_task.py @@ -1,7 +1,6 @@ import datetime as _datetime import typing -from flyteidl.plugins.sagemaker import hyperparameter_tuning_job_pb2 as _pb2_hpo_job from google.protobuf.json_format import MessageToDict from flytekit import __version__ @@ -10,13 +9,13 @@ from flytekit.common.tasks import task as _sdk_task from flytekit.common.tasks.sagemaker.built_in_training_job_task import SdkBuiltinAlgorithmTrainingJobTask from flytekit.common.tasks.sagemaker.custom_training_job_task import CustomTrainingJobTask +from flytekit.common.tasks.sagemaker.types import HyperparameterTuningJobConfig, ParameterRange from flytekit.models import interface as _interface_model from flytekit.models import literals as _literal_models from flytekit.models import task as _task_models from flytekit.models import types as _types_models from flytekit.models.core import types as _core_types from flytekit.models.sagemaker import hpo_job as _hpo_job_model -from flytekit.sdk import types as _sdk_types class SdkSimpleHyperparameterTuningJobTask(_sdk_task.SdkTask): @@ -28,17 +27,24 @@ def __init__( retries: int = 0, cacheable: bool = False, cache_version: str = "", + tunable_parameters: typing.List[str] = None, ): """ - - :param max_number_of_training_jobs: The maximum number of training jobs that can be launched by this + :param int max_number_of_training_jobs: The maximum number of training jobs that can be launched by this hyperparameter tuning job - :param max_parallel_training_jobs: The maximum number of training jobs that can launched by this hyperparameter + :param int max_parallel_training_jobs: The maximum number of training jobs that can launched by this hyperparameter tuning job in parallel - :param training_job: The reference to the training job definition - :param retries: Number of retries to attempt - :param cacheable: The flag to set if the user wants the output of the task execution to be cached - :param cache_version: String describing the caching version for task discovery purposes + :param typing.Union[SdkBuiltinAlgorithmTrainingJobTask, CustomTrainingJobTask] training_job: The reference to the training job definition + :param int retries: Number of retries to attempt + :param bool cacheable: The flag to set if the user wants the output of the task execution to be cached + :param str cache_version: String describing the caching version for task discovery purposes + :param typing.List[str] tunable_parameters: A list of parameters that to tune. If you are tuning a built-int + algorithm, refer to the algorithm's documentation to understand the possible values for the tunable + parameters. E.g. Refer to https://docs.aws.amazon.com/sagemaker/latest/dg/IC-Hyperparameter.html for the + list of hyperparameters for Image Classification built-in algorithm. If you are passing a custom + training job, the list of tunable parameters must be a strict subset of the list of inputs defined on + that job. Refer to https://docs.aws.amazon.com/sagemaker/latest/dg/automatic-model-tuning-define-ranges.html + for the list of supported hyperparameter types. """ # Use the training job model as a measure of type checking hpo_job = _hpo_job_model.HyperparameterTuningJob( @@ -52,14 +58,25 @@ def __init__( # TODO: Discuss whether this is a viable interface or contract timeout = _datetime.timedelta(seconds=0) - inputs = { - "hyperparameter_tuning_job_config": _interface_model.Variable( - _sdk_types.Types.Proto(_pb2_hpo_job.HyperparameterTuningJobConfig).to_flyte_literal_type(), "", - ), - } + inputs = {} inputs.update(training_job.interface.inputs) + inputs.update( + { + "hyperparameter_tuning_job_config": _interface_model.Variable( + HyperparameterTuningJobConfig.to_flyte_literal_type(), "", + ), + } + ) + + if tunable_parameters: + inputs.update( + { + param: _interface_model.Variable(ParameterRange.to_flyte_literal_type(), "") + for param in tunable_parameters + } + ) - super(SdkSimpleHyperparameterTuningJobTask, self).__init__( + super().__init__( type=SdkTaskType.SAGEMAKER_HYPERPARAMETER_TUNING_JOB_TASK, metadata=_task_models.TaskMetadata( runtime=_task_models.RuntimeMetadata( @@ -87,3 +104,7 @@ def __init__( ), custom=MessageToDict(hpo_job), ) + + def __call__(self, *args, **kwargs): + # Overriding the call function just so we clear up the docs and avoid IDEs complaining about the signature. + return super().__call__(*args, **kwargs) diff --git a/flytekit/common/tasks/sagemaker/types.py b/flytekit/common/tasks/sagemaker/types.py new file mode 100644 index 0000000000..5efa8fd6bf --- /dev/null +++ b/flytekit/common/tasks/sagemaker/types.py @@ -0,0 +1,7 @@ +from flytekit.models.sagemaker import hpo_job as _hpo_models +from flytekit.models.sagemaker import parameter_ranges as _parameter_range_models +from flytekit.sdk import types as _sdk_types + +HyperparameterTuningJobConfig = _sdk_types.Types.GenericProto(_hpo_models.HyperparameterTuningJobConfig) + +ParameterRange = _sdk_types.Types.GenericProto(_parameter_range_models.ParameterRangeOneOf) diff --git a/flytekit/common/types/base_sdk_types.py b/flytekit/common/types/base_sdk_types.py index 7ac659c3a6..ab12969865 100644 --- a/flytekit/common/types/base_sdk_types.py +++ b/flytekit/common/types/base_sdk_types.py @@ -1,5 +1,7 @@ import abc as _abc +from flyteidl.core.literals_pb2 import Literal + from flytekit.common import sdk_bases as _sdk_bases from flytekit.common.exceptions import user as _user_exceptions from flytekit.models import common as _common_models @@ -54,7 +56,7 @@ def __hash__(cls): class FlyteSdkValue(_literal_models.Literal, metaclass=FlyteSdkType): @classmethod - def from_flyte_idl(cls, pb2_object): + def from_flyte_idl(cls, pb2_object: Literal): """ :param flyteidl.core.literals_pb2.Literal pb2_object: :rtype: FlyteSdkValue diff --git a/flytekit/common/types/proto.py b/flytekit/common/types/proto.py index d6202cb39b..8d7700e215 100644 --- a/flytekit/common/types/proto.py +++ b/flytekit/common/types/proto.py @@ -1,35 +1,29 @@ import base64 as _base64 +from typing import Type, Union import six as _six from google.protobuf import reflection as _proto_reflection +from google.protobuf.json_format import Error +from google.protobuf.json_format import MessageToDict as _MessageToDict +from google.protobuf.json_format import ParseDict as _ParseDict +from google.protobuf.reflection import GeneratedProtocolMessageType +from google.protobuf.struct_pb2 import Struct from flytekit.common.exceptions import user as _user_exceptions from flytekit.common.types import base_sdk_types as _base_sdk_types from flytekit.models import literals as _literals from flytekit.models import types as _idl_types +from flytekit.models.common import FlyteIdlEntity, FlyteType +from flytekit.models.types import LiteralType - -def create_protobuf(pb_type): - """ - :param T pb_type: - :rtype: ProtobufType - """ - if not isinstance(pb_type, _proto_reflection.GeneratedProtocolMessageType): - raise _user_exceptions.FlyteTypeException( - expected_type=_proto_reflection.GeneratedProtocolMessageType, - received_type=type(pb_type), - received_value=pb_type, - ) - - class _Protobuf(Protobuf): - _pb_type = pb_type - - return _Protobuf +ProtobufT = Type[_proto_reflection.GeneratedProtocolMessageType] class ProtobufType(_base_sdk_types.FlyteSdkType): + _pb_type = Struct + @property - def pb_type(cls): + def pb_type(cls) -> GeneratedProtocolMessageType: """ :rtype: GeneratedProtocolMessageType """ @@ -51,15 +45,34 @@ def tag(cls): class Protobuf(_base_sdk_types.FlyteSdkValue, metaclass=ProtobufType): - PB_FIELD_KEY = "pb_type" TAG_PREFIX = "{}=".format(PB_FIELD_KEY) - def __init__(self, pb_object): + def __init__(self, pb_object: Union[GeneratedProtocolMessageType, FlyteIdlEntity]): """ - :param T pb_object: + :param Union[T, FlyteIdlEntity] pb_object: """ - data = pb_object.SerializeToString() + v = pb_object + # This section converts an existing proto object (or a subclass of) to the right type expected by this instance + # of GenericProto. GenericProto can be used with any protobuf type (not restricted to FlyteType). This makes it + # a bit tricky to figure out the right version of the underlying raw proto class to use to populate the final + # struct. + # If the provided object has to_flyte_idl(), call it to produce a raw proto. + if isinstance(pb_object, FlyteIdlEntity): + v = pb_object.to_flyte_idl() + + # A check to ensure the raw proto (v) is of the correct expected type. This also performs one final attempt to + # convert it to the correct type by leveraging from_flyte_idl (implemented by all FlyteTypes) in case this class + # is initialized with one. + expected_type = type(self).pb_type + if expected_type != type(v) and expected_type != type(pb_object): + if isinstance(type(self).pb_type, FlyteType): + v = expected_type.from_flyte_idl(v).to_flyte_idl() + else: + raise _user_exceptions.FlyteTypeException( + received_type=type(pb_object), expected_type=expected_type, received_value=pb_object + ) + data = v.SerializeToString() super(Protobuf, self).__init__( scalar=_literals.Scalar( binary=_literals.Binary(value=bytes(data) if _six.PY2 else data, tag=type(self).tag) @@ -97,7 +110,7 @@ def from_python_std(cls, t_value): """ if t_value is None: return _base_sdk_types.Void() - elif isinstance(t_value, cls.pb_type): + elif isinstance(t_value, cls.pb_type) or isinstance(t_value, FlyteIdlEntity): return cls(t_value) else: raise _user_exceptions.FlyteTypeException(type(t_value), cls.pb_type, received_value=t_value) @@ -148,3 +161,159 @@ def short_string(self): :rtype: Text """ return "{}".format(self.to_python_std()) + + +def create_protobuf(pb_type: Type[GeneratedProtocolMessageType]) -> Type[Protobuf]: + """ + :param Type[GeneratedProtocolMessageType] pb_type: + :rtype: Type[Protobuf] + """ + if not isinstance(pb_type, _proto_reflection.GeneratedProtocolMessageType): + raise _user_exceptions.FlyteTypeException( + expected_type=_proto_reflection.GeneratedProtocolMessageType, + received_type=type(pb_type), + received_value=pb_type, + ) + + class _Protobuf(Protobuf): + _pb_type = pb_type + + return _Protobuf + + +class GenericProtobuf(_base_sdk_types.FlyteSdkValue, metaclass=ProtobufType): + PB_FIELD_KEY = "pb_type" + TAG_PREFIX = "{}=".format(PB_FIELD_KEY) + + def __init__(self, pb_object: Union[GeneratedProtocolMessageType, FlyteIdlEntity]): + """ + :param Union[T, FlyteIdlEntity] pb_object: + """ + struct = Struct() + v = pb_object + + # This section converts an existing proto object (or a subclass of) to the right type expected by this instance + # of GenericProto. GenericProto can be used with any protobuf type (not restricted to FlyteType). This makes it + # a bit tricky to figure out the right version of the underlying raw proto class to use to populate the final + # struct. + # If the provided object has to_flyte_idl(), call it to produce a raw proto. + if isinstance(pb_object, FlyteIdlEntity): + v = pb_object.to_flyte_idl() + + # A check to ensure the raw proto (v) is of the correct expected type. This also performs one final attempt to + # convert it to the correct type by leveraging from_flyte_idl (implemented by all FlyteTypes) in case this class + # is initialized with one. + expected_type = type(self).pb_type + if expected_type != type(v) and expected_type != type(pb_object): + if isinstance(type(self).pb_type, FlyteType): + v = expected_type.from_flyte_idl(v).to_flyte_idl() + else: + raise _user_exceptions.FlyteTypeException( + received_type=type(pb_object), expected_type=expected_type, received_value=pb_object + ) + + struct.update(_MessageToDict(v)) + super().__init__(scalar=_literals.Scalar(generic=struct,)) + + @classmethod + def is_castable_from(cls, other): + """ + :param flytekit.common.types.base_literal_types.FlyteSdkType other: + :rtype: bool + """ + return isinstance(other, ProtobufType) and other.pb_type is cls.pb_type + + @classmethod + def from_python_std(cls, t_value: Union[GeneratedProtocolMessageType, FlyteIdlEntity]): + """ + :param Union[T, FlyteIdlEntity] t_value: It is up to each individual object as to whether or not this value can be cast. + :rtype: _base_sdk_types.FlyteSdkValue + :raises: flytekit.common.exceptions.user.FlyteTypeException + """ + if t_value is None: + return _base_sdk_types.Void() + elif isinstance(t_value, cls.pb_type) or isinstance(t_value, FlyteIdlEntity): + return cls(t_value) + else: + raise _user_exceptions.FlyteTypeException(type(t_value), cls.pb_type, received_value=t_value) + + @classmethod + def to_flyte_literal_type(cls) -> LiteralType: + """ + :rtype: flytekit.models.types.LiteralType + """ + return _idl_types.LiteralType(simple=_idl_types.SimpleType.STRUCT, metadata={cls.PB_FIELD_KEY: cls.descriptor},) + + @classmethod + def promote_from_model(cls, literal_model): + """ + Creates an object of this type from the model primitive defining it. + :param flytekit.models.literals.Literal literal_model: + :rtype: Protobuf + """ + pb_obj = cls.pb_type() + try: + dictionary = _MessageToDict(literal_model.scalar.generic) + pb_obj = _ParseDict(dictionary, pb_obj) + except Error as err: + raise _user_exceptions.FlyteTypeException( + received_type="generic", + expected_type=cls.pb_type, + received_value=_base64.b64encode(literal_model.scalar.generic), + additional_msg=f"Can not deserialize. Error: {err.__str__()}", + ) + + return cls(pb_obj) + + @classmethod + def short_class_string(cls) -> str: + """ + :rtype: Text + """ + return "Types.GenericProto({})".format(cls.descriptor) + + def to_python_std(self): + """ + :returns: The protobuf object as defined by the user. + :rtype: T + """ + pb_obj = type(self).pb_type() + try: + dictionary = _MessageToDict(self.scalar.generic) + pb_obj = _ParseDict(dictionary, pb_obj) + except Error as err: + raise _user_exceptions.FlyteTypeException( + received_type="generic", + expected_type=type(self).pb_type, + received_value=_base64.b64encode(self.scalar.generic), + additional_msg=f"Can not deserialize. Error: {err.__str__()}", + ) + return pb_obj + + def short_string(self) -> str: + """ + :rtype: Text + """ + return "{}".format(self.to_python_std()) + + +def create_generic(pb_type: Type[GeneratedProtocolMessageType]) -> Type[GenericProtobuf]: + """ + Creates a generic protobuf type that represents protobuf type ProtobufT and that will get serialized into a struct. + + :param Type[GeneratedProtocolMessageType] pb_type: + :rtype: Type[GenericProtobuf] + """ + if not isinstance(pb_type, _proto_reflection.GeneratedProtocolMessageType) and not issubclass( + pb_type, FlyteIdlEntity + ): + raise _user_exceptions.FlyteTypeException( + expected_type=_proto_reflection.GeneratedProtocolMessageType, + received_type=type(pb_type), + received_value=pb_type, + ) + + class _Protobuf(GenericProtobuf): + _pb_type = pb_type + + return _Protobuf diff --git a/flytekit/models/literals.py b/flytekit/models/literals.py index 0961f80108..3b8405e106 100644 --- a/flytekit/models/literals.py +++ b/flytekit/models/literals.py @@ -3,6 +3,7 @@ import pytz as _pytz import six as _six from flyteidl.core import literals_pb2 as _literals_pb2 +from google.protobuf.struct_pb2 import Struct from flytekit.common.exceptions import user as _user_exceptions from flytekit.models import common as _common @@ -196,119 +197,6 @@ def from_flyte_idl(cls, pb2_object): return cls(value=pb2_object.value, tag=pb2_object.tag) -class Scalar(_common.FlyteIdlEntity): - def __init__( - self, primitive=None, blob=None, binary=None, schema=None, none_type=None, error=None, generic=None, - ): - """ - Scalar wrapper around Flyte types. Only one can be specified. - - :param Primitive primitive: - :param Blob blob: - :param Binary binary: - :param Schema schema: - :param Void none_type: - :param error: - :param google.protobuf.struct_pb2.Struct generic: - """ - - self._primitive = primitive - self._blob = blob - self._binary = binary - self._schema = schema - self._none_type = none_type - self._error = error - self._generic = generic - - @property - def primitive(self): - """ - :rtype: Primitive - """ - return self._primitive - - @property - def blob(self): - """ - :rtype: Blob - """ - return self._blob - - @property - def binary(self): - """ - :rtype: Binary - """ - return self._binary - - @property - def schema(self): - """ - :rtype: Schema - """ - return self._schema - - @property - def none_type(self): - """ - :rtype: Void - """ - return self._none_type - - @property - def error(self): - """ - :rtype: TODO - """ - return self._error - - @property - def generic(self): - """ - :rtype: google.protobuf.struct_pb2.Struct - """ - return self._generic - - @property - def value(self): - """ - Returns whichever value is set - :rtype: T - """ - return self.primitive or self.blob or self.binary or self.schema or self.none_type or self.error - - def to_flyte_idl(self): - """ - :rtype: flyteidl.core.literals_pb2.Scalar - """ - return _literals_pb2.Scalar( - primitive=self.primitive.to_flyte_idl() if self.primitive is not None else None, - blob=self.blob.to_flyte_idl() if self.blob is not None else None, - binary=self.binary.to_flyte_idl() if self.binary is not None else None, - schema=self.schema.to_flyte_idl() if self.schema is not None else None, - none_type=self.none_type.to_flyte_idl() if self.none_type is not None else None, - error=self.error if self.error is not None else None, - generic=self.generic, - ) - - @classmethod - def from_flyte_idl(cls, pb2_object): - """ - :param flyteidl.core.literals_pb2.Scalar pb2_object: - :rtype: flytekit.models.literals.Scalar - """ - # todo finish - return cls( - primitive=Primitive.from_flyte_idl(pb2_object.primitive) if pb2_object.HasField("primitive") else None, - blob=Blob.from_flyte_idl(pb2_object.blob) if pb2_object.HasField("blob") else None, - binary=Binary.from_flyte_idl(pb2_object.binary) if pb2_object.HasField("binary") else None, - schema=Schema.from_flyte_idl(pb2_object.schema) if pb2_object.HasField("schema") else None, - none_type=Void.from_flyte_idl(pb2_object.none_type) if pb2_object.HasField("none_type") else None, - error=pb2_object.error if pb2_object.HasField("error") else None, - generic=pb2_object.generic if pb2_object.HasField("generic") else None, - ) - - class BlobMetadata(_common.FlyteIdlEntity): def __init__(self, type): """ @@ -701,8 +589,128 @@ def from_flyte_idl(cls, pb2_object): return cls({k: Literal.from_flyte_idl(v) for k, v in _six.iteritems(pb2_object.literals)}) +class Scalar(_common.FlyteIdlEntity): + def __init__( + self, + primitive: Primitive = None, + blob: Blob = None, + binary: Binary = None, + schema: Schema = None, + none_type: Void = None, + error=None, + generic: Struct = None, + ): + """ + Scalar wrapper around Flyte types. Only one can be specified. + + :param Primitive primitive: + :param Blob blob: + :param Binary binary: + :param Schema schema: + :param Void none_type: + :param error: + :param google.protobuf.struct_pb2.Struct generic: + """ + + self._primitive = primitive + self._blob = blob + self._binary = binary + self._schema = schema + self._none_type = none_type + self._error = error + self._generic = generic + + @property + def primitive(self): + """ + :rtype: Primitive + """ + return self._primitive + + @property + def blob(self): + """ + :rtype: Blob + """ + return self._blob + + @property + def binary(self): + """ + :rtype: Binary + """ + return self._binary + + @property + def schema(self): + """ + :rtype: Schema + """ + return self._schema + + @property + def none_type(self): + """ + :rtype: Void + """ + return self._none_type + + @property + def error(self): + """ + :rtype: TODO + """ + return self._error + + @property + def generic(self): + """ + :rtype: google.protobuf.struct_pb2.Struct + """ + return self._generic + + @property + def value(self): + """ + Returns whichever value is set + :rtype: T + """ + return self.primitive or self.blob or self.binary or self.schema or self.none_type or self.error + + def to_flyte_idl(self): + """ + :rtype: flyteidl.core.literals_pb2.Scalar + """ + return _literals_pb2.Scalar( + primitive=self.primitive.to_flyte_idl() if self.primitive is not None else None, + blob=self.blob.to_flyte_idl() if self.blob is not None else None, + binary=self.binary.to_flyte_idl() if self.binary is not None else None, + schema=self.schema.to_flyte_idl() if self.schema is not None else None, + none_type=self.none_type.to_flyte_idl() if self.none_type is not None else None, + error=self.error if self.error is not None else None, + generic=self.generic, + ) + + @classmethod + def from_flyte_idl(cls, pb2_object): + """ + :param flyteidl.core.literals_pb2.Scalar pb2_object: + :rtype: flytekit.models.literals.Scalar + """ + # todo finish + return cls( + primitive=Primitive.from_flyte_idl(pb2_object.primitive) if pb2_object.HasField("primitive") else None, + blob=Blob.from_flyte_idl(pb2_object.blob) if pb2_object.HasField("blob") else None, + binary=Binary.from_flyte_idl(pb2_object.binary) if pb2_object.HasField("binary") else None, + schema=Schema.from_flyte_idl(pb2_object.schema) if pb2_object.HasField("schema") else None, + none_type=Void.from_flyte_idl(pb2_object.none_type) if pb2_object.HasField("none_type") else None, + error=pb2_object.error if pb2_object.HasField("error") else None, + generic=pb2_object.generic if pb2_object.HasField("generic") else None, + ) + + class Literal(_common.FlyteIdlEntity): - def __init__(self, scalar=None, collection=None, map=None): + def __init__(self, scalar: Scalar = None, collection: LiteralCollection = None, map: LiteralMap = None): """ :param Scalar scalar: :param LiteralCollection collection: diff --git a/flytekit/models/sagemaker/hpo_job.py b/flytekit/models/sagemaker/hpo_job.py index 74f2900615..ee13f9c318 100644 --- a/flytekit/models/sagemaker/hpo_job.py +++ b/flytekit/models/sagemaker/hpo_job.py @@ -1,7 +1,6 @@ from flyteidl.plugins.sagemaker import hyperparameter_tuning_job_pb2 as _pb2_hpo_job from flytekit.models import common as _common -from flytekit.models.sagemaker import parameter_ranges as _parameter_ranges_models from flytekit.models.sagemaker import training_job as _training_job @@ -72,25 +71,14 @@ class HyperparameterTuningJobConfig(_common.FlyteIdlEntity): def __init__( self, - hyperparameter_ranges: _parameter_ranges_models.ParameterRanges, tuning_strategy: int, tuning_objective: HyperparameterTuningObjective, training_job_early_stopping_type: int, ): - self._hyperparameter_ranges = hyperparameter_ranges self._tuning_strategy = tuning_strategy self._tuning_objective = tuning_objective self._training_job_early_stopping_type = training_job_early_stopping_type - @property - def hyperparameter_ranges(self) -> _parameter_ranges_models.ParameterRanges: - """ - hyperparameter_ranges is a structure containing a map that maps hyperparameter name to the corresponding - hyperparameter range object - :rtype: _parameter_ranges_models.ParameterRanges - """ - return self._hyperparameter_ranges - @property def tuning_strategy(self) -> int: """ @@ -122,7 +110,6 @@ def training_job_early_stopping_type(self) -> int: def to_flyte_idl(self) -> _pb2_hpo_job.HyperparameterTuningJobConfig: return _pb2_hpo_job.HyperparameterTuningJobConfig( - hyperparameter_ranges=self._hyperparameter_ranges.to_flyte_idl(), tuning_strategy=self._tuning_strategy, tuning_objective=self._tuning_objective.to_flyte_idl(), training_job_early_stopping_type=self._training_job_early_stopping_type, @@ -132,9 +119,6 @@ def to_flyte_idl(self) -> _pb2_hpo_job.HyperparameterTuningJobConfig: def from_flyte_idl(cls, pb2_object: _pb2_hpo_job.HyperparameterTuningJobConfig): return cls( - hyperparameter_ranges=( - _parameter_ranges_models.ParameterRanges.from_flyte_idl(pb2_object.hyperparameter_ranges) - ), tuning_strategy=pb2_object.tuning_strategy, tuning_objective=HyperparameterTuningObjective.from_flyte_idl(pb2_object.tuning_objective), training_job_early_stopping_type=pb2_object.training_job_early_stopping_type, diff --git a/flytekit/models/sagemaker/parameter_ranges.py b/flytekit/models/sagemaker/parameter_ranges.py index d673afa33d..e9016f2a1d 100644 --- a/flytekit/models/sagemaker/parameter_ranges.py +++ b/flytekit/models/sagemaker/parameter_ranges.py @@ -1,7 +1,8 @@ -from typing import Dict, List +from typing import Dict, List, Optional, Union from flyteidl.plugins.sagemaker import parameter_ranges_pb2 as _idl_parameter_ranges +from flytekit.common.exceptions import user from flytekit.models import common as _common @@ -177,8 +178,15 @@ def to_flyte_idl(self) -> _idl_parameter_ranges.ParameterRanges: converted[k] = _idl_parameter_ranges.ParameterRangeOneOf(integer_parameter_range=v.to_flyte_idl()) elif isinstance(v, ContinuousParameterRange): converted[k] = _idl_parameter_ranges.ParameterRangeOneOf(continuous_parameter_range=v.to_flyte_idl()) - else: + elif isinstance(v, CategoricalParameterRange): converted[k] = _idl_parameter_ranges.ParameterRangeOneOf(categorical_parameter_range=v.to_flyte_idl()) + else: + raise user.FlyteTypeException( + received_type=type(v), + expected_type=type( + Union[IntegerParameterRange, ContinuousParameterRange, CategoricalParameterRange] + ), + ) return _idl_parameter_ranges.ParameterRanges(parameter_range_map=converted,) @@ -199,3 +207,88 @@ def from_flyte_idl(cls, pb2_object: _idl_parameter_ranges.ParameterRanges): converted[k] = CategoricalParameterRange.from_flyte_idl(v.categorical_parameter_range) return cls(parameter_range_map=converted,) + + +class ParameterRangeOneOf(_common.FlyteIdlEntity): + def __init__(self, param: Union[IntegerParameterRange, ContinuousParameterRange, CategoricalParameterRange]): + """ + Initializes a new ParameterRangeOneOf. + :param Union[IntegerParameterRange, ContinuousParameterRange, CategoricalParameterRange] param: One of the + supported parameter ranges. + """ + self._integer_parameter_range = param if isinstance(param, IntegerParameterRange) else None + self._continuous_parameter_range = param if isinstance(param, ContinuousParameterRange) else None + self._categorical_parameter_range = param if isinstance(param, CategoricalParameterRange) else None + + @property + def integer_parameter_range(self) -> Optional[IntegerParameterRange]: + """ + Retrieves the integer parameter range if one is set. None otherwise. + :rtype: Optional[IntegerParameterRange] + """ + if self._integer_parameter_range: + return self._integer_parameter_range + + return None + + @property + def continuous_parameter_range(self) -> Optional[ContinuousParameterRange]: + """ + Retrieves the continuous parameter range if one is set. None otherwise. + :rtype: Optional[ContinuousParameterRange] + """ + if self._continuous_parameter_range: + return self._continuous_parameter_range + + return None + + @property + def categorical_parameter_range(self) -> Optional[CategoricalParameterRange]: + """ + Retrieves the categorical parameter range if one is set. None otherwise. + :rtype: Optional[CategoricalParameterRange] + """ + if self._categorical_parameter_range: + return self._categorical_parameter_range + + return None + + def to_flyte_idl(self) -> _idl_parameter_ranges.ParameterRangeOneOf: + return _idl_parameter_ranges.ParameterRangeOneOf( + integer_parameter_range=self.integer_parameter_range.to_flyte_idl() + if self.integer_parameter_range + else None, + continuous_parameter_range=self.continuous_parameter_range.to_flyte_idl() + if self.continuous_parameter_range + else None, + categorical_parameter_range=self.categorical_parameter_range.to_flyte_idl() + if self.categorical_parameter_range + else None, + ) + + @classmethod + def from_flyte_idl( + cls, + pb_object: Union[ + _idl_parameter_ranges.ParameterRangeOneOf, + _idl_parameter_ranges.IntegerParameterRange, + _idl_parameter_ranges.ContinuousParameterRange, + _idl_parameter_ranges.CategoricalParameterRange, + ], + ): + param = None + if isinstance(pb_object, _idl_parameter_ranges.ParameterRangeOneOf): + if pb_object.HasField("continuous_parameter_range"): + param = ContinuousParameterRange.from_flyte_idl(pb_object.continuous_parameter_range) + elif pb_object.HasField("integer_parameter_range"): + param = IntegerParameterRange.from_flyte_idl(pb_object.integer_parameter_range) + elif pb_object.HasField("categorical_parameter_range"): + param = CategoricalParameterRange.from_flyte_idl(pb_object.categorical_parameter_range) + elif isinstance(pb_object, _idl_parameter_ranges.IntegerParameterRange): + param = IntegerParameterRange.from_flyte_idl(pb_object) + elif isinstance(pb_object, _idl_parameter_ranges.ContinuousParameterRange): + param = ContinuousParameterRange.from_flyte_idl(pb_object) + elif isinstance(pb_object, _idl_parameter_ranges.CategoricalParameterRange): + param = CategoricalParameterRange.from_flyte_idl(pb_object) + + return cls(param=param) diff --git a/flytekit/sdk/types.py b/flytekit/sdk/types.py index 924fc54733..7b64d1db49 100644 --- a/flytekit/sdk/types.py +++ b/flytekit/sdk/types.py @@ -163,7 +163,10 @@ def hundred_times_longer(wf_params, a, b): Generic = _helpers.get_sdk_type_from_literal_type(_primitives.Generic.to_flyte_literal_type()) """ - Use this to specify a simple JSON type. + Use this to specify a simple JSON type. The Generic type offer a flexible (but loose) extension to flyte's typing + system by allowing custom types/objects to be passed through. It's strongly recommended for producers & consumers of + entities that produce or consume a Generic type to perform their own expectations checks on the integrity of the + object. When used with an SDK-decorated method, expect this behavior from the default type engine: @@ -190,6 +193,22 @@ def operate(wf_params, a, b): elif a['operation'] == 'merge': a['value'].update(a['some']['nested'][0]['field']) b.set(a) + + # For better readability, it's strongly advised to leverage python's type aliasing. + MyTypeA = Types.Generic + MyTypeB = Types.Generic + + # This makes it clearer that it received a certain type and produces a different one. Other tasks that consume + # MyTypeB should do so in their input declaration. + @inputs(a=MyTypeA) + @outputs(b=MyTypeB) + @python_task + def operate(wf_params, a, b): + if a['operation'] == 'add': + a['value'] += a['operand'] # a['value'] is a number + elif a['operation'] == 'merge': + a['value'].update(a['some']['nested'][0]['field']) + b.set(a) """ Blob = _blobs.Blob @@ -357,7 +376,12 @@ def concat_then_split(wf_params, generic, typed,): Proto = staticmethod(_proto.create_protobuf) """ - Use this to specify a custom protobuf type. + Proto type wraps a protobuf type to provide interoperability between protobuf and flyte typing system. Using this + type, you can define custom input/output variable types of flyte entities and continue to provide strong typing + syntax. Proto type serializes proto objects as binary (leveraging `flyteidl's Binary literal `_). + Binary serialization of protobufs is the most space-efficient serialization form. Because of the way protobufs are + designed, unmarshalling the serialized proto requires access to the corresponding type. In order to use/visualize + the serialized proto, you will generally need to write custom code in the corresponding component. .. note:: @@ -395,6 +419,52 @@ def assert_and_create(wf_params, a, b): ) """ + GenericProto = staticmethod(_proto.create_generic) + """ + GenericProto type wraps a protobuf type to provide interoperability between protobuf and flyte typing system. Using + this type, you can define custom input/output variable types of flyte entities and continue to provide strong typing + syntax. Proto type serializes proto objects as binary (leveraging `flyteidl's Binary literal `_). + A generic proto is a specialization of the Generic type with added convenience functions to support marshalling/ + unmarshalling of the underlying protobuf object using the protobuf official json marshaller. While GenericProto type + does not produce the most space-efficient representation of protobufs, it's a suitable solution for making protobufs + easily accessible (i.e. humanly readable) in other flyte components (e.g. console, cli... etc.). + + .. note:: + + The protobuf Python library should be installed on the PYTHONPATH to ensure the type engine can access the + appropriate Python code to deserialize the protobuf. + + When used with an SDK-decorated method, expect this behavior from the default type engine: + + As input: + 1) If set, a Python protobuf object of the type specified in the definition. + 2) If not set, a None value. + + As output: + 1) A Python protobuf object matching the type specified by the users. + 2) Set None to null the output. + + From command-line: + A base-64 encoded string of the serialized protobuf. + + .. code-block:: python + + from protos import my_protos_pb2 + + @inputs(a=Types.GenericProto(my_protos_pb2.Custom)) + @outputs(b=Types.GenericProto(my_protos_pb2.Custom)) + @python_task + def assert_and_create(wf_params, a, b): + assert a.field1 == 1 + assert a.field2 == 'abc' + b.set( + my_protos_pb2.Custom( + field1=100, + field2='hello' + ) + ) + """ + List = staticmethod(_containers.List) """ Use this to specify a list of any type--including nested lists. diff --git a/flytekit/type_engines/default/flyte.py b/flytekit/type_engines/default/flyte.py index 4b58019003..b930a63d97 100644 --- a/flytekit/type_engines/default/flyte.py +++ b/flytekit/type_engines/default/flyte.py @@ -1,4 +1,5 @@ import importlib as _importer +from typing import Type from flytekit.common.exceptions import system as _system_exceptions from flytekit.common.exceptions import user as _user_exceptions @@ -13,11 +14,11 @@ from flytekit.models.core import types as _core_types -def _proto_sdk_type_from_tag(tag): +def _load_type_from_tag(tag: str) -> Type: """ - :param Text tag: - :rtype: _proto.Protobuf + Loads python type from tag """ + if "." not in tag: raise _user_exceptions.FlyteValueException( tag, "Protobuf tag must include at least one '.' to delineate package and object name.", @@ -34,11 +35,27 @@ def _proto_sdk_type_from_tag(tag): if not hasattr(pb_module, name): raise _user_exceptions.FlyteAssertion("Could not find the protobuf named: {} @ {}.".format(name, module)) - return _proto.create_protobuf(getattr(pb_module, name)) + return getattr(pb_module, name) -class FlyteDefaultTypeEngine(object): +def _proto_sdk_type_from_tag(tag): + """ + :param Text tag: + :rtype: _proto.Protobuf + """ + return _proto.create_protobuf(_load_type_from_tag(tag)) + +def _generic_proto_sdk_type_from_tag(tag: str) -> Type[_proto.GenericProtobuf]: + """ + :param Text tag: + :rtype: _proto.GenericProtobuf + """ + + return _proto.create_generic(_load_type_from_tag(tag)) + + +class FlyteDefaultTypeEngine(object): _SIMPLE_TYPE_LOOKUP_TABLE = { _literal_type_models.SimpleType.INTEGER: _primitive_types.Integer, _literal_type_models.SimpleType.FLOAT: _primitive_types.Float, @@ -95,6 +112,12 @@ def get_sdk_type_from_literal_type(self, literal_type): and _proto.Protobuf.PB_FIELD_KEY in literal_type.metadata ): return _proto_sdk_type_from_tag(literal_type.metadata[_proto.Protobuf.PB_FIELD_KEY]) + if ( + literal_type.simple == _literal_type_models.SimpleType.STRUCT + and literal_type.metadata + and _proto.Protobuf.PB_FIELD_KEY in literal_type.metadata + ): + return _generic_proto_sdk_type_from_tag(literal_type.metadata[_proto.Protobuf.PB_FIELD_KEY]) sdk_type = self._SIMPLE_TYPE_LOOKUP_TABLE.get(literal_type.simple) if sdk_type is None: raise NotImplementedError( diff --git a/tests/flytekit/common/workflows/sagemaker.py b/tests/flytekit/common/workflows/sagemaker.py new file mode 100644 index 0000000000..b0fe7d36d6 --- /dev/null +++ b/tests/flytekit/common/workflows/sagemaker.py @@ -0,0 +1,115 @@ +import os as _os + +from flytekit import configuration as _configuration +from flytekit.common.tasks.sagemaker import hpo_job_task +from flytekit.common.tasks.sagemaker.built_in_training_job_task import SdkBuiltinAlgorithmTrainingJobTask +from flytekit.common.tasks.sagemaker.types import HyperparameterTuningJobConfig +from flytekit.models.sagemaker.hpo_job import HyperparameterTuningJobConfig as _HyperparameterTuningJobConfig +from flytekit.models.sagemaker.hpo_job import ( + HyperparameterTuningObjective, + HyperparameterTuningObjectiveType, + HyperparameterTuningStrategy, + TrainingJobEarlyStoppingType, +) +from flytekit.models.sagemaker.parameter_ranges import ( + ContinuousParameterRange, + HyperparameterScalingType, + IntegerParameterRange, +) +from flytekit.models.sagemaker.training_job import ( + AlgorithmName, + AlgorithmSpecification, + InputContentType, + InputMode, + TrainingJobResourceConfig, +) +from flytekit.sdk.types import Types +from flytekit.sdk.workflow import Input, workflow_class + +example_hyperparams = { + "base_score": "0.5", + "booster": "gbtree", + "csv_weights": "0", + "dsplit": "row", + "grow_policy": "depthwise", + "lambda_bias": "0.0", + "max_bin": "256", + "max_leaves": "0", + "normalize_type": "tree", + "objective": "reg:linear", + "one_drop": "0", + "prob_buffer_row": "1.0", + "process_type": "default", + "rate_drop": "0.0", + "refresh_leaf": "1", + "sample_type": "uniform", + "scale_pos_weight": "1.0", + "silent": "0", + "sketch_eps": "0.03", + "skip_drop": "0.0", + "tree_method": "auto", + "tweedie_variance_power": "1.5", + "updater": "grow_colmaker,prune", +} + +builtin_algorithm_training_job_task2 = SdkBuiltinAlgorithmTrainingJobTask( + training_job_resource_config=TrainingJobResourceConfig( + instance_type="ml.m4.xlarge", instance_count=1, volume_size_in_gb=25, + ), + algorithm_specification=AlgorithmSpecification( + input_mode=InputMode.FILE, + input_content_type=InputContentType.TEXT_CSV, + algorithm_name=AlgorithmName.XGBOOST, + algorithm_version="0.72", + ), +) + +simple_xgboost_hpo_job_task = hpo_job_task.SdkSimpleHyperparameterTuningJobTask( + training_job=builtin_algorithm_training_job_task2, + max_number_of_training_jobs=10, + max_parallel_training_jobs=5, + cache_version="1", + retries=2, + cacheable=True, + tunable_parameters=["num_round", "max_depth", "gamma"], +) + + +@workflow_class +class SageMakerHPO(object): + train_dataset = Input(Types.MultiPartCSV, default="s3://somelocation") + validation_dataset = Input(Types.MultiPartCSV, default="s3://somelocation") + static_hyperparameters = Input(Types.Generic, default=example_hyperparams) + hyperparameter_tuning_job_config = Input( + HyperparameterTuningJobConfig, + default=_HyperparameterTuningJobConfig( + tuning_strategy=HyperparameterTuningStrategy.BAYESIAN, + tuning_objective=HyperparameterTuningObjective( + objective_type=HyperparameterTuningObjectiveType.MINIMIZE, metric_name="validation:error", + ), + training_job_early_stopping_type=TrainingJobEarlyStoppingType.AUTO, + ), + ) + + a = simple_xgboost_hpo_job_task( + train=train_dataset, + validation=validation_dataset, + static_hyperparameters=static_hyperparameters, + hyperparameter_tuning_job_config=hyperparameter_tuning_job_config, + num_round=IntegerParameterRange(min_value=2, max_value=8, scaling_type=HyperparameterScalingType.LINEAR), + max_depth=IntegerParameterRange(min_value=5, max_value=7, scaling_type=HyperparameterScalingType.LINEAR), + gamma=ContinuousParameterRange(min_value=0.0, max_value=0.3, scaling_type=HyperparameterScalingType.LINEAR), + ) + + +sagemaker_hpo_lp = SageMakerHPO.create_launch_plan() + +with _configuration.TemporaryConfiguration( + _os.path.join(_os.path.dirname(_os.path.realpath(__file__)), "../../common/configs/local.config",), + internal_overrides={"image": "myflyteimage:v123", "project": "myflyteproject", "domain": "development"}, +): + print("Printing WF definition") + print(SageMakerHPO) + + print("Printing LP definition") + print(sagemaker_hpo_lp) diff --git a/tests/flytekit/unit/common_tests/types/test_proto.py b/tests/flytekit/unit/common_tests/types/test_proto.py index 959dd2d56d..074fb1dfd8 100644 --- a/tests/flytekit/unit/common_tests/types/test_proto.py +++ b/tests/flytekit/unit/common_tests/types/test_proto.py @@ -5,6 +5,7 @@ from flytekit.common.exceptions import user as _user_exceptions from flytekit.common.types import proto as _proto +from flytekit.common.types.proto import ProtobufType from flytekit.models import types as _type_models @@ -26,7 +27,16 @@ def test_proto_to_literal_type(): def test_proto(): proto_type = _proto.create_protobuf(_errors_pb2.ContainerError) assert proto_type.short_class_string() == "Types.Proto(flyteidl.core.errors_pb2.ContainerError)" + run_test_proto_type(proto_type) + +def test_generic_proto(): + proto_type = _proto.create_generic(_errors_pb2.ContainerError) + assert proto_type.short_class_string() == "Types.GenericProto(flyteidl.core.errors_pb2.ContainerError)" + run_test_proto_type(proto_type) + + +def run_test_proto_type(proto_type: ProtobufType): pb = _errors_pb2.ContainerError(code="code", message="message") obj = proto_type.from_python_std(pb) obj2 = proto_type.from_flyte_idl(obj.to_flyte_idl()) diff --git a/tests/flytekit/unit/models/sagemaker/test_hpo_job.py b/tests/flytekit/unit/models/sagemaker/test_hpo_job.py index d426933100..eab5f209b4 100644 --- a/tests/flytekit/unit/models/sagemaker/test_hpo_job.py +++ b/tests/flytekit/unit/models/sagemaker/test_hpo_job.py @@ -1,4 +1,4 @@ -from flytekit.models.sagemaker import hpo_job, parameter_ranges, training_job +from flytekit.models.sagemaker import hpo_job, training_job def test_hyperparameter_tuning_objective(): @@ -12,14 +12,6 @@ def test_hyperparameter_tuning_objective(): def test_hyperparameter_job_config(): jc = hpo_job.HyperparameterTuningJobConfig( - hyperparameter_ranges=parameter_ranges.ParameterRanges( - parameter_range_map={ - "a": parameter_ranges.CategoricalParameterRange(values=["1", "2"]), - "b": parameter_ranges.IntegerParameterRange( - min_value=0, max_value=10, scaling_type=parameter_ranges.HyperparameterScalingType.LINEAR - ), - } - ), tuning_strategy=hpo_job.HyperparameterTuningStrategy.BAYESIAN, tuning_objective=hpo_job.HyperparameterTuningObjective( objective_type=hpo_job.HyperparameterTuningObjectiveType.MAXIMIZE, metric_name="test_metric" @@ -28,7 +20,6 @@ def test_hyperparameter_job_config(): ) jc2 = hpo_job.HyperparameterTuningJobConfig.from_flyte_idl(jc.to_flyte_idl()) - assert jc2.hyperparameter_ranges == jc.hyperparameter_ranges assert jc2.tuning_strategy == jc.tuning_strategy assert jc2.tuning_objective == jc.tuning_objective assert jc2.training_job_early_stopping_type == jc.training_job_early_stopping_type diff --git a/tests/flytekit/unit/models/sagemaker/test_parameter_ranges.py b/tests/flytekit/unit/models/sagemaker/test_parameter_ranges.py index 2d5fe2acef..e10899bdb7 100644 --- a/tests/flytekit/unit/models/sagemaker/test_parameter_ranges.py +++ b/tests/flytekit/unit/models/sagemaker/test_parameter_ranges.py @@ -1,8 +1,15 @@ import unittest +import pytest + from flytekit.models.sagemaker import parameter_ranges +# assert statements cannot be written inside lambda expressions. This is a convenient function to work around that. +def assert_equal(a, b): + assert a == b + + def test_continuous_parameter_range(): pr = parameter_ranges.ContinuousParameterRange( max_value=10, min_value=0.5, scaling_type=parameter_ranges.HyperparameterScalingType.REVERSELOGARITHMIC @@ -55,3 +62,33 @@ def test_parameter_ranges(): ) pr2 = parameter_ranges.ParameterRanges.from_flyte_idl(pr.to_flyte_idl()) assert pr == pr2 + + +LIST_OF_PARAMETERS = [ + ( + parameter_ranges.IntegerParameterRange( + min_value=1, max_value=5, scaling_type=parameter_ranges.HyperparameterScalingType.LINEAR + ), + lambda param: assert_equal(param.integer_parameter_range.max_value, 5), + ), + ( + parameter_ranges.ContinuousParameterRange( + min_value=0.1, max_value=1.0, scaling_type=parameter_ranges.HyperparameterScalingType.LOGARITHMIC + ), + lambda param: assert_equal(param.continuous_parameter_range.max_value, 1), + ), + ( + parameter_ranges.CategoricalParameterRange(values=["a-1", "a-2"]), + lambda param: assert_equal(len(param.categorical_parameter_range.values), 2), + ), +] + + +@pytest.mark.parametrize("param_tuple", LIST_OF_PARAMETERS) +def test_parameter_ranges_oneof(param_tuple): + param, assertion = param_tuple + oneof = parameter_ranges.ParameterRangeOneOf(param=param) + oneof2 = parameter_ranges.ParameterRangeOneOf.from_flyte_idl(oneof.to_flyte_idl()) + assert oneof2 == oneof + assertion(oneof) + assertion(oneof2) diff --git a/tests/flytekit/unit/sdk/tasks/test_sagemaker_tasks.py b/tests/flytekit/unit/sdk/tasks/test_sagemaker_tasks.py index 9b5c0ce947..ce95bf7024 100644 --- a/tests/flytekit/unit/sdk/tasks/test_sagemaker_tasks.py +++ b/tests/flytekit/unit/sdk/tasks/test_sagemaker_tasks.py @@ -1,15 +1,12 @@ import datetime as _datetime -import os +import os as _os import unittest from unittest import mock import retry.api -from flyteidl.plugins.sagemaker.hyperparameter_tuning_job_pb2 import HyperparameterTuningJobConfig as _pb2_HPOJobConfig from flyteidl.plugins.sagemaker.training_job_pb2 import TrainingJobResourceConfig as _pb2_TrainingJobResourceConfig from google.protobuf.json_format import ParseDict -import flytekit.common.tasks.sagemaker.distributed_training -import flytekit.models.core.types as _core_types from flytekit.common import constants as _common_constants from flytekit.common import utils as _utils from flytekit.common.core.identifier import WorkflowExecutionIdentifier @@ -18,13 +15,30 @@ from flytekit.common.tasks.sagemaker import hpo_job_task from flytekit.common.tasks.sagemaker.built_in_training_job_task import SdkBuiltinAlgorithmTrainingJobTask from flytekit.common.tasks.sagemaker.custom_training_job_task import CustomTrainingJobTask -from flytekit.common.tasks.sagemaker.hpo_job_task import SdkSimpleHyperparameterTuningJobTask +from flytekit.common.tasks.sagemaker.hpo_job_task import ( + HyperparameterTuningJobConfig, + SdkSimpleHyperparameterTuningJobTask, +) from flytekit.common.types import helpers as _type_helpers from flytekit.engines import common as _common_engine from flytekit.engines.unit.mock_stats import MockStats from flytekit.models import literals as _literals from flytekit.models import types as _idl_types from flytekit.models.core import identifier as _identifier +from flytekit.models.core import types as _core_types +from flytekit.models.sagemaker.hpo_job import HyperparameterTuningJobConfig as _HyperparameterTuningJobConfig +from flytekit.models.sagemaker.hpo_job import ( + HyperparameterTuningObjective, + HyperparameterTuningObjectiveType, + HyperparameterTuningStrategy, + TrainingJobEarlyStoppingType, +) +from flytekit.models.sagemaker.parameter_ranges import ( + ContinuousParameterRange, + HyperparameterScalingType, + IntegerParameterRange, + ParameterRangeOneOf, +) from flytekit.models.sagemaker.training_job import ( AlgorithmName, AlgorithmSpecification, @@ -37,6 +51,7 @@ from flytekit.sdk.sagemaker.task import custom_training_job_task from flytekit.sdk.tasks import inputs, outputs from flytekit.sdk.types import Types +from flytekit.sdk.workflow import Input, workflow_class example_hyperparams = { "base_score": "0.5", @@ -142,6 +157,7 @@ def test_builtin_algorithm_training_job_task(): cache_version="1", retries=2, cacheable=True, + tunable_parameters=["num_round", "gamma", "max_depth"], ) simple_xgboost_hpo_job_task._id = _identifier.Identifier( @@ -179,7 +195,7 @@ def test_simple_hpo_job_task(): assert simple_xgboost_hpo_job_task.interface.inputs["hyperparameter_tuning_job_config"].description == "" assert ( simple_xgboost_hpo_job_task.interface.inputs["hyperparameter_tuning_job_config"].type - == _sdk_types.Types.Proto(_pb2_HPOJobConfig).to_flyte_literal_type() + == HyperparameterTuningJobConfig.to_flyte_literal_type() ) assert simple_xgboost_hpo_job_task.interface.outputs["model"].description == "" assert simple_xgboost_hpo_job_task.interface.outputs["model"].type == _sdk_types.Types.Blob.to_flyte_literal_type() @@ -225,15 +241,47 @@ def my_task(wf_params, input_1, model): assert type(my_task) == CustomTrainingJobTask +def test_simple_hpo_job_task_interface(): + @workflow_class + class MyWf(object): + train_dataset = Input(Types.Blob) + validation_dataset = Input(Types.Blob) + static_hyperparameters = Input(Types.Generic) + hyperparameter_tuning_job_config = Input( + HyperparameterTuningJobConfig, + default=_HyperparameterTuningJobConfig( + tuning_strategy=HyperparameterTuningStrategy.BAYESIAN, + tuning_objective=HyperparameterTuningObjective( + objective_type=HyperparameterTuningObjectiveType.MINIMIZE, metric_name="validation:error", + ), + training_job_early_stopping_type=TrainingJobEarlyStoppingType.AUTO, + ), + ) + + a = simple_xgboost_hpo_job_task( + train=train_dataset, + validation=validation_dataset, + static_hyperparameters=static_hyperparameters, + hyperparameter_tuning_job_config=hyperparameter_tuning_job_config, + num_round=ParameterRangeOneOf( + IntegerParameterRange(min_value=3, max_value=10, scaling_type=HyperparameterScalingType.LINEAR) + ), + max_depth=ParameterRangeOneOf( + IntegerParameterRange(min_value=5, max_value=7, scaling_type=HyperparameterScalingType.LINEAR) + ), + gamma=ParameterRangeOneOf( + ContinuousParameterRange(min_value=0.0, max_value=0.3, scaling_type=HyperparameterScalingType.LINEAR) + ), + ) + + assert MyWf.nodes[0].inputs[2].binding.scalar.generic is not None + + # Defining a output-persist predicate to indicate if the copy of the instance should persist its output def predicate(distributed_training_context): return ( - distributed_training_context[ - flytekit.common.tasks.sagemaker.distributed_training.DistributedTrainingContextKey.CURRENT_HOST - ] - == distributed_training_context[ - flytekit.common.tasks.sagemaker.distributed_training.DistributedTrainingContextKey.HOSTS - ][1] + distributed_training_context[_sm_distribution.DistributedTrainingContextKey.CURRENT_HOST] + == distributed_training_context[_sm_distribution.DistributedTrainingContextKey.HOSTS][1] ) @@ -245,7 +293,6 @@ class DistributedCustomTrainingJobTaskTests(unittest.TestCase): @mock.patch.dict("os.environ", {}) def setUp(self): with _utils.AutoDeletingTempDir("input_dir") as input_dir: - self._task_input = _literals.LiteralMap( {"input_1": _literals.Literal(scalar=_literals.Scalar(primitive=_literals.Primitive(integer=1)))} ) @@ -280,7 +327,7 @@ def my_distributed_task(wf_params, input_1, model): def test_missing_current_host_in_distributed_training_context_keys_lead_to_keyerrors(self): with mock.patch.dict( - os.environ, + _os.environ, { _sm_distribution.SM_ENV_VAR_HOSTS: '["algo-0", "algo-1", "algo-2"]', _sm_distribution.SM_ENV_VAR_NETWORK_INTERFACE_NAME: "eth0", @@ -293,7 +340,7 @@ def test_missing_current_host_in_distributed_training_context_keys_lead_to_keyer def test_missing_hosts_distributed_training_context_keys_lead_to_keyerrors(self): with mock.patch.dict( - os.environ, + _os.environ, { _sm_distribution.SM_ENV_VAR_CURRENT_HOST: "algo-1", _sm_distribution.SM_ENV_VAR_NETWORK_INTERFACE_NAME: "eth0", @@ -306,7 +353,7 @@ def test_missing_hosts_distributed_training_context_keys_lead_to_keyerrors(self) def test_missing_network_interface_name_in_distributed_training_context_keys_lead_to_keyerrors(self): with mock.patch.dict( - os.environ, + _os.environ, { _sm_distribution.SM_ENV_VAR_CURRENT_HOST: "algo-1", _sm_distribution.SM_ENV_VAR_HOSTS: '["algo-0", "algo-1", "algo-2"]', @@ -319,7 +366,7 @@ def test_missing_network_interface_name_in_distributed_training_context_keys_lea def test_with_default_predicate_with_rank0_master(self): with mock.patch.dict( - os.environ, + _os.environ, { _sm_distribution.SM_ENV_VAR_CURRENT_HOST: "algo-0", _sm_distribution.SM_ENV_VAR_HOSTS: '["algo-0", "algo-1", "algo-2"]', @@ -333,7 +380,7 @@ def test_with_default_predicate_with_rank0_master(self): def test_with_default_predicate_with_rank1_master(self): with mock.patch.dict( - os.environ, + _os.environ, { _sm_distribution.SM_ENV_VAR_CURRENT_HOST: "algo-1", _sm_distribution.SM_ENV_VAR_HOSTS: '["algo-0", "algo-1", "algo-2"]', @@ -346,7 +393,7 @@ def test_with_default_predicate_with_rank1_master(self): def test_with_custom_predicate_with_none_dist_context(self): with mock.patch.dict( - os.environ, + _os.environ, { _sm_distribution.SM_ENV_VAR_CURRENT_HOST: "algo-1", _sm_distribution.SM_ENV_VAR_HOSTS: '["algo-0", "algo-1", "algo-2"]', @@ -363,7 +410,7 @@ def test_with_custom_predicate_with_none_dist_context(self): def test_with_custom_predicate_with_valid_dist_context(self): with mock.patch.dict( - os.environ, + _os.environ, { _sm_distribution.SM_ENV_VAR_CURRENT_HOST: "algo-1", _sm_distribution.SM_ENV_VAR_HOSTS: '["algo-0", "algo-1", "algo-2"]', @@ -382,7 +429,7 @@ def test_with_custom_predicate_with_valid_dist_context(self): def test_if_wf_param_has_dist_context(self): with mock.patch.dict( - os.environ, + _os.environ, { _sm_distribution.SM_ENV_VAR_CURRENT_HOST: "algo-1", _sm_distribution.SM_ENV_VAR_HOSTS: '["algo-0", "algo-1", "algo-2"]', diff --git a/tests/flytekit/unit/type_engines/default/test_flyte_type_engine.py b/tests/flytekit/unit/type_engines/default/test_flyte_type_engine.py index 90cbef2253..b37ce8981a 100644 --- a/tests/flytekit/unit/type_engines/default/test_flyte_type_engine.py +++ b/tests/flytekit/unit/type_engines/default/test_flyte_type_engine.py @@ -19,6 +19,17 @@ def test_proto_from_literal_type(): assert sdk_type.pb_type == _errors_pb2.ContainerError +def test_generic_proto_from_literal_type(): + sdk_type = _flyte_engine.FlyteDefaultTypeEngine().get_sdk_type_from_literal_type( + _type_models.LiteralType( + simple=_type_models.SimpleType.STRUCT, + metadata={_proto.Protobuf.PB_FIELD_KEY: "flyteidl.core.errors_pb2.ContainerError"}, + ) + ) + + assert sdk_type.pb_type == _errors_pb2.ContainerError + + def test_unloadable_module_from_literal_type(): with pytest.raises(_user_exceptions.FlyteAssertion): _flyte_engine.FlyteDefaultTypeEngine().get_sdk_type_from_literal_type(