Adding distil-whisper model support to TensorRT-LLM #1061
Closed
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
What this PR Does:
This PR adds support for converting huggingface's distil-whisper model weights to compatible pytorch
.pt
files which can be further used to build TensorRT-LLM engines.How the PR Does it:
distil-whisper's architecture is quite similar to openai-whisper models except for the fact that the number of decoder layers are lesser. Despite this, currently it was not possible to use TensorRT-LLM with distil-whisper since the naming convention for the parameters and weights were different among the two frameworks.
This PR addresses this issue and identifies the change of pattern between the two naming conventions. The script uses regex to convert models weights from huggingface to openai-whisper format.
You can do so by running the script convert_from_distil_whisper.py as follows:
Results and figures:
This might be useful when one wants fast latency models for short audio clips. Using TensorRT-LLM for distil-whisper provides upto 3x increase improvement in latency over huggingface.
For instance, for mean audio size of$3.5$ seconds, we are able to get a mean latency of $0.086$ seconds on an A10G.
Distribution of wav sizes vs latency: