From 65a353771a1173dd831c9ade2d5ea619c4ba089a Mon Sep 17 00:00:00 2001 From: Jong Wook Kim Date: Mon, 30 Sep 2024 02:38:26 -0700 Subject: [PATCH 1/2] using sdpa if available --- whisper/model.py | 49 ++++++++++++++++++++++++++++++++++++++--------- whisper/timing.py | 4 +++- 2 files changed, 43 insertions(+), 10 deletions(-) diff --git a/whisper/model.py b/whisper/model.py index a67828397..9e09a6d91 100644 --- a/whisper/model.py +++ b/whisper/model.py @@ -1,5 +1,6 @@ import base64 import gzip +from contextlib import contextmanager from dataclasses import dataclass from typing import Dict, Iterable, Optional @@ -12,6 +13,14 @@ from .decoding import detect_language as detect_language_function from .transcribe import transcribe as transcribe_function +try: + from torch.nn.functional import scaled_dot_product_attention + + SDPA_AVAILABLE = True +except (ImportError, RuntimeError, OSError): + scaled_dot_product_attention = None + SDPA_AVAILABLE = False + @dataclass class ModelDimensions: @@ -59,7 +68,19 @@ def sinusoids(length, channels, max_timescale=10000): return torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1) +@contextmanager +def disable_sdpa(): + prev_state = MultiHeadAttention.use_sdpa + try: + MultiHeadAttention.use_sdpa = False + yield + finally: + MultiHeadAttention.use_sdpa = prev_state + + class MultiHeadAttention(nn.Module): + use_sdpa = True + def __init__(self, n_state: int, n_head: int): super().__init__() self.n_head = n_head @@ -92,20 +113,30 @@ def forward( def qkv_attention( self, q: Tensor, k: Tensor, v: Tensor, mask: Optional[Tensor] = None - ): + ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: n_batch, n_ctx, n_state = q.shape scale = (n_state // self.n_head) ** -0.25 - q = q.view(*q.shape[:2], self.n_head, -1).permute(0, 2, 1, 3) * scale - k = k.view(*k.shape[:2], self.n_head, -1).permute(0, 2, 3, 1) * scale + q = q.view(*q.shape[:2], self.n_head, -1).permute(0, 2, 1, 3) + k = k.view(*k.shape[:2], self.n_head, -1).permute(0, 2, 1, 3) v = v.view(*v.shape[:2], self.n_head, -1).permute(0, 2, 1, 3) - qk = q @ k - if mask is not None: - qk = qk + mask[:n_ctx, :n_ctx] - qk = qk.float() + if SDPA_AVAILABLE and MultiHeadAttention.use_sdpa: + a = scaled_dot_product_attention( + q, k, v, is_causal=mask is not None and n_ctx > 1 + ) + out = a.permute(0, 2, 1, 3).flatten(start_dim=2) + qk = None + else: + qk = (q * scale) @ (k * scale).transpose(-1, -2) + if mask is not None: + qk = qk + mask[:n_ctx, :n_ctx] + qk = qk.float() + + w = F.softmax(qk, dim=-1).to(q.dtype) + out = (w @ v).permute(0, 2, 1, 3).flatten(start_dim=2) + qk = qk.detach() - w = F.softmax(qk, dim=-1).to(q.dtype) - return (w @ v).permute(0, 2, 1, 3).flatten(start_dim=2), qk.detach() + return out, qk class ResidualAttentionBlock(nn.Module): diff --git a/whisper/timing.py b/whisper/timing.py index b695ead0a..e5634142b 100644 --- a/whisper/timing.py +++ b/whisper/timing.py @@ -191,7 +191,9 @@ def find_alignment( for i, block in enumerate(model.decoder.blocks) ] - with torch.no_grad(): + from .model import disable_sdpa + + with torch.no_grad(), disable_sdpa(): logits = model(mel.unsqueeze(0), tokens.unsqueeze(0))[0] sampled_logits = logits[len(tokenizer.sot_sequence) :, : tokenizer.eot] token_probs = sampled_logits.softmax(dim=-1) From 3211024b5386e1e4c191029272a682d904cb2267 Mon Sep 17 00:00:00 2001 From: Jong Wook Kim Date: Mon, 30 Sep 2024 10:23:39 -0700 Subject: [PATCH 2/2] Update model.py --- whisper/model.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/whisper/model.py b/whisper/model.py index 9e09a6d91..e53744738 100644 --- a/whisper/model.py +++ b/whisper/model.py @@ -2,7 +2,7 @@ import gzip from contextlib import contextmanager from dataclasses import dataclass -from typing import Dict, Iterable, Optional +from typing import Dict, Iterable, Optional, Tuple import numpy as np import torch @@ -113,7 +113,7 @@ def forward( def qkv_attention( self, q: Tensor, k: Tensor, v: Tensor, mask: Optional[Tensor] = None - ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: n_batch, n_ctx, n_state = q.shape scale = (n_state // self.n_head) ** -0.25 q = q.view(*q.shape[:2], self.n_head, -1).permute(0, 2, 1, 3)