diff --git a/whisper/model.py b/whisper/model.py index a67828397..e53744738 100644 --- a/whisper/model.py +++ b/whisper/model.py @@ -1,7 +1,8 @@ import base64 import gzip +from contextlib import contextmanager from dataclasses import dataclass -from typing import Dict, Iterable, Optional +from typing import Dict, Iterable, Optional, Tuple import numpy as np import torch @@ -12,6 +13,14 @@ from .decoding import detect_language as detect_language_function from .transcribe import transcribe as transcribe_function +try: + from torch.nn.functional import scaled_dot_product_attention + + SDPA_AVAILABLE = True +except (ImportError, RuntimeError, OSError): + scaled_dot_product_attention = None + SDPA_AVAILABLE = False + @dataclass class ModelDimensions: @@ -59,7 +68,19 @@ def sinusoids(length, channels, max_timescale=10000): return torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1) +@contextmanager +def disable_sdpa(): + prev_state = MultiHeadAttention.use_sdpa + try: + MultiHeadAttention.use_sdpa = False + yield + finally: + MultiHeadAttention.use_sdpa = prev_state + + class MultiHeadAttention(nn.Module): + use_sdpa = True + def __init__(self, n_state: int, n_head: int): super().__init__() self.n_head = n_head @@ -92,20 +113,30 @@ def forward( def qkv_attention( self, q: Tensor, k: Tensor, v: Tensor, mask: Optional[Tensor] = None - ): + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: n_batch, n_ctx, n_state = q.shape scale = (n_state // self.n_head) ** -0.25 - q = q.view(*q.shape[:2], self.n_head, -1).permute(0, 2, 1, 3) * scale - k = k.view(*k.shape[:2], self.n_head, -1).permute(0, 2, 3, 1) * scale + q = q.view(*q.shape[:2], self.n_head, -1).permute(0, 2, 1, 3) + k = k.view(*k.shape[:2], self.n_head, -1).permute(0, 2, 1, 3) v = v.view(*v.shape[:2], self.n_head, -1).permute(0, 2, 1, 3) - qk = q @ k - if mask is not None: - qk = qk + mask[:n_ctx, :n_ctx] - qk = qk.float() + if SDPA_AVAILABLE and MultiHeadAttention.use_sdpa: + a = scaled_dot_product_attention( + q, k, v, is_causal=mask is not None and n_ctx > 1 + ) + out = a.permute(0, 2, 1, 3).flatten(start_dim=2) + qk = None + else: + qk = (q * scale) @ (k * scale).transpose(-1, -2) + if mask is not None: + qk = qk + mask[:n_ctx, :n_ctx] + qk = qk.float() + + w = F.softmax(qk, dim=-1).to(q.dtype) + out = (w @ v).permute(0, 2, 1, 3).flatten(start_dim=2) + qk = qk.detach() - w = F.softmax(qk, dim=-1).to(q.dtype) - return (w @ v).permute(0, 2, 1, 3).flatten(start_dim=2), qk.detach() + return out, qk class ResidualAttentionBlock(nn.Module): diff --git a/whisper/timing.py b/whisper/timing.py index b695ead0a..e5634142b 100644 --- a/whisper/timing.py +++ b/whisper/timing.py @@ -191,7 +191,9 @@ def find_alignment( for i, block in enumerate(model.decoder.blocks) ] - with torch.no_grad(): + from .model import disable_sdpa + + with torch.no_grad(), disable_sdpa(): logits = model(mel.unsqueeze(0), tokens.unsqueeze(0))[0] sampled_logits = logits[len(tokenizer.sot_sequence) :, : tokenizer.eot] token_probs = sampled_logits.softmax(dim=-1)