Skip to content

Commit

Permalink
SageMaker custom types for ParameterRangeOneOf and HyperparameterConf…
Browse files Browse the repository at this point in the history
…ig. TunableParams extraction (flyteorg#189)

Creates custom types for ParameterRangeOneOf and HyperparameterJobConfig.

Use generics to serialize/deserialize protos to make UX visualization easier

Move tunable hyperparameters out of HyperparameterJobConfig to make it easier to bind (and consistent with Custom Container Training Job)


* Typing

* ParameterRangeOneOf model

* cleanup

* lint

* lint

* unittests

* unit

* unit

* isort

* isort

* py 3.5

* lint

* lint

* re-add generic types to protos

* lint

* Remove generic

* remove T

* reformat

* fix import

* PR Comments

* lint

* lint

* PR Comments

* PR Comments

* lint

* unittest

* PR Comments

* lint

* remove deprecated fields

* Support converting raw protos through Types.*Proto classes

* lint

* revert notebook.py

* lint
  • Loading branch information
EngHabu authored Oct 16, 2020
1 parent 10d4b0e commit ebcf501
Show file tree
Hide file tree
Showing 15 changed files with 797 additions and 209 deletions.
51 changes: 36 additions & 15 deletions flytekit/common/tasks/sagemaker/hpo_job_task.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import datetime as _datetime
import typing

from flyteidl.plugins.sagemaker import hyperparameter_tuning_job_pb2 as _pb2_hpo_job
from google.protobuf.json_format import MessageToDict

from flytekit import __version__
Expand All @@ -10,13 +9,13 @@
from flytekit.common.tasks import task as _sdk_task
from flytekit.common.tasks.sagemaker.built_in_training_job_task import SdkBuiltinAlgorithmTrainingJobTask
from flytekit.common.tasks.sagemaker.custom_training_job_task import CustomTrainingJobTask
from flytekit.common.tasks.sagemaker.types import HyperparameterTuningJobConfig, ParameterRange
from flytekit.models import interface as _interface_model
from flytekit.models import literals as _literal_models
from flytekit.models import task as _task_models
from flytekit.models import types as _types_models
from flytekit.models.core import types as _core_types
from flytekit.models.sagemaker import hpo_job as _hpo_job_model
from flytekit.sdk import types as _sdk_types


class SdkSimpleHyperparameterTuningJobTask(_sdk_task.SdkTask):
Expand All @@ -28,17 +27,24 @@ def __init__(
retries: int = 0,
cacheable: bool = False,
cache_version: str = "",
tunable_parameters: typing.List[str] = None,
):
"""
:param max_number_of_training_jobs: The maximum number of training jobs that can be launched by this
:param int max_number_of_training_jobs: The maximum number of training jobs that can be launched by this
hyperparameter tuning job
:param max_parallel_training_jobs: The maximum number of training jobs that can launched by this hyperparameter
:param int max_parallel_training_jobs: The maximum number of training jobs that can launched by this hyperparameter
tuning job in parallel
:param training_job: The reference to the training job definition
:param retries: Number of retries to attempt
:param cacheable: The flag to set if the user wants the output of the task execution to be cached
:param cache_version: String describing the caching version for task discovery purposes
:param typing.Union[SdkBuiltinAlgorithmTrainingJobTask, CustomTrainingJobTask] training_job: The reference to the training job definition
:param int retries: Number of retries to attempt
:param bool cacheable: The flag to set if the user wants the output of the task execution to be cached
:param str cache_version: String describing the caching version for task discovery purposes
:param typing.List[str] tunable_parameters: A list of parameters that to tune. If you are tuning a built-int
algorithm, refer to the algorithm's documentation to understand the possible values for the tunable
parameters. E.g. Refer to https://docs.aws.amazon.com/sagemaker/latest/dg/IC-Hyperparameter.html for the
list of hyperparameters for Image Classification built-in algorithm. If you are passing a custom
training job, the list of tunable parameters must be a strict subset of the list of inputs defined on
that job. Refer to https://docs.aws.amazon.com/sagemaker/latest/dg/automatic-model-tuning-define-ranges.html
for the list of supported hyperparameter types.
"""
# Use the training job model as a measure of type checking
hpo_job = _hpo_job_model.HyperparameterTuningJob(
Expand All @@ -52,14 +58,25 @@ def __init__(
# TODO: Discuss whether this is a viable interface or contract
timeout = _datetime.timedelta(seconds=0)

inputs = {
"hyperparameter_tuning_job_config": _interface_model.Variable(
_sdk_types.Types.Proto(_pb2_hpo_job.HyperparameterTuningJobConfig).to_flyte_literal_type(), "",
),
}
inputs = {}
inputs.update(training_job.interface.inputs)
inputs.update(
{
"hyperparameter_tuning_job_config": _interface_model.Variable(
HyperparameterTuningJobConfig.to_flyte_literal_type(), "",
),
}
)

if tunable_parameters:
inputs.update(
{
param: _interface_model.Variable(ParameterRange.to_flyte_literal_type(), "")
for param in tunable_parameters
}
)

super(SdkSimpleHyperparameterTuningJobTask, self).__init__(
super().__init__(
type=SdkTaskType.SAGEMAKER_HYPERPARAMETER_TUNING_JOB_TASK,
metadata=_task_models.TaskMetadata(
runtime=_task_models.RuntimeMetadata(
Expand Down Expand Up @@ -87,3 +104,7 @@ def __init__(
),
custom=MessageToDict(hpo_job),
)

def __call__(self, *args, **kwargs):
# Overriding the call function just so we clear up the docs and avoid IDEs complaining about the signature.
return super().__call__(*args, **kwargs)
7 changes: 7 additions & 0 deletions flytekit/common/tasks/sagemaker/types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from flytekit.models.sagemaker import hpo_job as _hpo_models
from flytekit.models.sagemaker import parameter_ranges as _parameter_range_models
from flytekit.sdk import types as _sdk_types

HyperparameterTuningJobConfig = _sdk_types.Types.GenericProto(_hpo_models.HyperparameterTuningJobConfig)

ParameterRange = _sdk_types.Types.GenericProto(_parameter_range_models.ParameterRangeOneOf)
4 changes: 3 additions & 1 deletion flytekit/common/types/base_sdk_types.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import abc as _abc

from flyteidl.core.literals_pb2 import Literal

from flytekit.common import sdk_bases as _sdk_bases
from flytekit.common.exceptions import user as _user_exceptions
from flytekit.models import common as _common_models
Expand Down Expand Up @@ -54,7 +56,7 @@ def __hash__(cls):

class FlyteSdkValue(_literal_models.Literal, metaclass=FlyteSdkType):
@classmethod
def from_flyte_idl(cls, pb2_object):
def from_flyte_idl(cls, pb2_object: Literal):
"""
:param flyteidl.core.literals_pb2.Literal pb2_object:
:rtype: FlyteSdkValue
Expand Down
215 changes: 192 additions & 23 deletions flytekit/common/types/proto.py
Original file line number Diff line number Diff line change
@@ -1,35 +1,29 @@
import base64 as _base64
from typing import Type, Union

import six as _six
from google.protobuf import reflection as _proto_reflection
from google.protobuf.json_format import Error
from google.protobuf.json_format import MessageToDict as _MessageToDict
from google.protobuf.json_format import ParseDict as _ParseDict
from google.protobuf.reflection import GeneratedProtocolMessageType
from google.protobuf.struct_pb2 import Struct

from flytekit.common.exceptions import user as _user_exceptions
from flytekit.common.types import base_sdk_types as _base_sdk_types
from flytekit.models import literals as _literals
from flytekit.models import types as _idl_types
from flytekit.models.common import FlyteIdlEntity, FlyteType
from flytekit.models.types import LiteralType


def create_protobuf(pb_type):
"""
:param T pb_type:
:rtype: ProtobufType
"""
if not isinstance(pb_type, _proto_reflection.GeneratedProtocolMessageType):
raise _user_exceptions.FlyteTypeException(
expected_type=_proto_reflection.GeneratedProtocolMessageType,
received_type=type(pb_type),
received_value=pb_type,
)

class _Protobuf(Protobuf):
_pb_type = pb_type

return _Protobuf
ProtobufT = Type[_proto_reflection.GeneratedProtocolMessageType]


class ProtobufType(_base_sdk_types.FlyteSdkType):
_pb_type = Struct

@property
def pb_type(cls):
def pb_type(cls) -> GeneratedProtocolMessageType:
"""
:rtype: GeneratedProtocolMessageType
"""
Expand All @@ -51,15 +45,34 @@ def tag(cls):


class Protobuf(_base_sdk_types.FlyteSdkValue, metaclass=ProtobufType):

PB_FIELD_KEY = "pb_type"
TAG_PREFIX = "{}=".format(PB_FIELD_KEY)

def __init__(self, pb_object):
def __init__(self, pb_object: Union[GeneratedProtocolMessageType, FlyteIdlEntity]):
"""
:param T pb_object:
:param Union[T, FlyteIdlEntity] pb_object:
"""
data = pb_object.SerializeToString()
v = pb_object
# This section converts an existing proto object (or a subclass of) to the right type expected by this instance
# of GenericProto. GenericProto can be used with any protobuf type (not restricted to FlyteType). This makes it
# a bit tricky to figure out the right version of the underlying raw proto class to use to populate the final
# struct.
# If the provided object has to_flyte_idl(), call it to produce a raw proto.
if isinstance(pb_object, FlyteIdlEntity):
v = pb_object.to_flyte_idl()

# A check to ensure the raw proto (v) is of the correct expected type. This also performs one final attempt to
# convert it to the correct type by leveraging from_flyte_idl (implemented by all FlyteTypes) in case this class
# is initialized with one.
expected_type = type(self).pb_type
if expected_type != type(v) and expected_type != type(pb_object):
if isinstance(type(self).pb_type, FlyteType):
v = expected_type.from_flyte_idl(v).to_flyte_idl()
else:
raise _user_exceptions.FlyteTypeException(
received_type=type(pb_object), expected_type=expected_type, received_value=pb_object
)
data = v.SerializeToString()
super(Protobuf, self).__init__(
scalar=_literals.Scalar(
binary=_literals.Binary(value=bytes(data) if _six.PY2 else data, tag=type(self).tag)
Expand Down Expand Up @@ -97,7 +110,7 @@ def from_python_std(cls, t_value):
"""
if t_value is None:
return _base_sdk_types.Void()
elif isinstance(t_value, cls.pb_type):
elif isinstance(t_value, cls.pb_type) or isinstance(t_value, FlyteIdlEntity):
return cls(t_value)
else:
raise _user_exceptions.FlyteTypeException(type(t_value), cls.pb_type, received_value=t_value)
Expand Down Expand Up @@ -148,3 +161,159 @@ def short_string(self):
:rtype: Text
"""
return "{}".format(self.to_python_std())


def create_protobuf(pb_type: Type[GeneratedProtocolMessageType]) -> Type[Protobuf]:
"""
:param Type[GeneratedProtocolMessageType] pb_type:
:rtype: Type[Protobuf]
"""
if not isinstance(pb_type, _proto_reflection.GeneratedProtocolMessageType):
raise _user_exceptions.FlyteTypeException(
expected_type=_proto_reflection.GeneratedProtocolMessageType,
received_type=type(pb_type),
received_value=pb_type,
)

class _Protobuf(Protobuf):
_pb_type = pb_type

return _Protobuf


class GenericProtobuf(_base_sdk_types.FlyteSdkValue, metaclass=ProtobufType):
PB_FIELD_KEY = "pb_type"
TAG_PREFIX = "{}=".format(PB_FIELD_KEY)

def __init__(self, pb_object: Union[GeneratedProtocolMessageType, FlyteIdlEntity]):
"""
:param Union[T, FlyteIdlEntity] pb_object:
"""
struct = Struct()
v = pb_object

# This section converts an existing proto object (or a subclass of) to the right type expected by this instance
# of GenericProto. GenericProto can be used with any protobuf type (not restricted to FlyteType). This makes it
# a bit tricky to figure out the right version of the underlying raw proto class to use to populate the final
# struct.
# If the provided object has to_flyte_idl(), call it to produce a raw proto.
if isinstance(pb_object, FlyteIdlEntity):
v = pb_object.to_flyte_idl()

# A check to ensure the raw proto (v) is of the correct expected type. This also performs one final attempt to
# convert it to the correct type by leveraging from_flyte_idl (implemented by all FlyteTypes) in case this class
# is initialized with one.
expected_type = type(self).pb_type
if expected_type != type(v) and expected_type != type(pb_object):
if isinstance(type(self).pb_type, FlyteType):
v = expected_type.from_flyte_idl(v).to_flyte_idl()
else:
raise _user_exceptions.FlyteTypeException(
received_type=type(pb_object), expected_type=expected_type, received_value=pb_object
)

struct.update(_MessageToDict(v))
super().__init__(scalar=_literals.Scalar(generic=struct,))

@classmethod
def is_castable_from(cls, other):
"""
:param flytekit.common.types.base_literal_types.FlyteSdkType other:
:rtype: bool
"""
return isinstance(other, ProtobufType) and other.pb_type is cls.pb_type

@classmethod
def from_python_std(cls, t_value: Union[GeneratedProtocolMessageType, FlyteIdlEntity]):
"""
:param Union[T, FlyteIdlEntity] t_value: It is up to each individual object as to whether or not this value can be cast.
:rtype: _base_sdk_types.FlyteSdkValue
:raises: flytekit.common.exceptions.user.FlyteTypeException
"""
if t_value is None:
return _base_sdk_types.Void()
elif isinstance(t_value, cls.pb_type) or isinstance(t_value, FlyteIdlEntity):
return cls(t_value)
else:
raise _user_exceptions.FlyteTypeException(type(t_value), cls.pb_type, received_value=t_value)

@classmethod
def to_flyte_literal_type(cls) -> LiteralType:
"""
:rtype: flytekit.models.types.LiteralType
"""
return _idl_types.LiteralType(simple=_idl_types.SimpleType.STRUCT, metadata={cls.PB_FIELD_KEY: cls.descriptor},)

@classmethod
def promote_from_model(cls, literal_model):
"""
Creates an object of this type from the model primitive defining it.
:param flytekit.models.literals.Literal literal_model:
:rtype: Protobuf
"""
pb_obj = cls.pb_type()
try:
dictionary = _MessageToDict(literal_model.scalar.generic)
pb_obj = _ParseDict(dictionary, pb_obj)
except Error as err:
raise _user_exceptions.FlyteTypeException(
received_type="generic",
expected_type=cls.pb_type,
received_value=_base64.b64encode(literal_model.scalar.generic),
additional_msg=f"Can not deserialize. Error: {err.__str__()}",
)

return cls(pb_obj)

@classmethod
def short_class_string(cls) -> str:
"""
:rtype: Text
"""
return "Types.GenericProto({})".format(cls.descriptor)

def to_python_std(self):
"""
:returns: The protobuf object as defined by the user.
:rtype: T
"""
pb_obj = type(self).pb_type()
try:
dictionary = _MessageToDict(self.scalar.generic)
pb_obj = _ParseDict(dictionary, pb_obj)
except Error as err:
raise _user_exceptions.FlyteTypeException(
received_type="generic",
expected_type=type(self).pb_type,
received_value=_base64.b64encode(self.scalar.generic),
additional_msg=f"Can not deserialize. Error: {err.__str__()}",
)
return pb_obj

def short_string(self) -> str:
"""
:rtype: Text
"""
return "{}".format(self.to_python_std())


def create_generic(pb_type: Type[GeneratedProtocolMessageType]) -> Type[GenericProtobuf]:
"""
Creates a generic protobuf type that represents protobuf type ProtobufT and that will get serialized into a struct.
:param Type[GeneratedProtocolMessageType] pb_type:
:rtype: Type[GenericProtobuf]
"""
if not isinstance(pb_type, _proto_reflection.GeneratedProtocolMessageType) and not issubclass(
pb_type, FlyteIdlEntity
):
raise _user_exceptions.FlyteTypeException(
expected_type=_proto_reflection.GeneratedProtocolMessageType,
received_type=type(pb_type),
received_value=pb_type,
)

class _Protobuf(GenericProtobuf):
_pb_type = pb_type

return _Protobuf
Loading

0 comments on commit ebcf501

Please sign in to comment.