diff --git a/README.md b/README.md index 300d0fe744..74d8936556 100644 --- a/README.md +++ b/README.md @@ -152,6 +152,7 @@ Patrick et al., 2021](https://arxiv.org/abs/2106.05392)* - [Sine](xformers/components/positional_embedding/sine.py) - [Vocabulary](xformers/components/positional_embedding/vocab.py) +- [Rotary](xformers/components/positional_embedding/rotary.py)

diff --git a/tests/test_attentions.py b/tests/test_attentions.py index 83db740262..ffa8bee783 100644 --- a/tests/test_attentions.py +++ b/tests/test_attentions.py @@ -237,17 +237,12 @@ def test_different_kq_dimensions( heads: int, device: torch.device, ): - if attention_name in { - "global", - "local", - "random", - "lambda", - "linformer", - "blocksparse", - }: + + multi_head = _get_multihead(attention_name, 0.0, 0.0, False, heads, device) + + if multi_head.attention.requires_same_k_q_dimensions: # pyre-fixme[29]: The library function `pytest.skip` is not supported by Pyre. pytest.skip(f"{attention_name} does not support different k, q dimensions yet.") - multi_head = _get_multihead(attention_name, 0.0, 0.0, False, heads, device) seq_q = SEQ - 16 q = torch.rand((BATCH, seq_q, MODEL), device=device) diff --git a/tests/test_block_factory.py b/tests/test_block_factory.py index f5097c6b21..24fa0bd7b8 100644 --- a/tests/test_block_factory.py +++ b/tests/test_block_factory.py @@ -211,3 +211,16 @@ def test_xformer_decoder_block( encoded = encoder_block(inputs) _ = decoder_block(inputs, encoded, encoder_att_mask=att_mask, input_mask=input_mask) + + # Test different sequence lengths when encoding and decoding + if not decoder_block.mha.attention.requires_same_k_q_dimensions: + if not causal or not hasattr(decoder_block.mha.attention, "causal"): + _ = decoder_block(inputs[:, :-16], encoded) + else: + # Check that we assert properly + with pytest.raises(AssertionError): + _ = decoder_block(inputs[:, :-16], encoded) + else: + # Check that we assert properly + with pytest.raises(AssertionError): + _ = decoder_block(inputs[:, :-16], encoded) diff --git a/tests/test_rotary_embeddings.py b/tests/test_rotary_embeddings.py index 70f6f14089..5c9e6ccbe5 100644 --- a/tests/test_rotary_embeddings.py +++ b/tests/test_rotary_embeddings.py @@ -43,6 +43,9 @@ def test_rotary_embeddings(device): 0, 0, 0, 0 ].clone() # all diagonal elements will have the same value att_rot = ( - att_rot <= 1e-5 + att_rot <= 1e-4 ) # all non diagonal elements had lower attention than diagonal (+ float tolerance) assert torch.all(att_rot) + + # Test that different sequence lengths is ok + _, _ = rotary(q[:, :, :-16, :], k) diff --git a/xformers/components/attention/base.py b/xformers/components/attention/base.py index eba0afc957..4728a0db29 100644 --- a/xformers/components/attention/base.py +++ b/xformers/components/attention/base.py @@ -34,11 +34,19 @@ class Attention(nn.Module, metaclass=ABCMeta): @abstractmethod def __init__(self, dropout: Optional[float] = None, *args, **kwargs): super().__init__() + + # Requires the inputs to be projected self.requires_input_projection = True + + # Whether the head dimension needs to be present (if not it can be folded into the batch dimension) self.requires_head_dimension = False + # key padding mask and attention mask must be passed in as separate arguments instead of a merged attention mask self.requires_separate_masks = False + # Requires that K and Q have the same sequence length + self.requires_same_k_q_dimensions = False + @classmethod def from_config(cls: Type[Self], config: AttentionConfig) -> Self: # Generate the class inputs from the config diff --git a/xformers/components/attention/blocksparse.py b/xformers/components/attention/blocksparse.py index 130f56bc36..9dfeddd579 100644 --- a/xformers/components/attention/blocksparse.py +++ b/xformers/components/attention/blocksparse.py @@ -108,7 +108,9 @@ def __init__( # key padding mask and attention mask must be passed in separately self.requires_separate_masks = True - def update_mask_type(self, mask: torch.Tensor, to_dtype: torch.dtype): + self.requires_same_k_q_dimensions = True + + def update_mask_type(self, mask: torch.Tensor): global _mask_type_warning if _mask_type_warning: logging.warning( @@ -141,9 +143,9 @@ def forward( # initial attention setup if att_mask is not None and att_mask.dtype == torch.bool: - self.update_mask_type(att_mask, q.dtype) + self.update_mask_type(att_mask) if key_padding_mask is not None and key_padding_mask.dtype == torch.bool: - self.update_mask_type(key_padding_mask, q.dtype) + self.update_mask_type(key_padding_mask) assert ( att_mask is None or att_mask.dim() == 2 diff --git a/xformers/components/attention/global_tokens.py b/xformers/components/attention/global_tokens.py index 9ef7606fbe..0d188fde5d 100644 --- a/xformers/components/attention/global_tokens.py +++ b/xformers/components/attention/global_tokens.py @@ -78,6 +78,8 @@ def __init__( else maybe_sparsify(self.attention_mask) ) + self.requires_same_k_q_dimensions = True + def forward( self, q: torch.Tensor, diff --git a/xformers/components/attention/lambda_layer.py b/xformers/components/attention/lambda_layer.py index 1d145d6ef6..dc8130c902 100644 --- a/xformers/components/attention/lambda_layer.py +++ b/xformers/components/attention/lambda_layer.py @@ -44,6 +44,7 @@ def __init__(self, dropout: float, seq_len: int, dim_head: int, *_, **__): ) self.rel_pos = calc_rel_pos(seq_len) self.attn_drop = torch.nn.Dropout(dropout, inplace=True) + self.requires_same_k_q_dimensions = True def forward( self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, *args, **kwargs diff --git a/xformers/components/attention/linformer.py b/xformers/components/attention/linformer.py index 0458e9889b..8f6b181c24 100644 --- a/xformers/components/attention/linformer.py +++ b/xformers/components/attention/linformer.py @@ -42,6 +42,7 @@ def __init__( self.F = nn.Linear(seq_len, k, bias=False) self.attn_drop = nn.Dropout(dropout, inplace=False) self.seq_len = seq_len + self.requires_same_k_q_dimensions = True def forward( self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, *args, **kwargs diff --git a/xformers/components/attention/local.py b/xformers/components/attention/local.py index e00c730444..d3699becdf 100644 --- a/xformers/components/attention/local.py +++ b/xformers/components/attention/local.py @@ -75,6 +75,7 @@ def __init__( self.window_size = window_size self.attention_mask: Optional[torch.Tensor] = None + self.requires_same_k_q_dimensions = True def _get_local_mask(self, shape: torch.Size) -> torch.Tensor: window_size = self.window_size * 2 + 1 if self.causal else self.window_size diff --git a/xformers/components/attention/random.py b/xformers/components/attention/random.py index 4de0f52911..1b51182b12 100644 --- a/xformers/components/attention/random.py +++ b/xformers/components/attention/random.py @@ -68,6 +68,7 @@ def __init__( self.rand_attention_mask: Optional[torch.Tensor] = None self.constant_masking = constant_masking self.force_sparsity = force_sparsity + self.requires_same_k_q_dimensions = True def _get_rand_mask(self, shape: torch.Size) -> torch.Tensor: sparsity = 1 - self.r diff --git a/xformers/components/multi_head_dispatch.py b/xformers/components/multi_head_dispatch.py index a576704e7a..8e83dd629e 100644 --- a/xformers/components/multi_head_dispatch.py +++ b/xformers/components/multi_head_dispatch.py @@ -146,6 +146,18 @@ def forward( B, S_Q, _ = query.size() # Batch x Sequence x Embedding (latent) _, S_K, _ = key.size() # K, Q's sequence length could differ + # Catch different query and key length but a causal attention + if S_Q != S_K: + assert ( + not self.attention.requires_same_k_q_dimensions + ), "This attention mechanism requires query and key to have the same sequence (context) lengths" + + if hasattr(self.attention, "causal"): + assert not self.attention.causal, ( + "Causal attention is not supported when key and query have different sequence lengths.\n" + + "In that case causality is ill-determined. Please pad your sequences accordingly" + ) + # Calculate query, key, values for all heads in batch if self.attention.requires_input_projection: q, k, v = self.in_proj_container(query=query, key=key, value=value) diff --git a/xformers/components/positional_embedding/rotary.py b/xformers/components/positional_embedding/rotary.py index 7ece79543f..94bf5736f3 100644 --- a/xformers/components/positional_embedding/rotary.py +++ b/xformers/components/positional_embedding/rotary.py @@ -18,9 +18,14 @@ def rotate_half(x): @torch.jit.script -def apply_rotary_pos_emb(q, k, cos, sin): +def apply_rotary_pos_emb(x, cos, sin): # NOTE: This could probably be moved to Triton - return (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin) + + # Handle a possible sequence length mismatch in between q and k + cos = cos[:, :, : x.shape[-2], :] + sin = sin[:, :, : x.shape[-2], :] + + return (x * cos) + (rotate_half(x) * sin) class RotaryEmbedding(torch.nn.Module): @@ -73,7 +78,10 @@ def forward( self, q: torch.Tensor, k: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor]: self._cos_cached, self._sin_cached = self._update_cos_sin_tables( - q, seq_dimension=-2 + k, seq_dimension=-2 ) - return apply_rotary_pos_emb(q, k, self._cos_cached, self._sin_cached) + return ( + apply_rotary_pos_emb(q, self._cos_cached, self._sin_cached), + apply_rotary_pos_emb(k, self._cos_cached, self._sin_cached), + )