Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Corby's numerically more stable self attn version #118

Merged
merged 2 commits into from
Sep 29, 2021
Merged

Conversation

stas00
Copy link
Contributor

@stas00 stas00 commented Sep 29, 2021

This PR is @corbyrosset's suggestion at how to overcome the 104B numerical instability we have been experiencing. Quoting:

Re: 104B instability (https://huggingface.slack.com/archives/C01NHER1JLS/p1632801340055000) One thing I've encountered before is how the self-attention is computed. E.g. this line shows that the norm_factor may be multiplied after the Query * Key matrix multiplication. If the dim of Q and K are very large, the output may blow up and the norm_factor won't be able to save it.

Proposal: move the norm_factor inward, so Q and K are scaled down before matrix multiply:

        matmul_result = torch.baddbmm(
            matmul_result,
            1.0/math.sqrt(self.norm_factor) * query_layer.transpose(0, 1),   # [b * np, sq, hn]
            1.0/math.sqrt(self.norm_factor) * key_layer.transpose(0, 1).transpose(1, 2),  # [b * np, hn, sk]
            beta=0.0 if alibi is None else 1.0, alpha=1.0)

        # change view to [b, np, sq, sk]
        attention_scores = matmul_result.view(*output_size)

To make the operation mathematically equivalent, moving the norm factor inward requires taking sqrt again
if n is a scalar, A and B matrices:

n * (A dot B) === (sqrt(n) * A) dot (sqrt(n) * B)

Also thanks to @RezaYazdaniAminabadi who helped to find where this function is defined in CuBlas https://docs.nvidia.com/cuda/cublas/index.html#cublas-GemmStridedBatchedEx and which includes the definition:

C+istrideC=αop(A+istrideA)op(B+istrideB)+β(C+istrideC), for i ∈[0,batchCount−1]

the issue is alpha is multiplied after the matrix-matrix mul is done so it can cause instability

@stas00 stas00 merged commit 48902f1 into tr8-104B Sep 29, 2021
@stas00 stas00 deleted the stas00-patch-1 branch September 29, 2021 02:05
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant