Skip to content

Commit

Permalink
Catch FF requiring squared context length
Browse files Browse the repository at this point in the history
  • Loading branch information
blefaudeux committed Jun 3, 2022
1 parent b19d544 commit 55119ab
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 3 deletions.
5 changes: 4 additions & 1 deletion tests/test_block_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,10 @@ def test_xformer_decoder_block(
)

# Test different sequence lengths when encoding and decoding
if not decoder_block.requires_same_k_q_dimensions:
if (
not decoder_block.requires_same_k_q_dimensions
and not decoder_block.requires_squared_context_length
):
if not causal or not decoder_block.causal_attention:
_ = decoder_block(inputs[:, :-16], encoded)
else:
Expand Down
5 changes: 4 additions & 1 deletion tests/test_model_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,16 +199,19 @@ def check_against_default(p):


@pytest.mark.parametrize("weight_init", [w.value for w in xFormerWeightInit])
@pytest.mark.parametrize("feedforward", ["MLP", "Conv2DFeedforward"])
@pytest.mark.parametrize("deepnorm", [False, True])
@pytest.mark.parametrize("device", DEVICES)
def test_weight_init(weight_init, deepnorm, device):
def test_weight_init(weight_init, feedforward, deepnorm, device):
torch.cuda.manual_seed(42)
torch.manual_seed(42)

config = test_configs_dict

if deepnorm:
config["encoder"]["layer_norm_style"] = "deepnorm"
config["encoder"]["feedforward_config"]["name"] = feedforward

config["decoder"]["layer_norm_style"] = "deepnorm"

# Make sure that all the init methods catch all the weights
Expand Down
4 changes: 3 additions & 1 deletion xformers/factory/block_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,9 +275,11 @@ def __init__(self, config: xFormerDecoderConfig, **kwargs):
cross_mha = build_multi_head_attention(config.multi_head_config_cross)
feedforward = build_feedforward(config.feedforward_config)

# Expose attention specific capabilities
# Expose attention or feedforward specific capabilities
self.supports_attention_mask = mha.attention.supports_attention_mask
self.requires_same_k_q_dimensions = mha.attention.requires_same_k_q_dimensions
self.requires_squared_context_length = feedforward.requires_squared_context

self.causal_attention = (
mha.attention.causal if hasattr(mha.attention, "causal") else False
)
Expand Down

0 comments on commit 55119ab

Please sign in to comment.