Skip to content

Commit

Permalink
Adding Visual Attention
Browse files Browse the repository at this point in the history
  • Loading branch information
blefaudeux committed Jun 8, 2022
1 parent b281307 commit 696d178
Show file tree
Hide file tree
Showing 8 changed files with 142 additions and 15 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Four blocksparsity layouts from DeepSpeed [#320]
- Support several initialization options [#312]
- Conv2DFeedforward feedforward part [#321]
- VisualAttention [#329]


## [0.0.11] - 2022-05-30
Expand Down
5 changes: 4 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,9 @@ Patrick et al., 2021](https://arxiv.org/abs/2106.05392)*
- [2D Pooling](xformers/components/attention/pooling.py)
- *[Metaformer is actually what you need for vision, Yu et al.](https://arxiv.org/pdf/2111.11418v1.pdf)*

- [Visual Attention](xformers/components/attention/visual.py)
- *[`Visual Attention Network`_, Guo et al](https://arxiv.org/pdf/2202.09741.pdf)*

- ... add a new one [see Contribution.md](CONTRIBUTING.md)

</p></details>
Expand Down Expand Up @@ -199,7 +202,7 @@ Patrick et al., 2021](https://arxiv.org/abs/2106.05392)*

<details><summary>Initializations </summary><p>
This is completely optional, and will only occur when generating full models through xFormers, not when picking parts individually.

There are basically two initialization mechanisms exposed, but the user is free to initialize weights as he/she sees fit after the fact.
- Parts can expose a `init_weights()` method, which define sane defaults
- xFormers supports [specific init schemes](xformers/factory/weight_init.py) which *can take precedence* over the init_weights()
Expand Down
9 changes: 5 additions & 4 deletions examples/cifarMetaformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ def __init__(
num_classes=10,
dim=384,
attention="scaled_dot_product",
feedforward="MLP",
layer_norm_style="pre",
use_rotary_embeddings=True,
linear_warmup_ratio=0.1,
Expand All @@ -45,8 +46,7 @@ def __init__(
# Generate the skeleton of our hierarchical Transformer

# This is a small poolformer configuration, adapted to the small CIFAR10 pictures (32x32)
# Any other related config would work,
# and the attention mechanisms don't have to be the same across layers
# Any other related config would work, and the attention mechanisms don't have to be the same across layers
base_hierarchical_configs = [
BasicLayerConfig(
embedding=64,
Expand Down Expand Up @@ -121,8 +121,8 @@ def forward(self, x):

# Adjust batch depending on the available memory on your machine.
# You can also use reversible layers to save memory
REF_BATCH = 512
BATCH = 512 # lower if not enough GPU memory
REF_BATCH = 768
BATCH = 256 # lower if not enough GPU memory

MAX_EPOCHS = 50
NUM_WORKERS = 4
Expand Down Expand Up @@ -172,6 +172,7 @@ def forward(self, x):
num_classes=num_classes,
attention="scaled_dot_product",
layer_norm_style="pre",
feedforward="MLP",
use_rotary_embeddings=True,
)
trainer = pl.Trainer(
Expand Down
34 changes: 25 additions & 9 deletions tests/test_attentions.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,14 +27,16 @@
)

BATCH = 2
SEQ = 128 if torch.cuda.is_available() else 32
SEQ = 128 if torch.cuda.is_available() else 36
MODEL = 128 if torch.cuda.is_available() else 16
GLOBAL_ATTENTION_RATIO = (
_DENSITY_THRESHOLD * 0.9
) # Make sure that we test the sparse implementation, no matter the threshold

assert ATTENTION_REGISTRY.keys(), "Attention layers should have been registered"

_non_order_invariant_attentions = ["visual", "pooling"]


def _get_multihead(
attention_name,
Expand Down Expand Up @@ -93,7 +95,9 @@ def noop(x):
@pytest.mark.parametrize("residual_dropout", [0.0, 0.1])
@pytest.mark.parametrize("causal", [True, False])
@pytest.mark.parametrize("heads", [1, 4])
@pytest.mark.parametrize("attention_name", ATTENTION_REGISTRY.keys())
@pytest.mark.parametrize(
"attention_name", ATTENTION_REGISTRY.keys() - _non_order_invariant_attentions
)
@pytest.mark.parametrize("device", DEVICES)
def test_order_invariance(
attention_name: str,
Expand All @@ -104,9 +108,6 @@ def test_order_invariance(
device: torch.device,
):

if int(math.sqrt(SEQ)) ** 2 != SEQ and attention_name == "pooling":
pytest.skip(f"{attention_name} requires squared sequence lengths")

torch.manual_seed(42)
torch.cuda.manual_seed_all(42)

Expand All @@ -120,6 +121,12 @@ def test_order_invariance(
use_seperate_proj_weights=False,
)

if (
int(math.sqrt(SEQ)) ** 2 != SEQ
and multi_head.attention.requires_squared_context
):
pytest.skip(f"{attention_name} requires squared sequence lengths")

# Check that a shuffled input produces the same results
seqs = [SEQ, SEQ // 2] if (attention_name != "blocksparse") else [SEQ]

Expand Down Expand Up @@ -304,12 +311,15 @@ def test_broadcast_batch_dimension(
device: torch.device,
batch_sizes: Tuple[int, int, int],
):
if int(math.sqrt(SEQ)) ** 2 != SEQ and attention_name == "pooling":
pytest.skip(f"{attention_name} requires squared sequence lengths")

Q_BATCH, K_BATCH, V_BATCH = batch_sizes
multi_head = _get_multihead(attention_name, 0.0, 0.0, False, heads, device)

if (
int(math.sqrt(SEQ)) ** 2 != SEQ
and multi_head.attention.requires_squared_context
):
pytest.skip(f"{attention_name} requires squared sequence lengths")

if multi_head.attention.requires_same_k_q_dimensions:
# pyre-fixme[29]: The library function `pytest.skip` is not supported by Pyre.
pytest.skip(f"{attention_name} does not support different k, q dimensions yet.")
Expand Down Expand Up @@ -388,14 +398,20 @@ def test_torch_script_ability(
heads: int,
attn_dropout: float,
):
if attention_name in {"favor", "global", "local", "random", "pooling"}:
if attention_name in {"favor", "global", "local", "random"}:
# pyre-fixme[29]: The library function `pytest.skip` is not supported by Pyre.
pytest.skip(f"{attention_name} does not support scripting yet.")

device = torch.device("cpu")

multi_head = _get_multihead(attention_name, attn_dropout, 0.0, False, heads, device)

if (
int(math.sqrt(SEQ)) ** 2 != SEQ
and multi_head.attention.requires_squared_context
):
pytest.skip(f"{attention_name} requires squared sequence lengths")

# input for tracing the function
q = torch.rand((BATCH, SEQ, MODEL), device=device)
k = torch.rand((BATCH, SEQ, MODEL), device=device)
Expand Down
3 changes: 3 additions & 0 deletions xformers/components/attention/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,9 @@ def __init__(self, dropout: Optional[float] = None, *args, **kwargs):
# so that the MHA wrapper should skip it
self.requires_skip_multi_head = False

# This attention requires a context length which is squared, often due to 2D pooling
self.requires_squared_context = False

# Whether this attention mechanism supports attention masks
self.supports_attention_mask = True
self.supports_key_padding_mask = False
Expand Down
4 changes: 4 additions & 0 deletions xformers/components/attention/pooling.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,10 @@ def __init__(
# This operator does not really handle q,k,v
self.requires_same_k_q_dimensions = True

# This attention requires the 2d structure out of the context,
# implictly assumed to be a squared length
self.requires_squared_context = True

def forward(self, q: torch.Tensor, *_, **__):
# Expose the 2D token structure
B, HW, C = q.shape
Expand Down
96 changes: 96 additions & 0 deletions xformers/components/attention/visual.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.


import math
from dataclasses import dataclass

import torch
import torch.nn as nn

from xformers.components.attention import Attention, AttentionConfig, register_attention


@dataclass
class VisualAttentionConfig(AttentionConfig):
dim_model: int # dimension of the input sequence


class LKA(nn.Module):
def __init__(self, dim: int):
super().__init__()
self.conv0 = nn.Conv2d(dim, dim, 5, padding=2, groups=dim)
self.conv_spatial = nn.Conv2d(
dim, dim, 7, stride=1, padding=9, groups=dim, dilation=3
)
self.conv1 = nn.Conv2d(dim, dim, 1)

def forward(self, x: torch.Tensor):
u = x.clone()
attn = self.conv0(x)
attn = self.conv_spatial(attn)
attn = self.conv1(attn)

return u * attn


@register_attention("visual", VisualAttentionConfig)
class Visual(Attention):
def __init__(
self,
dim_model: int,
*_,
**__,
):
"""
Large kernel attention mechanism, as proposed in `Visual Attention Network`_, Guo et al (2022).
The original notation is tentatively kept as is. See https://github.com/Visual-Attention-Network
for the reference implementation
.. Note: compared to the paper, this block contains the LKA (Large Kernel Attention)
and the prior and posterior transformations (Conv2d and activation)
.. _`Visual Attention Network` : https://arxiv.org/pdf/2202.09741.pdf
"""
super().__init__()

self.block = nn.Sequential(
nn.Conv2d(dim_model, dim_model, 1),
nn.GELU(),
LKA(dim_model),
nn.Conv2d(dim_model, dim_model, 1),
)

# MHA related flags:
self.requires_same_k_q_dimensions = (
True # This mechanism only really supports self attention
)
self.supports_attention_mask = False
self.requires_skip_multi_head = (
True # This mechanism skips the multihead attention altogether
)
self.requires_squared_context = (
True # Recovering the 2D structure from context assumes squared content
)

self.requires_input_projection = (
False # This mechanism does not require that the MHA projects inputs
)

def forward(self, q: torch.Tensor, *_, **__):
# Expose the 2D token structure
B, HW, C = q.shape
H = int(math.sqrt(HW))
assert H * H == HW

x = q.transpose(-2, -1).reshape(B, C, H, H)

# Large kernel attention
residual = x.clone()
x = self.block(x)
x = x + residual

# Get back to B HW C
return x.flatten(2, 3).transpose(-2, -1)
5 changes: 4 additions & 1 deletion xformers/factory/block_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,7 +278,10 @@ def __init__(self, config: xFormerDecoderConfig, **kwargs):
# Expose attention or feedforward specific capabilities
self.supports_attention_mask = mha.attention.supports_attention_mask
self.requires_same_k_q_dimensions = mha.attention.requires_same_k_q_dimensions
self.requires_squared_context_length = feedforward.requires_squared_context
self.requires_squared_context_length = (
feedforward.requires_squared_context
or mha.attention.requires_squared_context
)

self.causal_attention = (
mha.attention.causal if hasattr(mha.attention, "causal") else False
Expand Down

0 comments on commit 696d178

Please sign in to comment.