Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Please don't kill BetterTransformer — 1.88x faster inference than SDPA #2083

Open
umarbutler opened this issue Oct 28, 2024 · 1 comment
Open

Comments

@umarbutler
Copy link

Feature request

I would like to request that BetterTransformer not be deprecated.

Motivation

I have come to rely on BetterTransformer significantly for accelerating RoBERTa and BERT models. I have found BetterTransformer-transformed RoBERTa models and BERT to be much faster than just using SDPA with RoBERTa and BERT (previously I was using my own implementation but I know that SPDA has been added to transformers's RoBERTa implementation but again the speed does not compare to BetterTransformer). I believe BetterTransformer's fusing of layers into a single BetterTransformer encoder block is probably responsible for the added performance gains.

Removing BetterTransformer would be a net negative if those performance gains are not replicated in the native transformers implementation of RoBERTa and BERT.

To emphasise how important BetterTransformer is to me, if it were deprecated, I would create my own private library to preserve its features.

My preference is however on having it remain.

Your contribution

This request.

@umarbutler
Copy link
Author

Here is an example I just whipped up to illustrate just how valuable BetterTransformer is:

import torch

from transformers import RobertaModel, RobertaTokenizerFast

# BEGIN CONFIG #
MODEL_NAME = 'umarbutler/emubert'
EXAMPLE_INPUT = "\
The Parliament shall, subject to this Constitution,\
have power to make laws for the peace, order, and good\
government of the Commonwealth with respect to:\
    (i) trade and commerce with other countries, and among\
        the States;\
    (ii) taxation; but so as not to discriminate between"""
# END CONFIG #

sdpa_model = RobertaModel.from_pretrained(MODEL_NAME, attn_implementation = 'sdpa').to(torch.bfloat16).to('cuda').eval()
bettertransformer_model = RobertaModel.from_pretrained(MODEL_NAME).to(torch.bfloat16).to_bettertransformer().to('cuda').eval()
tokenizer = RobertaTokenizerFast.from_pretrained(MODEL_NAME)
input_tokens = tokenizer(EXAMPLE_INPUT, return_tensors='pt').to('cuda')

with torch.inference_mode():
    # Do unbenched forward passes to control for potential caching effects.
    for _ in range(10):
        bettertransformer_model(**input_tokens)
        sdpa_model(**input_tokens)
    
    # Benchmark the models.
    %timeit bettertransformer_model(**input_tokens)
    %timeit sdpa_model(**input_tokens)

On my 4090, BetterTransformer achieves 1.93 ms ± 104 μs and SDPA achieves 3.64 ms ± 259 μs. BetterTransformer is almost 2x faster (1.88x)...

@umarbutler umarbutler changed the title Please do not deprecate BetterTransformer Please don't kill BetterTransformer — 1.88x faster inference than SDPA Oct 29, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant