-
Notifications
You must be signed in to change notification settings - Fork 1.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Packing in SFT #805
Comments
Hi @wdykas ! Packing is a common practice and a trick to enable pre-training / fine-tuning on more sequences. I think that adding the EOS token is an enough signal for the model. The trick seems to has been introduced in the T5 paper: https://arxiv.org/pdf/1910.10683.pdf
Can you elaborate more on this question? for loss calculation we do not perform generation we only compute the next token once. |
Sorry I meant more along the lines of how can calculate loss on the tokens without cross contamination in attention between sequences? Is there code where you create special attention masks? This seemed to be an issue opened here: huggingface/transformers#25452 |
Since the sequences are generally not correlated there is no issue in cross contamination. Otherwise you would need to pass a dedicated attention mask but this is not currently done. |
thanks for the answer! One last QQ, a problem that I see in this implementation is that we create a long continuous buffer of text separated by [EOS]. Then it seems we create examples of seq_len by naively looping over https://github.com/huggingface/trl/blob/main/trl/trainer/utils.py#L579C35-L579C48 without considering EOS. |
Yes, that's the idea of packing. If you want to avoid that you can disable packing and each sequence will start at the beginning and filled with padding tokens to the max sequence length in the batch. |
I think it is better to add a feature to make the sequence packing without any kinds of cross contamination like the figure above. Specifically in the supervised finetuning stage. Because in the SFT stage, We usually have short conversations (e.g. << 4096), which may cause some "training-inference gap" if we don't add the attention mask. Notice that there is an implicit affect of the very long sequence attention (if we do not add the attention mask) -- the entropy [1]. |
Thank you for the answer! |
I am are very confused on the current sft packing implementation. My dataset is small, less than 1k rows, and very correlated. As such, based on what I have gathered in this issue, the current packing code causes my full finetune training result to go off-the-rails in terms of quality vs packing false. So my question is:
Thanks. |
Packing is mainly a throughput/efficiency technique at scale. For 1000 samples, why don't you just turn packing off? |
Because gpu waste is waste at any scale. The idea of packing allows us to reduce small/micro training time to be completed in 1/2 of the time vs packing off. Imagine 1 minute training time (packing=on) vs 2 minutes (packing=off) but repeat that N times over the course of the day for a developer that is trying to maximize work/time efficiency. We have verified the training time diff for packing=on but result quality is an issue for us. I don't think packing should be reserved for large data. For many cases, 1000 dataset for finetuning is actually overkill. Quality of data trumps dataset row count. So 1k may be tiny, but actually for a special laser focused training, it's quite large. Just that the current implementations is not optimized for correlated/small datasets, and/or that we are not using it properly, thus my questions on the two points that is causing confusion/problem for us. |
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread. |
Thanks for sharing this very important topic. Another likely very relevant consideration is that, in Huggingface's BetterTransformer (Fast Kernel in LLaMA-recipes) does not support attention_mask:
|
You can use the correct implementation of packing without cross-contamination attention from this repo: https://github.com/MeetKai/functionary/tree/main/functionary/train/packing |
@khai-meetkai will it work with Unsloth? https://github.com/unslothai/unsloth |
Not correct. you are using padding.... |
It is padding packed entries that is shorter than max. Or are you referring to something else? Packing does not mean you can't use padding. |
I think this thread is about the ConstantLengthDatasetLoader, but correct me if I'm wrong. One of the advantages of using CLDL is that you don't use any padding, thus training could be more stable in some circumstances / downstream tasks. Yes, you can pack with padding too, but it wouldn't be the same thing, |
Ok I see your point. I would argue that constantlength with no padding which naively splits just to fill max length is throwing some daset rows out the window. Now if your train set is millions of rows, this may not matter but smaller datasets, you most definitely want to pack as much into max len with no naive splitting and pad the rest. Also can you point to papers or articles showing padding in general affecting quality/stability of training? I have never trained with no padding so I have no experience in this regard. |
Is this done now? I'd like to know if TRL's implementation attends to all samples at once. I know packages such as Axolotl support multipacking. Thanks! |
Anything update? I am also interested to use packing for instruction finetuning on long-context examples. In that case, I would combine both short and long context examples. To improve training efficiency, I think packing is useful to maximize your training efficiency (GPU memory) for each batch. However, the current implementation of packing might not be optimum for the SFT datasets:
Anything new for the solution? Thanks in advance for your help. |
I understand how packing is allowed in pretraining but I was looking for some clarification on how we are allowed to pack samples for SFT with
ConstantLengthDataset
. I see that an EOS token is put between samples https://github.com/huggingface/trl/blob/main/trl/trainer/utils.py#L567 but how are attention masks handled to make sure that the samples don't attend to each others context or is just putting EOS enough signal?Could you also point me to the code for how multiple answers are handled in generation for loss calculations?
The text was updated successfully, but these errors were encountered: