Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fuse MLP in attention mechanism #14

Open
SeanNaren opened this issue Jul 11, 2022 · 1 comment
Open

Fuse MLP in attention mechanism #14

SeanNaren opened this issue Jul 11, 2022 · 1 comment
Labels
enhancement New feature or request

Comments

@SeanNaren
Copy link
Owner

Due to facebookresearch/xformers#286 we cannot currently fuse the bias/gelu/activation into a single kernel using triton. This means we're just use a standard MLP and are probably taking a perf hit.

In megatron deepspeed, they use a torch scripted GeLU/Bias function with additional flags to fuse the operation: https://github.com/bigscience-workshop/Megatron-DeepSpeed/blob/master/megatron/model/fused_bias_gelu.py

I haven't managed to get this to work, as the global settings in this file cause the xFormers rotary embeddings to fail: https://github.com/facebookresearch/xformers/blob/bcb707576c6a80eaf850aa80e8643d3497ec2bc4/xformers/components/positional_embedding/rotary.py#L21

Combining this scripted function with standard Linear operations + Dropout may give us a slight performance boost. Waiting for triton dropout support in BF16 seems like it might take some while (I think related triton-lang/triton#431)

cc @blefaudeux

@SeanNaren SeanNaren added the enhancement New feature or request label Jul 11, 2022
@SeanNaren
Copy link
Owner Author

This may also be a viable solution: facebookresearch/xformers#352

I tried this briefly but ran into errors based on the ZeRO 3 hooks. Going to re-try to see if I can get this to work!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

1 participant