Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

关于S^2 Attention在半精度下的微调grad_norm为nan #4160

Closed
1 task done
fffffq99 opened this issue Jun 8, 2024 · 1 comment
Closed
1 task done

关于S^2 Attention在半精度下的微调grad_norm为nan #4160

fffffq99 opened this issue Jun 8, 2024 · 1 comment
Labels
solved This problem has been already solved

Comments

@fffffq99
Copy link

fffffq99 commented Jun 8, 2024

Reminder

  • I have read the README and searched the existing issues.

System Info

我尝试使用longlora的微调方法来微调llm,当我开启S^2 Attention,我发现当我使用精度类型为fp16时,终端总是显示grad_norm为nan,甚至当我使用示例数据“identity.json“微调时,尽管loss显示为0,grad_norm仍旧为nan。
但是,我不做任何修改,精度类型换为fp32时,grad_norm将会很正常。
我的pytorch版本是2.2.2。

Reproduction

Train

Expected behavior

No response

Others

该问题已经被解决,问题出现在longlora.py#cat
attn_output = torch.cat( ( attn_output[:, :, : self.num_heads // 2], attn_output[:, :, self.num_heads // 2 :].roll(groupsz // 2, dims=1), ) )
这里的torch.cat没有指定维度dim,默认是dim=0,而逻辑上应该是dim=2。虽然后续代码执行了reshape,但是在16精度下,似乎反向传播时还是不能找到真正对应的位置。
经过测试,
指定dim=2可以解决此问题,如下:
attn_output = torch.cat( ( attn_output[:, :, : self.num_heads // 2], attn_output[:, :, self.num_heads // 2 :].roll(groupsz // 2, dims=1), ), dim=2 )
或者使用longlora原始代码的方法直接赋值:
attn_output[:, :, self.num_heads // 2 :] = attn_output[:, :, self.num_heads // 2 :].roll(groupsz // 2, dims=1)

@hiyouga hiyouga added the pending This problem is yet to be addressed label Jun 8, 2024
@hiyouga hiyouga added solved This problem has been already solved and removed pending This problem is yet to be addressed labels Jun 10, 2024
@hiyouga
Copy link
Owner

hiyouga commented Jun 10, 2024

Fixed, thank you for helping us to identify this critical issue.

Before a793e84:
image

After a793e84:
image

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
solved This problem has been already solved
Projects
None yet
Development

No branches or pull requests

2 participants