Skip to content

Commit

Permalink
Catch FF requiring squared context length
Browse files Browse the repository at this point in the history
  • Loading branch information
blefaudeux committed Jun 3, 2022
1 parent 5c61ad0 commit 5fe5ec3
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 2 deletions.
5 changes: 4 additions & 1 deletion tests/test_block_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,10 @@ def test_xformer_decoder_block(
)

# Test different sequence lengths when encoding and decoding
if not decoder_block.requires_same_k_q_dimensions:
if (
not decoder_block.requires_same_k_q_dimensions
and not decoder_block.requires_squared_context_length
):
if not causal or not decoder_block.causal_attention:
_ = decoder_block(inputs[:, :-16], encoded)
else:
Expand Down
4 changes: 3 additions & 1 deletion xformers/factory/block_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,9 +275,11 @@ def __init__(self, config: xFormerDecoderConfig, **kwargs):
cross_mha = build_multi_head_attention(config.multi_head_config_cross)
feedforward = build_feedforward(config.feedforward_config)

# Expose attention specific capabilities
# Expose attention or feedforward specific capabilities
self.supports_attention_mask = mha.attention.supports_attention_mask
self.requires_same_k_q_dimensions = mha.attention.requires_same_k_q_dimensions
self.requires_squared_context_length = feedforward.requires_squared_context

self.causal_attention = (
mha.attention.causal if hasattr(mha.attention, "causal") else False
)
Expand Down

0 comments on commit 5fe5ec3

Please sign in to comment.