-
Notifications
You must be signed in to change notification settings - Fork 636
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[bug] Pre-norm wrapper only normalizing the first input #233
Conversation
As usual, reversible is the pain point.. fixing that |
4c84b92
to
fb0b479
Compare
@@ -34,21 +26,37 @@ class LayerNormStyle(str, Enum): | |||
Post = "post" | |||
|
|||
|
|||
class RequiresWrappedInputs: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
classes which derive from this class only accept a single input list (makes it impossible to subtly footgun)
@@ -16,14 +16,6 @@ | |||
from xformers.triton.layer_norm import FusedLayerNorm | |||
|
|||
|
|||
def _to_tensor_list( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this was supposed to be a helper, but in the end it was masking bugs (in that layer(x, y, z) could have a different behaviour depending on the residual wraps). I think that it's better to force inputs in a single fashion
|
||
return inputs[0] + self.layer(*inputs, *args, **kwargs) | ||
def forward(self, inputs: List[torch.Tensor], **kwargs): | ||
if self.wrap_inputs: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the trick here is that these residual/norm wrapper can wrap themselves at times. When they wrap an external layer, then the inputs are unrolled, when the sublayer is another wrap then we maintain inputs=List[Tensor] to prevent bugs like this one
@@ -335,8 +335,8 @@ def forward( | |||
q, k, v = x, x, x | |||
|
|||
# Pre/Post norms and residual paths are already handled | |||
x = self.wrap_att(q, k, v, att_mask=att_mask) | |||
x = self.wrap_ff(x) | |||
x = self.wrap_att(inputs=[q, k, v], att_mask=att_mask) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
all the wraps require a single input list + kwargs, which I believe is more future proof (this normalizing bug cannot happen, or at least not as easily)
Codecov Report
@@ Coverage Diff @@
## main #233 +/- ##
==========================================
- Coverage 92.22% 92.17% -0.05%
==========================================
Files 60 60
Lines 3228 3247 +19
==========================================
+ Hits 2977 2993 +16
- Misses 251 254 +3
Flags with carried forward coverage won't be shown. Click here to find out more.
Continue to review full report at Codecov.
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice catch and nice fix!
@@ -15,7 +15,11 @@ | |||
from .in_proj_container import InProjContainer, InProjParams # noqa | |||
from .multi_head_dispatch import MultiHeadDispatch # noqa | |||
from .multi_head_dispatch import MultiHeadDispatchConfig | |||
from .residual import LayerNormStyle, PostNorm, PreNorm, Residual # noqa | |||
from .residual import LayerNormStyle # noqa; noqa |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: extra noqa typo?
What does this PR do?
Could well explain #219 (cc @jramapuram), I stumbled upon this bug by luck while rewriting some of the input projection part and being more strict with self-attention (in that case the tensors need to be the exact same objects, with an incoming PR, makes it easier to catch a bug if the intent is self attention).
This showed that the pre-norm wrapper was (a) only normalizing the first input (b) creating new objects because it would normalize the tensors one by one, even if it was the same tensor to begin with. This is a bug which already existed in the past and was fixed, not sure how it came back (long lived branch + botched merge maybe), hence the unit test and the change of interface to make sure that this does not happen anymore. Consequence was both correctness + speed (after the pre-norm, the tensors were not the same anymore, so the self attention speed up was off)
(a) I changed the interface to make it compulsory to pass the inputs as a list for all these wrappers, so that there's no more confusion
(b) is unit tested in this PR + the incoming PR will add another test
Before submitting
PR review
Anyone in the community is free to review the PR once the tests have passed.
If we didn't discuss your PR in Github issues there's a high chance it will not be merged.