Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Albert fine tuning does not always converge #831

Closed
mattdangerw opened this issue Mar 11, 2023 · 9 comments · Fixed by #1786
Closed

Albert fine tuning does not always converge #831

mattdangerw opened this issue Mar 11, 2023 · 9 comments · Fixed by #1786
Labels
type:Bug Something isn't working

Comments

@mattdangerw
Copy link
Member

It appears that albert classification sometimes fails catastrophically with our compilation defaults. On a standard colab GPU, the follow script will sometime give good results, and sometime hover at 50% accuracy.

import keras_nlp
import tensorflow_datasets as tfds

imdb_train, imdb_test = tfds.load(
    "imdb_reviews",
    split=["train", "test"],
    as_supervised=True,
    batch_size=16,
)

# Load a ALBERT model.
classifier = keras_nlp.models.AlbertClassifier.from_preset("albert_base_en_uncased")
# Fine-tune on IMDb movie reviews.
classifier.fit(imdb_train, validation_data=imdb_test)

This may be something we face with all our models to some extent, but the issue seems exacerbated on albert in particular. We should experiment with different optimizers (e.g. AdamW) and lower learning rates. We will need to run training multiple times with a fresh, randomly initialized model, to see if we are getting improvement on this problem.

@mattdangerw mattdangerw added the type:Bug Something isn't working label Mar 11, 2023
@mattdangerw
Copy link
Member Author

If anyone is feeling motivated to check this out feel free! But heads up it may be a little tedious to debug with just a colab GPU.

@abheesht17
Copy link
Collaborator

Might be a better option to use P100 on Kaggle. And also, we can reduce the sequence length to 128 to enable larger batch sizes?

@abheesht17
Copy link
Collaborator

Yep, takes 4 minutes/epoch on Kaggle with batch size 32 and sequence length 128!

@abheesht17
Copy link
Collaborator

abheesht17 commented Mar 11, 2023

@mattdangerw, seems to give consistently good results with LR = 2e-5 and Adam

https://www.kaggle.com/code/penstrokes75/kerasnlp-sequence-cls-with-albert-on-imdb-dataset?scriptVersionId=121749538&cellId=5

Edit: Oh, right. Out of 5 runs, it fails in one run (gives an accuracy of 50%). Hmmm

@NiharJani2002
Copy link

NiharJani2002 commented Mar 11, 2023

@mattdangerw
I have a solution for it. Is it correct or not

It's not uncommon for machine learning models to show varying levels of performance on different runs, especially when dealing with small datasets or complex models. In the case of the script you provided, the issue may be related to the initialization of the model's parameters or the optimization algorithm used during training.

As you suggested, one way to tackle this problem is to experiment with different optimization algorithms and learning rates. AdamW, for example, is a variant of the Adam optimizer that can help mitigate the effects of weight decay and improve generalization. You could also try using a different learning rate schedule, such as a cosine annealing schedule, to help the model converge to a good solution.

Another approach would be to use a different initialization strategy for the model's parameters. The ALBERT model uses a unique parameter-sharing strategy that involves decomposing the embedding and transformer layers into shared and unshared subspaces. This may require a different initialization strategy than traditional models. You could try initializing the model's parameters with a different distribution, such as a truncated normal distribution, and see if it improves performance.

Finally, you mentioned running training multiple times with a freshly initialized model. This is a good idea to help ensure that any improvements in performance are not just due to chance. You may also want to consider using a technique such as k-fold cross-validation to get a better estimate of the model's performance.

Overall, the key to improving the performance of the ALBERT model on this task will be to experiment with different hyperparameters and initialization strategies and to carefully monitor the model's performance over multiple runs.

Code:

Importing Libraries

import keras_nlp
import tensorflow as tf
import tensorflow_datasets as tfds

Load IMDb movie reviews dataset.

imdb_train, imdb_test = tfds.load(
"imdb_reviews",
split=["train", "test"],
as_supervised=True,
batch_size=16,
)

Define ALBERT model with custom initialization and optimizer.

initializer = tf.keras.initializers.TruncatedNormal(stddev=0.02)
optimizer = keras_nlp.optimization.AdamWeightDecay(
learning_rate=2e-5,
weight_decay_rate=0.01,
epsilon=1e-6,
exclude_from_weight_decay=["LayerNorm", "bias"],
)
classifier = keras_nlp.models.AlbertClassifier(
num_labels=2,
pretrained_model_name="albert_base_en_uncased",
initializer=initializer,
optimizer=optimizer,
)

Define learning rate schedule.

num_train_examples = len(imdb_train)
num_epochs = 3
num_train_steps = num_train_examples // classifier.batch_size * num_epochs
warmup_steps = int(0.1 * num_train_steps)
lr_scheduler = keras_nlp.lr_schedules.CosineDecayWithWarmup(
initial_learning_rate=2e-5,
decay_steps=num_train_steps,
warmup_steps=warmup_steps,
)

Compile model.

classifier.compile(
optimizer=optimizer,
loss=keras_nlp.losses.SigmoidFocalCrossEntropy(),
metrics=[keras_nlp.metrics.BinaryAccuracy()],
)

Train model.

history = classifier.fit(
imdb_train,
epochs=num_epochs,
validation_data=imdb_test,
callbacks=[keras_nlp.callbacks.LearningRateScheduler(lr_scheduler)],
)

Evaluate model.

test_loss, test_acc = classifier.evaluate(imdb_test)
print(f"Test loss: {test_loss}, Test accuracy: {test_acc}")

@abheesht17 Is it write ?

@shivance
Copy link
Collaborator

shivance commented Mar 16, 2023

Hey @mattdangerw and hey @chenmoneygithub the issue pertains with AlbertMaskedLM models as well. I've been playing around with it for convergence for fulfilment of #833. Here is the cite how albert performs against four different LRs

image

Training on IMDB dataset, full training script can be found here : https://www.kaggle.com/code/shivanshuman/does-tensorflow-task-converge

@abheesht17
Copy link
Collaborator

@shivance, did you try AdamW and some form of LR scheduler?

@mattdangerw
Copy link
Member Author

@shivance thanks that is helpful! Though I think the annoying part of this problem is the failure state is random across entire training runs.

So the analysis we are really going to need would include, say, 10 trails per optimizer/learning rate approach. It's definitely going to be compute intensive to dig into this!

More just general musings on this, the instability of fine-tuning is a somewhat well known problem for all of these models. When most papers report a GLUE score, they are really taking the top score out of, say, 5 trails per individual GLUE task. Here's a whole paper on the problem -> https://arxiv.org/pdf/2006.04884.pdf, there are some proposed solutions we can look at in there but I haven't dug too deeply.

Note that our goal should not be to remove all instability in fine-tuning (that's probably not feasible), but to provide a better default starting place than most users would find on their own. At the end of the day, if you really care about fine-tuning on a specific dataset, nothing will beat a hyper-parameter search for that specific dataset.

@mattdangerw
Copy link
Member Author

Looks like lowering the default learning rate a bit can fix most of this. Doing that now.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
type:Bug Something isn't working
Projects
None yet
Development

Successfully merging a pull request may close this issue.

4 participants