Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SageMaker custom types for ParameterRangeOneOf and HyperparameterConfig. TunableParams extraction #189

Merged
merged 42 commits into from
Oct 16, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
42 commits
Select commit Hold shift + click to select a range
a248037
Typing
EngHabu Sep 30, 2020
3883b03
Merge branch 'master' into sm-custom-type
EngHabu Sep 30, 2020
2198781
ParameterRangeOneOf model
EngHabu Oct 4, 2020
bfc9991
cleanup
EngHabu Oct 4, 2020
b83d810
lint
EngHabu Oct 4, 2020
9ebf5c9
Merge branch 'master' into sm-custom-type
EngHabu Oct 4, 2020
243d0e7
lint
EngHabu Oct 4, 2020
d32de2a
unittests
EngHabu Oct 4, 2020
ca2c8fb
unit
EngHabu Oct 4, 2020
9c9646d
unit
EngHabu Oct 4, 2020
105f58a
isort
EngHabu Oct 4, 2020
3cb7a0c
isort
EngHabu Oct 4, 2020
7851288
py 3.5
EngHabu Oct 4, 2020
6b7d89a
Merge branch 'master' into sm-custom-type
EngHabu Oct 4, 2020
29c4aa7
lint
EngHabu Oct 4, 2020
8830680
Merge branch 'master' into sm-custom-type
EngHabu Oct 4, 2020
454ec9d
lint
EngHabu Oct 4, 2020
e55ef83
re-add generic types to protos
EngHabu Oct 4, 2020
a5c44a5
lint
EngHabu Oct 4, 2020
9638801
Remove generic
EngHabu Oct 5, 2020
6b554cb
Merge branch 'master' into sm-custom-type
EngHabu Oct 5, 2020
defee18
remove T
EngHabu Oct 5, 2020
6876251
Merge branch 'sm-custom-type' of github.com:lyft/flytekit into sm-cus…
EngHabu Oct 5, 2020
93a1e29
Merge branch 'master' into sm-custom-type
EngHabu Oct 6, 2020
a9ee07c
Merge branch 'master' into sm-custom-type
EngHabu Oct 7, 2020
9af0fb4
reformat
EngHabu Oct 7, 2020
591e084
fix import
EngHabu Oct 7, 2020
d334211
PR Comments
EngHabu Oct 7, 2020
e09bd52
lint
EngHabu Oct 7, 2020
cd898f5
lint
EngHabu Oct 7, 2020
3eb21ee
PR Comments
EngHabu Oct 8, 2020
e876fb7
PR Comments
EngHabu Oct 8, 2020
b0cc0cb
lint
EngHabu Oct 8, 2020
e1fdaf6
unittest
EngHabu Oct 8, 2020
3b122c5
Merge master
EngHabu Oct 9, 2020
c026081
PR Comments
EngHabu Oct 10, 2020
23aa95c
lint
EngHabu Oct 10, 2020
c48d496
remove deprecated fields
EngHabu Oct 10, 2020
5c54947
Support converting raw protos through Types.*Proto classes
EngHabu Oct 15, 2020
800251f
lint
EngHabu Oct 15, 2020
c6b2667
revert notebook.py
EngHabu Oct 15, 2020
4807d5e
lint
EngHabu Oct 15, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 36 additions & 15 deletions flytekit/common/tasks/sagemaker/hpo_job_task.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import datetime as _datetime
import typing

from flyteidl.plugins.sagemaker import hyperparameter_tuning_job_pb2 as _pb2_hpo_job
from google.protobuf.json_format import MessageToDict

from flytekit import __version__
Expand All @@ -10,13 +9,13 @@
from flytekit.common.tasks import task as _sdk_task
from flytekit.common.tasks.sagemaker.built_in_training_job_task import SdkBuiltinAlgorithmTrainingJobTask
from flytekit.common.tasks.sagemaker.custom_training_job_task import CustomTrainingJobTask
from flytekit.common.tasks.sagemaker.types import HyperparameterTuningJobConfig, ParameterRange
from flytekit.models import interface as _interface_model
from flytekit.models import literals as _literal_models
from flytekit.models import task as _task_models
from flytekit.models import types as _types_models
from flytekit.models.core import types as _core_types
from flytekit.models.sagemaker import hpo_job as _hpo_job_model
from flytekit.sdk import types as _sdk_types


class SdkSimpleHyperparameterTuningJobTask(_sdk_task.SdkTask):
Expand All @@ -28,17 +27,24 @@ def __init__(
retries: int = 0,
cacheable: bool = False,
cache_version: str = "",
tunable_parameters: typing.List[str] = None,
):
"""

:param max_number_of_training_jobs: The maximum number of training jobs that can be launched by this
:param int max_number_of_training_jobs: The maximum number of training jobs that can be launched by this
hyperparameter tuning job
:param max_parallel_training_jobs: The maximum number of training jobs that can launched by this hyperparameter
:param int max_parallel_training_jobs: The maximum number of training jobs that can launched by this hyperparameter
tuning job in parallel
:param training_job: The reference to the training job definition
:param retries: Number of retries to attempt
:param cacheable: The flag to set if the user wants the output of the task execution to be cached
:param cache_version: String describing the caching version for task discovery purposes
:param typing.Union[SdkBuiltinAlgorithmTrainingJobTask, CustomTrainingJobTask] training_job: The reference to the training job definition
:param int retries: Number of retries to attempt
:param bool cacheable: The flag to set if the user wants the output of the task execution to be cached
:param str cache_version: String describing the caching version for task discovery purposes
:param typing.List[str] tunable_parameters: A list of parameters that to tune. If you are tuning a built-int
algorithm, refer to the algorithm's documentation to understand the possible values for the tunable
parameters. E.g. Refer to https://docs.aws.amazon.com/sagemaker/latest/dg/IC-Hyperparameter.html for the
list of hyperparameters for Image Classification built-in algorithm. If you are passing a custom
training job, the list of tunable parameters must be a strict subset of the list of inputs defined on
that job. Refer to https://docs.aws.amazon.com/sagemaker/latest/dg/automatic-model-tuning-define-ranges.html
for the list of supported hyperparameter types.
"""
# Use the training job model as a measure of type checking
hpo_job = _hpo_job_model.HyperparameterTuningJob(
Expand All @@ -52,14 +58,25 @@ def __init__(
# TODO: Discuss whether this is a viable interface or contract
timeout = _datetime.timedelta(seconds=0)

inputs = {
"hyperparameter_tuning_job_config": _interface_model.Variable(
_sdk_types.Types.Proto(_pb2_hpo_job.HyperparameterTuningJobConfig).to_flyte_literal_type(), "",
),
}
inputs = {}
inputs.update(training_job.interface.inputs)
inputs.update(
{
"hyperparameter_tuning_job_config": _interface_model.Variable(
HyperparameterTuningJobConfig.to_flyte_literal_type(), "",
),
}
)

if tunable_parameters:
inputs.update(
{
param: _interface_model.Variable(ParameterRange.to_flyte_literal_type(), "")
for param in tunable_parameters
}
)

super(SdkSimpleHyperparameterTuningJobTask, self).__init__(
super().__init__(
type=SdkTaskType.SAGEMAKER_HYPERPARAMETER_TUNING_JOB_TASK,
metadata=_task_models.TaskMetadata(
runtime=_task_models.RuntimeMetadata(
Expand Down Expand Up @@ -87,3 +104,7 @@ def __init__(
),
custom=MessageToDict(hpo_job),
)

def __call__(self, *args, **kwargs):
# Overriding the call function just so we clear up the docs and avoid IDEs complaining about the signature.
return super().__call__(*args, **kwargs)
7 changes: 7 additions & 0 deletions flytekit/common/tasks/sagemaker/types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from flytekit.models.sagemaker import hpo_job as _hpo_models
from flytekit.models.sagemaker import parameter_ranges as _parameter_range_models
from flytekit.sdk import types as _sdk_types

HyperparameterTuningJobConfig = _sdk_types.Types.GenericProto(_hpo_models.HyperparameterTuningJobConfig)

ParameterRange = _sdk_types.Types.GenericProto(_parameter_range_models.ParameterRangeOneOf)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we are wrapping a model instead, should we still call it Generic"Proto"?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

models should be thought of as "glorified" protos (or pythonic-protos)... so same assumptions about them should apply...

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

makes sense

4 changes: 3 additions & 1 deletion flytekit/common/types/base_sdk_types.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import abc as _abc

from flyteidl.core.literals_pb2 import Literal

from flytekit.common import sdk_bases as _sdk_bases
from flytekit.common.exceptions import user as _user_exceptions
from flytekit.models import common as _common_models
Expand Down Expand Up @@ -54,7 +56,7 @@ def __hash__(cls):

class FlyteSdkValue(_literal_models.Literal, metaclass=FlyteSdkType):
@classmethod
def from_flyte_idl(cls, pb2_object):
def from_flyte_idl(cls, pb2_object: Literal):
"""
:param flyteidl.core.literals_pb2.Literal pb2_object:
:rtype: FlyteSdkValue
Expand Down
215 changes: 192 additions & 23 deletions flytekit/common/types/proto.py
Original file line number Diff line number Diff line change
@@ -1,35 +1,29 @@
import base64 as _base64
from typing import Type, Union

import six as _six
from google.protobuf import reflection as _proto_reflection
from google.protobuf.json_format import Error
from google.protobuf.json_format import MessageToDict as _MessageToDict
from google.protobuf.json_format import ParseDict as _ParseDict
from google.protobuf.reflection import GeneratedProtocolMessageType
from google.protobuf.struct_pb2 import Struct

from flytekit.common.exceptions import user as _user_exceptions
from flytekit.common.types import base_sdk_types as _base_sdk_types
from flytekit.models import literals as _literals
from flytekit.models import types as _idl_types
from flytekit.models.common import FlyteIdlEntity, FlyteType
from flytekit.models.types import LiteralType


def create_protobuf(pb_type):
"""
:param T pb_type:
:rtype: ProtobufType
"""
if not isinstance(pb_type, _proto_reflection.GeneratedProtocolMessageType):
raise _user_exceptions.FlyteTypeException(
expected_type=_proto_reflection.GeneratedProtocolMessageType,
received_type=type(pb_type),
received_value=pb_type,
)

class _Protobuf(Protobuf):
_pb_type = pb_type

return _Protobuf
ProtobufT = Type[_proto_reflection.GeneratedProtocolMessageType]


class ProtobufType(_base_sdk_types.FlyteSdkType):
_pb_type = Struct

@property
def pb_type(cls):
def pb_type(cls) -> GeneratedProtocolMessageType:
"""
:rtype: GeneratedProtocolMessageType
"""
Expand All @@ -51,15 +45,34 @@ def tag(cls):


class Protobuf(_base_sdk_types.FlyteSdkValue, metaclass=ProtobufType):

PB_FIELD_KEY = "pb_type"
TAG_PREFIX = "{}=".format(PB_FIELD_KEY)

def __init__(self, pb_object):
def __init__(self, pb_object: Union[GeneratedProtocolMessageType, FlyteIdlEntity]):
"""
:param T pb_object:
:param Union[T, FlyteIdlEntity] pb_object:
"""
data = pb_object.SerializeToString()
v = pb_object
# This section converts an existing proto object (or a subclass of) to the right type expected by this instance
# of GenericProto. GenericProto can be used with any protobuf type (not restricted to FlyteType). This makes it
# a bit tricky to figure out the right version of the underlying raw proto class to use to populate the final
# struct.
# If the provided object has to_flyte_idl(), call it to produce a raw proto.
if isinstance(pb_object, FlyteIdlEntity):
v = pb_object.to_flyte_idl()

# A check to ensure the raw proto (v) is of the correct expected type. This also performs one final attempt to
# convert it to the correct type by leveraging from_flyte_idl (implemented by all FlyteTypes) in case this class
# is initialized with one.
expected_type = type(self).pb_type
if expected_type != type(v) and expected_type != type(pb_object):
if isinstance(type(self).pb_type, FlyteType):
v = expected_type.from_flyte_idl(v).to_flyte_idl()
else:
raise _user_exceptions.FlyteTypeException(
received_type=type(pb_object), expected_type=expected_type, received_value=pb_object
)
data = v.SerializeToString()
super(Protobuf, self).__init__(
scalar=_literals.Scalar(
binary=_literals.Binary(value=bytes(data) if _six.PY2 else data, tag=type(self).tag)
Expand Down Expand Up @@ -97,7 +110,7 @@ def from_python_std(cls, t_value):
"""
if t_value is None:
return _base_sdk_types.Void()
elif isinstance(t_value, cls.pb_type):
elif isinstance(t_value, cls.pb_type) or isinstance(t_value, FlyteIdlEntity):
return cls(t_value)
else:
raise _user_exceptions.FlyteTypeException(type(t_value), cls.pb_type, received_value=t_value)
Expand Down Expand Up @@ -148,3 +161,159 @@ def short_string(self):
:rtype: Text
"""
return "{}".format(self.to_python_std())


def create_protobuf(pb_type: Type[GeneratedProtocolMessageType]) -> Type[Protobuf]:
"""
:param Type[GeneratedProtocolMessageType] pb_type:
:rtype: Type[Protobuf]
"""
if not isinstance(pb_type, _proto_reflection.GeneratedProtocolMessageType):
raise _user_exceptions.FlyteTypeException(
expected_type=_proto_reflection.GeneratedProtocolMessageType,
received_type=type(pb_type),
received_value=pb_type,
)

class _Protobuf(Protobuf):
_pb_type = pb_type

return _Protobuf


class GenericProtobuf(_base_sdk_types.FlyteSdkValue, metaclass=ProtobufType):
PB_FIELD_KEY = "pb_type"
TAG_PREFIX = "{}=".format(PB_FIELD_KEY)

def __init__(self, pb_object: Union[GeneratedProtocolMessageType, FlyteIdlEntity]):
"""
:param Union[T, FlyteIdlEntity] pb_object:
"""
struct = Struct()
v = pb_object

# This section converts an existing proto object (or a subclass of) to the right type expected by this instance
# of GenericProto. GenericProto can be used with any protobuf type (not restricted to FlyteType). This makes it
# a bit tricky to figure out the right version of the underlying raw proto class to use to populate the final
# struct.
# If the provided object has to_flyte_idl(), call it to produce a raw proto.
if isinstance(pb_object, FlyteIdlEntity):
v = pb_object.to_flyte_idl()

# A check to ensure the raw proto (v) is of the correct expected type. This also performs one final attempt to
# convert it to the correct type by leveraging from_flyte_idl (implemented by all FlyteTypes) in case this class
# is initialized with one.
expected_type = type(self).pb_type
if expected_type != type(v) and expected_type != type(pb_object):
if isinstance(type(self).pb_type, FlyteType):
v = expected_type.from_flyte_idl(v).to_flyte_idl()
else:
raise _user_exceptions.FlyteTypeException(
received_type=type(pb_object), expected_type=expected_type, received_value=pb_object
)

struct.update(_MessageToDict(v))
super().__init__(scalar=_literals.Scalar(generic=struct,))

@classmethod
def is_castable_from(cls, other):
"""
:param flytekit.common.types.base_literal_types.FlyteSdkType other:
:rtype: bool
"""
return isinstance(other, ProtobufType) and other.pb_type is cls.pb_type

@classmethod
def from_python_std(cls, t_value: Union[GeneratedProtocolMessageType, FlyteIdlEntity]):
"""
:param Union[T, FlyteIdlEntity] t_value: It is up to each individual object as to whether or not this value can be cast.
:rtype: _base_sdk_types.FlyteSdkValue
:raises: flytekit.common.exceptions.user.FlyteTypeException
"""
if t_value is None:
return _base_sdk_types.Void()
elif isinstance(t_value, cls.pb_type) or isinstance(t_value, FlyteIdlEntity):
return cls(t_value)
else:
raise _user_exceptions.FlyteTypeException(type(t_value), cls.pb_type, received_value=t_value)

@classmethod
def to_flyte_literal_type(cls) -> LiteralType:
"""
:rtype: flytekit.models.types.LiteralType
"""
return _idl_types.LiteralType(simple=_idl_types.SimpleType.STRUCT, metadata={cls.PB_FIELD_KEY: cls.descriptor},)

@classmethod
def promote_from_model(cls, literal_model):
"""
Creates an object of this type from the model primitive defining it.
:param flytekit.models.literals.Literal literal_model:
:rtype: Protobuf
"""
pb_obj = cls.pb_type()
try:
dictionary = _MessageToDict(literal_model.scalar.generic)
pb_obj = _ParseDict(dictionary, pb_obj)
except Error as err:
raise _user_exceptions.FlyteTypeException(
received_type="generic",
expected_type=cls.pb_type,
received_value=_base64.b64encode(literal_model.scalar.generic),
additional_msg=f"Can not deserialize. Error: {err.__str__()}",
)

return cls(pb_obj)

@classmethod
def short_class_string(cls) -> str:
"""
:rtype: Text
"""
return "Types.GenericProto({})".format(cls.descriptor)

def to_python_std(self):
"""
:returns: The protobuf object as defined by the user.
:rtype: T
"""
pb_obj = type(self).pb_type()
try:
dictionary = _MessageToDict(self.scalar.generic)
pb_obj = _ParseDict(dictionary, pb_obj)
except Error as err:
raise _user_exceptions.FlyteTypeException(
received_type="generic",
expected_type=type(self).pb_type,
received_value=_base64.b64encode(self.scalar.generic),
additional_msg=f"Can not deserialize. Error: {err.__str__()}",
)
return pb_obj

def short_string(self) -> str:
"""
:rtype: Text
"""
return "{}".format(self.to_python_std())


def create_generic(pb_type: Type[GeneratedProtocolMessageType]) -> Type[GenericProtobuf]:
"""
Creates a generic protobuf type that represents protobuf type ProtobufT and that will get serialized into a struct.

:param Type[GeneratedProtocolMessageType] pb_type:
:rtype: Type[GenericProtobuf]
"""
if not isinstance(pb_type, _proto_reflection.GeneratedProtocolMessageType) and not issubclass(
pb_type, FlyteIdlEntity
):
raise _user_exceptions.FlyteTypeException(
expected_type=_proto_reflection.GeneratedProtocolMessageType,
received_type=type(pb_type),
received_value=pb_type,
)

class _Protobuf(GenericProtobuf):
_pb_type = pb_type

return _Protobuf
Loading