-
Notifications
You must be signed in to change notification settings - Fork 27.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Performance 2/6] Replace einops.rearrange with torch native ops #15804
Conversation
""" | ||
b, n, _ = t.shape # Get the batch size (b) and sequence length (n) | ||
d = t.shape[2] // h # Determine the depth per head | ||
return t.reshape(b, n, h, d) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
t.reshape(b,n,h,-1)
should achieve similar result without having to explicitly calculate d
|
||
q = _reshape(q_in) | ||
k = _reshape(k_in) | ||
v = _reshape(v_in) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This can be done in 1 line with q, k, v = (_reshape(t) for t in (q_in, k_in, v_in))
* replace rearrange to view AUTOMATIC1111#15804 * see also lllyasviel/stable-diffusion-webui-forge@79adfa8 * conditional use torch.rms_norm for torch 2.4 * fix RMSNorm() for clear: use torch.ones()
* replace rearrange to view AUTOMATIC1111#15804 * see also lllyasviel/stable-diffusion-webui-forge@79adfa8 * conditional use torch.rms_norm for torch 2.4 * fix RMSNorm() for clear: use torch.ones()
Description
According to lllyasviel/stable-diffusion-webui-forge#716 (comment),
einops.rearrange
calls in crossattn is causing extra overhead. Replacing it with torch native ops can save ~55ms/it.Screenshots/videos:
TODO
There are other places where
einops.rearrange
can be replaced by torch native ops, but this one in CrossAttn is the most critical one. Instrument the usage ofeinops.rearrange
elsewhere might also yield some improvements.Checklist: