Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Light-weight benchmarking script #664

Merged

Conversation

NusretOzates
Copy link
Contributor

Implemented sentiment analysis benchmark for Classifiers using IMDB review dataset for #634

@google-cla
Copy link

google-cla bot commented Jan 15, 2023

Thanks for your pull request! It looks like this may be your first contribution to a Google open source project. Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA).

View this failed invocation of the CLA check for more information.

For the most up to date status, view the checks section at the bottom of the pull request.

@chenmoneygithub chenmoneygithub self-requested a review January 15, 2023 23:40
@chenmoneygithub
Copy link
Contributor

@NusretOzates Thank you for the PR! Will take a closer look tomorrow.

@jbischof jbischof self-requested a review January 17, 2023 15:14
Copy link
Member

@mattdangerw mattdangerw left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Just a few comments.

keras_nlp/benchmarks/README.md Show resolved Hide resolved
def create_model():
for name, symbol in keras_nlp.models.__dict__.items():
if inspect.isclass(symbol) and issubclass(symbol, keras.Model):
if FLAGS.model and name != f"{FLAGS.model.capitalize()}Classifier":
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rather than all this, I would just take in the symbol name directly e.g. --model=BertClassifier. This will be a little more obvious in usage.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yea, this is my bad honestly, I put "bert" in the problem description.

.prefetch(tf.data.AUTOTUNE)
)
val_dataset = (
test_dataset.take(10000)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe rather than a hardcoded split like this, you could run tfds.load with_info. Use that to get the size of the test set, and use a fractional split here. E.g. int(test_dataset_cardinality / 2)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you can do:

test_ds_size = test_dataset.cardinality()
val_dataset = test_dataset.take(test_ds_size // 2)
...


# End time
end_time = time.time()
print(f"Total time: {end_time - start_time}")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe "Wall time", so it's clear we are just measuring elapsed time?

keras_nlp/benchmarks/README.md Show resolved Hide resolved
Copy link
Contributor

@chenmoneygithub chenmoneygithub left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great overall! A few comments.

from the root of the repository:

```sh
python3 .keras_nlp/benchmarks/sentiment_analysis.py \
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove the extra ".": .keras_nlp/ => keras_nlp/

@@ -0,0 +1,136 @@
# Copyright 2022 The KerasNLP Authors
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can use 2023 now, time flashes!


FLAGS = flags.FLAGS
flags.DEFINE_string(
"model", None, "The name of the classifier such as BertClassifier."
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add a training comma so that it's formatted to multilines.

.prefetch(tf.data.AUTOTUNE)
)
val_dataset = (
test_dataset.take(10000)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you can do:

test_ds_size = test_dataset.cardinality()
val_dataset = test_dataset.take(test_ds_size // 2)
...

def create_model():
for name, symbol in keras_nlp.models.__dict__.items():
if inspect.isclass(symbol) and issubclass(symbol, keras.Model):
if FLAGS.model and name != f"{FLAGS.model.capitalize()}Classifier":
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yea, this is my bad honestly, I put "bert" in the problem description.

@NusretOzates
Copy link
Contributor Author

Will do all of that tomorrow! Thanks for the review @chenmoneygithub @mattdangerw

…t dataset size set automatically using dataset info
@NusretOzates
Copy link
Contributor Author

@chenmoneygithub @mattdangerw, I made the necessary changes. I would like to add a TensorFlow profiler too but tensorboard and tensorboard_profiler_plugin are not in the requirements and I didn't want to touch that part 😄

Copy link
Member

@mattdangerw mattdangerw left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Just a few last comments, but after those are in I think this is all set to merge.

Agreed the tensorboard profiler bit can be a follow up.

Thanks so much for this! This is high quality code.



def check_flags():
if not FLAGS.model:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just remembered there is actually a way to do this with absl directly. flags.mark_flag_as_required("flag").

https://github.com/keras-team/keras-nlp/blob/master/examples/bert_pretraining/bert_pretrain.py#L454

import keras_nlp

# Use mixed precision for optimal performance
keras.mixed_precision.set_global_policy("mixed_float16")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually given that this is a benchmarking script, better to leave this as a flag probably. Totally valid to benchmark a model under full precision or mixed.

Can we just make this a string flag?

flags.DEFINE_string(
    "mixed_precision_policy",
    "mixed_float16",
    "The global mixed precision policy to use. E.g. 'mixed_float16' or 'float32'.",
)

.prefetch(tf.data.AUTOTUNE)
)

test_dataset_size = info.splits['test'].num_examples // 2
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe drop a comment here.

# We split the test data evenly into validation and test sets.

@NusretOzates
Copy link
Contributor Author

@mattdangerw All done👌 I'm happy to contribute and I actually would like/plan to do more!

@mattdangerw
Copy link
Member

@NusretOzates thanks! I see one small issue, we now need to make sure to only set the mixed precision policy after flags are parsed, or the script wont run. I will push a small fix.

@NusretOzates
Copy link
Contributor Author

That makes sense, thanks for the fix 😄

Copy link
Contributor

@chenmoneygithub chenmoneygithub left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The code looks great! We want to thank you for the high-quality PR again! XD

@NusretOzates
Copy link
Contributor Author

You are welcome! 😄

@chenmoneygithub chenmoneygithub merged commit fb1eeb9 into keras-team:master Jan 19, 2023
@NusretOzates NusretOzates deleted the sentinement_analysis_benchmark branch January 19, 2023 21:47
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants