-
Notifications
You must be signed in to change notification settings - Fork 304
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
TypeTransformer for Keras #1242
TypeTransformer for Keras #1242
Conversation
9bb6a88
to
1a99a2f
Compare
Signed-off-by: Ryan Nazareth <[email protected]>
Signed-off-by: Ryan Nazareth <[email protected]>
Signed-off-by: Ryan Nazareth <[email protected]>
Signed-off-by: Ryan Nazareth <[email protected]>
1a99a2f
to
3216e8b
Compare
dev-requirements.in
Outdated
@@ -13,3 +13,4 @@ google-cloud-bigquery | |||
google-cloud-bigquery-storage | |||
IPython | |||
torch | |||
tensorflow<=2.8.1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
see my comment in #1240 (comment) for reasons for pinning tf versions.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ignore this - ive pinned grpcio-status<1.49.0
instead based on suggestion from @pingsutw in another PR, which fixed it !
Signed-off-by: Ryan Nazareth <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There is an error in test_to_python_value_and_literal
def test_to_python_value_and_literal(transformer, python_type, format, python_val):
ctx = context_manager.FlyteContext.current_context()
tf = transformer
lt = tf.get_literal_type(python_type)
lv = tf.to_literal(ctx, python_val, type(python_val), lt) # type: ignore
assert lv.scalar.blob.metadata == BlobMetadata(
type=BlobType(
format=format,
dimensionality=BlobType.BlobDimensionality.SINGLE,
)
)
assert lv.scalar.blob.uri is not None
output = tf.to_python_value(ctx, lv, python_type)
if isinstance(python_val, keras.Sequential):
for p1, p2 in zip(output.weights, python_val.weights):
np.testing.assert_array_equal(p1.numpy(), p2.numpy())
assert True
else:
> assert isinstance(output, dict)
E assert False
E + where False = isinstance(<keras.engine.functional.Functional object at 0x000001AE37320108>, dict)
Signed-off-by: Ryan Nazareth <[email protected]>
Ahh yes, forgot to account for keras.Model in the check, after the refactor ....... just pushed fix now |
Codecov Report
@@ Coverage Diff @@
## master #1242 +/- ##
==========================================
+ Coverage 68.68% 68.71% +0.02%
==========================================
Files 288 292 +4
Lines 26333 26482 +149
Branches 2486 2494 +8
==========================================
+ Hits 18087 18196 +109
- Misses 7768 7805 +37
- Partials 478 481 +3
Help us with your feedback. Take ten seconds to tell us how you rate us. Have a feature suggestion? Share it here. |
Signed-off-by: Ryan Nazareth [email protected]
TL;DR
Adds a typetransformer to support
keras.Model
andkeras.Sequential
as native typesType
Are all requirements met?
Complete description
__init__.py
file handles case where the user doesn't have keras installed.Tracking Issue
flyteorg/flyte#2759
Follow-up issue
N/A