Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding distil-whisper model support to TensorRT-LLM #1061

Closed

Conversation

Bhuvanesh09
Copy link
Contributor

What this PR Does:

This PR adds support for converting huggingface's distil-whisper model weights to compatible pytorch .pt files which can be further used to build TensorRT-LLM engines.

How the PR Does it:

distil-whisper's architecture is quite similar to openai-whisper models except for the fact that the number of decoder layers are lesser. Despite this, currently it was not possible to use TensorRT-LLM with distil-whisper since the naming convention for the parameters and weights were different among the two frameworks.

This PR addresses this issue and identifies the change of pattern between the two naming conventions. The script uses regex to convert models weights from huggingface to openai-whisper format.

You can do so by running the script convert_from_distil_whisper.py as follows:

# install requirements first
pip install -r requirements.txt

# will download the model weights from huggingface and convert them to openai-whisper's pytorch format
# model is saved to ./assets/ by default
python3 convert_from_distil_whisper.py --model_name distil-whisper/distil-large-v2

# now we can build the model like before:
python3 build.py --model_name distil-large-v2 --output_dir compiled-distil-large-v2 --use_gpt_attention_plugin --use_gemm_plugin --use_bert_attention_plugin

Results and figures:

This might be useful when one wants fast latency models for short audio clips. Using TensorRT-LLM for distil-whisper provides upto 3x increase improvement in latency over huggingface.

For instance, for mean audio size of $3.5$ seconds, we are able to get a mean latency of $0.086$ seconds on an A10G.

Distribution of wav sizes vs latency:
image

@Bhuvanesh09
Copy link
Contributor Author

@symphonylyh Kindly, take a look if possible.

@yuekaizhang
Copy link

Hi @Bhuvanesh09, great work! I will take this PR into internal gitlab to do some CI test. We will credit your work on the release notes for distill-whisper model.

@kaiyux kaiyux mentioned this pull request Feb 27, 2024
@kaiyux
Copy link
Member

kaiyux commented Feb 27, 2024

Hi @Bhuvanesh09 , thanks very much for your great work. The update including your changes have been merged into the main branch, (see #1168) and we've credit you as the co-author.

We are going to close this PR. Thanks again for your support, and please let us know if you have any questions.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants