-
Notifications
You must be signed in to change notification settings - Fork 2.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
support gradient checkpointing for training #2531
Comments
Hello! I'm afraid that gradient checkpointing is not supported by the
|
Thanks for the reply! So, good news first: after some initial fiddling where I almost gave up, I did get your PR to train with a non-stock model and dataset with gradient checkpointing. You can see my runs on wandb here, it also is saving the .py scripts to each run for details. if you find it helpful, I can provide "qualitative feedback" in that PR thread on some features that work/don't work, but I'm not sure I'll have time to hunt down where exactly in the source code they are related to as they aren't "errors" per se but run OOM, or loss is always 0, etc
alright so I didn't mention it explicitly so it's a fair suggestion, but I been on that BTW, since the implementation would be tied to the PR, feel free to close this issue or link it to the PR to be closed in the future - either works for me. |
I'll link them :) Well done on getting that PR working! It doesn't have much documentation to guide you right now 😄 And please feel free to share any feedback in that PR. The refactor will be so big that I won't be able to test all combinations & settings, e.g. gradient checkpointing was likely not going to be thoroughly tested. Also, that Spearman cosine looks great (0.93 Spearman correlation based on Cosine Similarity). Is that from a hold-out testing set or from the training set?
|
Great! I'll share some points in the PR later when I have a chance to write things up a bit more. A tl;dr (all of these I have tried to validate with multiple base models) is that:
That was longer than I meant but you get the idea.
Yeah I was surprised by this, but it could make sense. It is using held-out validation samples split from the source data (n=500), but there are a couple spicy things I'm testing here at the same time so I'm going to wait till I can validate this a bit better. bonus mostly unsure where to put this. I have a question related to controlling tokenizer padding during both train and inference. I'm not sure if it's something you are specifically adjusting in the PR, but it is a problem in both "old" SBERT and still doesn't work in the v3 PR code. Should I make a new issue, ask in your PR conversation, or something else? |
Those all sound good to know! I'll be glad to learn more about them in the PR comment. By "base loss", do you mean with Matryoshka? I admittedly have only tested that with CoSENTLoss, AnglELoss, CosineSimilarityLoss and MultipleNegativeRankingLoss, but it should work with all.
I think a new issue would be best suited. I'm aware that the tokenizer settings are just harshly overridden by some old code. The big problem is trying to give people freedom to change it without also breaking a lot of old models. |
Great, will add details to the PR sometime this weekend.
I added the "base" bit because for me, using Matryoshka or not doesn't seem to 'break' training. So basically, all of those losses you listed (except Cosine) cause the train loss to be 0. wandb examples, relevant code is there for each run: MultipleNegativeRankingLoss, CoSENTLoss
ty, will do! |
Will experiment with these on Monday. Those are some very clean scripts you've created, I'm impressed 😄 |
Hi, I am trying to use the
model.fit
method withgradient_checkpointing=True
due to memory usage/total memory constraints. After having tried several variations on the idea, neither myself nor claude3/gpt4 can figure out where to put it such that it doesn't cause errors. Most errors are typically due to the model becomingNone
.example attempt
example attempt at adding it in the convoluted
models.Transformer
model class prior to creating a SentenceTransformer:results in:
the same code works without trying to enable gradient checkpointing.
I also tried passing it to
model_args
, but it doesn't accept it. Any input on whether I am doing something wrong or if this missing feature can be added to this package would be appreciated, I cannot use the library to train the intended models without it.The text was updated successfully, but these errors were encountered: