Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support gradient checkpointing for training #2531

Open
pszemraj opened this issue Mar 7, 2024 · 7 comments · Fixed by #2449
Open

support gradient checkpointing for training #2531

pszemraj opened this issue Mar 7, 2024 · 7 comments · Fixed by #2449

Comments

@pszemraj
Copy link

pszemraj commented Mar 7, 2024

Hi, I am trying to use the model.fit method with gradient_checkpointing=True due to memory usage/total memory constraints. After having tried several variations on the idea, neither myself nor claude3/gpt4 can figure out where to put it such that it doesn't cause errors. Most errors are typically due to the model becoming None.

example attempt

example attempt at adding it in the convoluted models.Transformer model class prior to creating a SentenceTransformer:

model_name = "distilroberta-base" # supports gradient_checkpointing in transformers 
word_embedding_model = models.Transformer(model_name, max_seq_length=max_seq_length)

# Enable gradient checkpointing on the underlying transformer model
if gradient_checkpointing:
    word_embedding_model.auto_model.gradient_checkpointing_enable()

pooling_model = models.Pooling(word_embedding_model.get_word_embedding_dimension())
model = SentenceTransformer(modules=[word_embedding_model, pooling_model])

# ... (rest of the code for dataloading etc)
model.fit(
    train_objectives=[(train_dataloader, train_loss)],
    evaluator=dev_evaluator,
    epochs=num_epochs,
    evaluation_steps=int(len(train_dataloader) * 0.05),
    warmup_steps=warmup_steps,
    output_path=model_save_path,
    checkpoint_path=model_save_path,
    use_amp=torch.cuda.is_available(),
    checkpoint_save_total_limit=1,
)

results in:

# ...
File "/usr/local/lib/python3.10/dist-packages/fire/core.py", line 475, in _Fire

    component, remaining_args = _CallAndUpdateTrace(

  File "/usr/local/lib/python3.10/dist-packages/fire/core.py", line 691, in _CallAndUpdateTrace

    component = fn(*varargs, **kwargs)

  File "/content/train_sbert_from_encoder.py", line 350, in main

    model.fit(

AttributeError: 'NoneType' object has no attribute 'fit'

the same code works without trying to enable gradient checkpointing.


I also tried passing it to model_args, but it doesn't accept it. Any input on whether I am doing something wrong or if this missing feature can be added to this package would be appreciated, I cannot use the library to train the intended models without it.

@tomaarsen
Copy link
Collaborator

Hello!

I'm afraid that gradient checkpointing is not supported by the fit method right now. #2449 should introduce it, but that PR is still undergoing testing.
If you have memory issues, then the best solution would be to reduce the batch_size in the DataLoader. This should result in a linear reduction in memory as well.

  • Tom Aarsen

@pszemraj
Copy link
Author

pszemraj commented Mar 7, 2024

Thanks for the reply!

So, good news first: after some initial fiddling where I almost gave up, I did get your PR to train with a non-stock model and dataset with gradient checkpointing. You can see my runs on wandb here, it also is saving the .py scripts to each run for details.

if you find it helpful, I can provide "qualitative feedback" in that PR thread on some features that work/don't work, but I'm not sure I'll have time to hunt down where exactly in the source code they are related to as they aren't "errors" per se but run OOM, or loss is always 0, etc

  • I can point to the wandb runs that they are related to though
  • wanted to check if useful first before cluttering your PR further :)

re: batch size

alright so I didn't mention it explicitly so it's a fair suggestion, but I been on that batch_size=1 game for a while now

image


BTW, since the implementation would be tied to the PR, feel free to close this issue or link it to the PR to be closed in the future - either works for me.

@tomaarsen tomaarsen linked a pull request Mar 8, 2024 that will close this issue
7 tasks
@tomaarsen
Copy link
Collaborator

I'll link them :)

Well done on getting that PR working! It doesn't have much documentation to guide you right now 😄 And please feel free to share any feedback in that PR. The refactor will be so big that I won't be able to test all combinations & settings, e.g. gradient checkpointing was likely not going to be thoroughly tested.

Also, that Spearman cosine looks great (0.93 Spearman correlation based on Cosine Similarity). Is that from a hold-out testing set or from the training set?

  • Tom Aarsen

@pszemraj
Copy link
Author

pszemraj commented Mar 8, 2024

Great! I'll share some points in the PR later when I have a chance to write things up a bit more. A tl;dr (all of these I have tried to validate with multiple base models) is that:

  • if I use any base loss besides CosineSimilarityLoss, the training loss will just be 0 constantly
  • if I try to set preprocessing_num_workers to anything, upon attempting to start the second evaluation the process hangs
  • There is a weird increase in GPU vram usage that happens at the end of the first eval (I.e. after all eval steps are completed)

That was longer than I meant but you get the idea.

re: eval scores

Yeah I was surprised by this, but it could make sense. It is using held-out validation samples split from the source data (n=500), but there are a couple spicy things I'm testing here at the same time so I'm going to wait till I can validate this a bit better.

bonus mostly unsure where to put this. I have a question related to controlling tokenizer padding during both train and inference. I'm not sure if it's something you are specifically adjusting in the PR, but it is a problem in both "old" SBERT and still doesn't work in the v3 PR code. Should I make a new issue, ask in your PR conversation, or something else?

@tomaarsen
Copy link
Collaborator

Those all sound good to know! I'll be glad to learn more about them in the PR comment. By "base loss", do you mean with Matryoshka? I admittedly have only tested that with CoSENTLoss, AnglELoss, CosineSimilarityLoss and MultipleNegativeRankingLoss, but it should work with all.

bonus mostly unsure where to put this. I have a question related to controlling tokenizer padding during both train and inference. I'm not sure if it's something you are specifically adjusting in the PR, but it is a problem in both "old" SBERT and still doesn't work in the v3 PR code. Should I make a new issue, ask in your PR conversation, or something else?

I think a new issue would be best suited. I'm aware that the tokenizer settings are just harshly overridden by some old code. The big problem is trying to give people freedom to change it without also breaking a lot of old models.

@pszemraj
Copy link
Author

pszemraj commented Mar 8, 2024

Great, will add details to the PR sometime this weekend.

By "base loss", do you mean with Matryoshka? I admittedly have only tested that with CoSENTLoss, AnglELoss, CosineSimilarityLoss and MultipleNegativeRankingLoss, but it should work with all.

I added the "base" bit because for me, using Matryoshka or not doesn't seem to 'break' training. So basically, all of those losses you listed (except Cosine) cause the train loss to be 0. wandb examples, relevant code is there for each run: MultipleNegativeRankingLoss, CoSENTLoss

I think a new issue would be best suited. I'm aware that the tokenizer settings are just harshly overridden by some old code. The big problem is trying to give people freedom to change it without also breaking a lot of old models.

ty, will do!

@tomaarsen
Copy link
Collaborator

Will experiment with these on Monday. Those are some very clean scripts you've created, I'm impressed 😄
Obviously, the losses shouldn't all just be 0 haha.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants