Skip to content

Commit

Permalink
Support converting raw protos through Types.*Proto classes
Browse files Browse the repository at this point in the history
  • Loading branch information
EngHabu committed Oct 15, 2020
1 parent c48d496 commit 5c54947
Show file tree
Hide file tree
Showing 5 changed files with 81 additions and 50 deletions.
8 changes: 3 additions & 5 deletions flytekit/common/tasks/sagemaker/types.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
from flyteidl.plugins.sagemaker import hyperparameter_tuning_job_pb2 as _pb2_hpo_job
from flyteidl.plugins.sagemaker import parameter_ranges_pb2 as _pb2_parameter_ranges

from flytekit.models.sagemaker import parameter_ranges as _parameter_range_models, hpo_job as _hpo_models
from flytekit.sdk import types as _sdk_types

HyperparameterTuningJobConfig = _sdk_types.Types.GenericProto(_pb2_hpo_job.HyperparameterTuningJobConfig)
HyperparameterTuningJobConfig = _sdk_types.Types.GenericProto(_hpo_models.HyperparameterTuningJobConfig)

ParameterRange = _sdk_types.Types.GenericProto(_pb2_parameter_ranges.ParameterRangeOneOf)
ParameterRange = _sdk_types.Types.GenericProto(_parameter_range_models.ParameterRangeOneOf)
41 changes: 38 additions & 3 deletions flytekit/common/types/proto.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from flytekit.common.types import base_sdk_types as _base_sdk_types
from flytekit.models import literals as _literals
from flytekit.models import types as _idl_types
from flytekit.models.common import FlyteIdlEntity
from flytekit.models.common import FlyteIdlEntity, FlyteType
from flytekit.models.types import LiteralType

ProtobufT = Type[_proto_reflection.GeneratedProtocolMessageType]
Expand Down Expand Up @@ -53,8 +53,24 @@ def __init__(self, pb_object: Union[GeneratedProtocolMessageType, FlyteIdlEntity
:param Union[T, FlyteIdlEntity] pb_object:
"""
v = pb_object
# This section converts an existing proto object (or a subclass of) to the right type expected by this instance
# of GenericProto. GenericProto can be used with any protobuf type (not restricted to FlyteType). This makes it
# a bit tricky to figure out the right version of the underlying raw proto class to use to populate the final
# struct.
# If the provided object has to_flyte_idl(), call it to produce a raw proto.
if isinstance(pb_object, FlyteIdlEntity):
v = pb_object.to_flyte_idl()

# A check to ensure the raw proto (v) is of the correct expected type. This also performs one final attempt to
# convert it to the correct type by leveraging from_flyte_idl (implemented by all FlyteTypes) in case this class
# is initialized with one.
expected_type = type(self).pb_type
if expected_type != type(v) and expected_type != type(pb_object):
if isinstance(type(self).pb_type, FlyteType):
v = expected_type.from_flyte_idl(v).to_flyte_idl()
else:
raise _user_exceptions.FlyteTypeException(received_type=type(pb_object), expected_type=expected_type,
received_value=pb_object)
data = v.SerializeToString()
super(Protobuf, self).__init__(
scalar=_literals.Scalar(
Expand Down Expand Up @@ -174,10 +190,28 @@ def __init__(self, pb_object: Union[GeneratedProtocolMessageType, FlyteIdlEntity
"""
struct = Struct()
v = pb_object

# This section converts an existing proto object (or a subclass of) to the right type expected by this instance
# of GenericProto. GenericProto can be used with any protobuf type (not restricted to FlyteType). This makes it
# a bit tricky to figure out the right version of the underlying raw proto class to use to populate the final
# struct.
# If the provided object has to_flyte_idl(), call it to produce a raw proto.
if isinstance(pb_object, FlyteIdlEntity):
v = pb_object.to_flyte_idl()

# A check to ensure the raw proto (v) is of the correct expected type. This also performs one final attempt to
# convert it to the correct type by leveraging from_flyte_idl (implemented by all FlyteTypes) in case this class
# is initialized with one.
expected_type = type(self).pb_type
if expected_type != type(v) and expected_type != type(pb_object):
if isinstance(type(self).pb_type, FlyteType):
v = expected_type.from_flyte_idl(v).to_flyte_idl()
else:
raise _user_exceptions.FlyteTypeException(received_type=type(pb_object), expected_type=expected_type,
received_value=pb_object)

struct.update(_MessageToDict(v))
super().__init__(scalar=_literals.Scalar(generic=struct,))
super().__init__(scalar=_literals.Scalar(generic=struct, ))

@classmethod
def is_castable_from(cls, other):
Expand Down Expand Up @@ -268,7 +302,8 @@ def create_generic(pb_type: Type[GeneratedProtocolMessageType]) -> Type[GenericP
:param Type[GeneratedProtocolMessageType] pb_type:
:rtype: Type[GenericProtobuf]
"""
if not isinstance(pb_type, _proto_reflection.GeneratedProtocolMessageType):
if not isinstance(pb_type, _proto_reflection.GeneratedProtocolMessageType) and not issubclass(pb_type,
FlyteIdlEntity):
raise _user_exceptions.FlyteTypeException(
expected_type=_proto_reflection.GeneratedProtocolMessageType,
received_type=type(pb_type),
Expand Down
24 changes: 17 additions & 7 deletions flytekit/models/sagemaker/parameter_ranges.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,13 +267,23 @@ def to_flyte_idl(self) -> _idl_parameter_ranges.ParameterRangeOneOf:
)

@classmethod
def from_flyte_idl(cls, pb_object: _idl_parameter_ranges.ParameterRangeOneOf):
def from_flyte_idl(cls, pb_object: Union[_idl_parameter_ranges.ParameterRangeOneOf,
_idl_parameter_ranges.IntegerParameterRange,
_idl_parameter_ranges.ContinuousParameterRange,
_idl_parameter_ranges.CategoricalParameterRange]):
param = None
if pb_object.HasField("continuous_parameter_range"):
param = ContinuousParameterRange.from_flyte_idl(pb_object.continuous_parameter_range)
elif pb_object.HasField("integer_parameter_range"):
param = IntegerParameterRange.from_flyte_idl(pb_object.integer_parameter_range)
elif pb_object.HasField("categorical_parameter_range"):
param = CategoricalParameterRange.from_flyte_idl(pb_object.categorical_parameter_range)
if isinstance(pb_object, _idl_parameter_ranges.ParameterRangeOneOf):
if pb_object.HasField("continuous_parameter_range"):
param = ContinuousParameterRange.from_flyte_idl(pb_object.continuous_parameter_range)
elif pb_object.HasField("integer_parameter_range"):
param = IntegerParameterRange.from_flyte_idl(pb_object.integer_parameter_range)
elif pb_object.HasField("categorical_parameter_range"):
param = CategoricalParameterRange.from_flyte_idl(pb_object.categorical_parameter_range)
elif isinstance(pb_object, _idl_parameter_ranges.IntegerParameterRange):
param = IntegerParameterRange.from_flyte_idl(pb_object)
elif isinstance(pb_object, _idl_parameter_ranges.ContinuousParameterRange):
param = ContinuousParameterRange.from_flyte_idl(pb_object)
elif isinstance(pb_object, _idl_parameter_ranges.CategoricalParameterRange):
param = CategoricalParameterRange.from_flyte_idl(pb_object)

return cls(param=param)
45 changes: 20 additions & 25 deletions tests/flytekit/common/workflows/notebook.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,20 @@
from flytekit.contrib.notebook.tasks import python_notebook, spark_notebook
from flytekit.sdk.tasks import inputs, outputs
from flytekit.sdk.types import Types
from flytekit.sdk.workflow import Input, workflow_class

interactive_python = python_notebook(
notebook_path="../../../../notebook-task-examples/python-notebook.ipynb",
inputs=inputs(pi=Types.Float),
outputs=outputs(out=Types.Float),
cpu_request="1",
memory_request="1G",
)

interactive_spark = spark_notebook(
notebook_path="../../../../notebook-task-examples/spark-notebook-pi.ipynb",
inputs=inputs(partitions=Types.Integer),
outputs=outputs(pi=Types.Float),
)


@workflow_class
class FlyteNotebookSparkWorkflow(object):
partitions = Input(Types.Integer, default=10)
out1 = interactive_spark(partitions=partitions)
out2 = interactive_python(pi=out1.outputs.pi)
# interactive_python = python_notebook(
# notebook_path="../../../../notebook-task-examples/python-notebook.ipynb",
# inputs=inputs(pi=Types.Float),
# outputs=outputs(out=Types.Float),
# cpu_request="1",
# memory_request="1G",
# )
#
# interactive_spark = spark_notebook(
# notebook_path="../../../../notebook-task-examples/spark-notebook-pi.ipynb",
# inputs=inputs(partitions=Types.Integer),
# outputs=outputs(pi=Types.Float),
# )
#
#
# @workflow_class
# class FlyteNotebookSparkWorkflow(object):
# partitions = Input(Types.Integer, default=10)
# out1 = interactive_spark(partitions=partitions)
# out2 = interactive_python(pi=out1.outputs.pi)
13 changes: 3 additions & 10 deletions tests/flytekit/common/workflows/sagemaker.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
ContinuousParameterRange,
HyperparameterScalingType,
IntegerParameterRange,
ParameterRangeOneOf,
)
from flytekit.models.sagemaker.training_job import (
AlgorithmName,
Expand Down Expand Up @@ -97,15 +96,9 @@ class SageMakerHPO(object):
validation=validation_dataset,
static_hyperparameters=static_hyperparameters,
hyperparameter_tuning_job_config=hyperparameter_tuning_job_config,
num_round=ParameterRangeOneOf(
IntegerParameterRange(min_value=2, max_value=8, scaling_type=HyperparameterScalingType.LINEAR)
),
max_depth=ParameterRangeOneOf(
IntegerParameterRange(min_value=5, max_value=7, scaling_type=HyperparameterScalingType.LINEAR)
),
gamma=ParameterRangeOneOf(
ContinuousParameterRange(min_value=0.0, max_value=0.3, scaling_type=HyperparameterScalingType.LINEAR)
),
num_round=IntegerParameterRange(min_value=2, max_value=8, scaling_type=HyperparameterScalingType.LINEAR),
max_depth=IntegerParameterRange(min_value=5, max_value=7, scaling_type=HyperparameterScalingType.LINEAR),
gamma=ContinuousParameterRange(min_value=0.0, max_value=0.3, scaling_type=HyperparameterScalingType.LINEAR),
)


Expand Down

0 comments on commit 5c54947

Please sign in to comment.