-
Notifications
You must be signed in to change notification settings - Fork 5.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[feat] Add option to use (Scheduled) Huber Loss in all diffusion training pipelines to improve resilience to data corruption and better image quality #7488
Comments
Thank you for your contributions. But I think just releasing your code independently will be more justified here. If you want to add something to
You are welcome to open a PR for this, however. |
Ccing @kashif and @patil-suraj for awareness. |
agree! I also found huber to work well in the time-series setting back in the day: https://github.com/zalandoresearch/pytorch-ts/blob/master/pts/modules/gaussian_diffusion.py#L251-L252 |
@sayakpaul Very well, then I'll make a repo with the modified training scripts and the instruction to install Diffusers as a dependency, and a PR for the |
Of course thank you! Happy to help promote your work too! |
@kashif btw, using bare smooth L1 in the case of my proposed changes won't be helpful, like in the case of openai's formula for PHL. If you read torch docs for smooth L1 loss, it's torch's huber loss, divided by the constant delta https://pytorch.org/docs/stable/generated/torch.nn.SmoothL1Loss.html#torch.nn.SmoothL1Loss However, this loss has vastly different asymptotics: As beta -> 0, Smooth L1 loss converges to L1Loss, while HuberLoss converges to a constant 0 loss. When beta is 0, Smooth L1 loss is equivalent to L1 loss. As beta -> +∞, Smooth L1 loss converges to a constant 0 loss, while HuberLoss converges to MSELoss. For Smooth L1 loss, as beta varies, the L1 segment of the loss has a constant slope of 1. For HuberLoss, the slope of the L1 segment is beta. It may not be noticeable for constant parameter, it will have a profound impact in case of our scheduled huber loss, which needs to be close to L2 on the first forward diffusion timesteps. Pseudo-Huber loss, derived from a square root (see Wikipedia), however, satisfies both desired asymptotics. Additionally, torch's native Huber loss and smooth L1 loss are piecewise and not twice differentiable (differenting the parabola will yield a linear dependency and differenting the abs will result in constant line. There will be a cusp at their intersection. Having a 2+ smooth loss was one of the key proof assumptions of our attempted theorem. (And may help the others later) |
There may be even better schedule for delta depending on snr, see the discussion in kohya_ss |
…ts (#7527) * add scheduled pseudo-huber loss training scripts See #7488 * add reduction modes to huber loss * [DB Lora] *2 multiplier to huber loss cause of 1/2 a^2 conv. pairing of kohya-ss/sd-scripts@c6495de * [DB Lora] add option for smooth l1 (huber / delta) Pairing of kohya-ss/sd-scripts@dd22958 * [DB Lora] unify huber scheduling Pairing of kohya-ss/sd-scripts@19a834c * [DB Lora] add snr huber scheduler Pairing of kohya-ss/sd-scripts@47fb1a6 * fixup examples link * use snr schedule by default in DB * update all huber scripts with snr * code quality * huber: make style && make quality --------- Co-authored-by: Sayak Paul <[email protected]>
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread. Please note that issues that do not follow the contributing guidelines are likely to be ignored. |
…ts (#7527) * add scheduled pseudo-huber loss training scripts See #7488 * add reduction modes to huber loss * [DB Lora] *2 multiplier to huber loss cause of 1/2 a^2 conv. pairing of kohya-ss/sd-scripts@c6495de * [DB Lora] add option for smooth l1 (huber / delta) Pairing of kohya-ss/sd-scripts@dd22958 * [DB Lora] unify huber scheduling Pairing of kohya-ss/sd-scripts@19a834c * [DB Lora] add snr huber scheduler Pairing of kohya-ss/sd-scripts@47fb1a6 * fixup examples link * use snr schedule by default in DB * update all huber scripts with snr * code quality * huber: make style && make quality --------- Co-authored-by: Sayak Paul <[email protected]>
Is your feature request related to a problem? Please describe.
Diffusion models are known to be vulnerable to outliers in training data. Therefore it's possible for a relatively small number of corrupting samples to "poison" the model, making it unable to produce desired output, which has been exploited by the programs such as Nightshade.
One of the reasons of this vulnerability may lie in the commonly used L2 (Mean Squared Error) loss, the fundamental part of diffusion/flow models, which is also highly sensitive to ourliers, see Anscombe's quartet for some examples.
Describe the solution you'd like.
In our new paper (also my first paper 🥳) "Improving Diffusion Models's Data-Corruption Resistance using Scheduled Pseudo-Huber Loss" https://arxiv.org/abs/2403.16728 we present a novel scheme to improve the score-matching models resilience to data corruption of parts of their datasets, introducing Huber Loss -- long used in robust regression, such as when you need to restore a contour in highly noised computer vision tasks -- and Scheduled Huber Loss. Huber loss behaves exactly like L2 around zero and like L1 (Mean Absolute Error) as it tends towards infinity, making it punish the outliers less hardly than the quadratic MSE. However, a common concern is that it may hinder the models capability to learn diverse concepts and small details, that's like we introduced Scheduled Pseudo-Huber loss with the decreasing parameter, so that the loss will behave like Huber loss on early reverse-diffusion timesteps, when the image only begins to form and is most vulnerable to be lead astray, and like L2 on final timesteps, to learn fine details of the images.
Describe alternatives you've considered.
We made tests with Pseudo-Huber Loss, Scheduled Pseudo-Huber Loss and L2, and SPHL beats the rest in nearly all cases. (On the plots the Resilence is similarity to clean pictures on partially corrupted runs minus the similarity to clean pictures on clean runs, see paper for more details)
Another alternatives are data filtration, image recaptioning (what may also be vulnerable to adversarial noise) and or "diffusion purification", which would require additional resources and may be impractical in case of large models training and false negatives, which may be drastic outliers with high corrupting potential.
👀 Also we found that the Diffusers LCM training script has a wrong Pseudo-Huber Loss coefficient proportionality (and this mistake was in the original OpenAI's article about LCMs), resulting in wrong asymptotics as its parameter tends to 0 or to infinity, resulting in the most negative impact when it is timestep-scheduled. This would be nice to fix as well (maybe adding a compatibility option for previously made LCMs)
We show that our scheme works in text2speech diffusion domain as well, further supporting the claims.
Additional context.
As a side effect (which I remembered after publishing, when I was looking through the sampled pictures), Huber loss also seems to improve the "vibrancy" of pictures on clean runs, though the mechanism behind it is unknown (maybe better concept disentanglement?). I think it would be nice to include at least simply because of this effect 🙃
As I literally was behind this idea and made the experiments with modified Diffusers library, I have all the code on my hands and will make a PR soon
We also extensively tried to prove a theorem, claiming that in the event of corrupting samples present in the dataset (the third moment "skewness" of the distribution is greater than zero), the usage of Scheduled Pseudo-Huber loss with timestep-decreasing parameter will result in less KL divergence between the clean data and the data distribution generated by an ideal score-matching (e.g. diffusion) model than the usage of L2, but there was a mistake in the proof and we stuck. If you'd like to take a look at our proof attempt, PM me.
The text was updated successfully, but these errors were encountered: