-
Notifications
You must be signed in to change notification settings - Fork 636
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[feat] Adding Visual Attention #329
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -31,6 +31,7 @@ def __init__( | |
num_classes=10, | ||
dim=384, | ||
attention="scaled_dot_product", | ||
feedforward="MLP", | ||
layer_norm_style="pre", | ||
use_rotary_embeddings=True, | ||
linear_warmup_ratio=0.1, | ||
|
@@ -45,8 +46,7 @@ def __init__( | |
# Generate the skeleton of our hierarchical Transformer | ||
|
||
# This is a small poolformer configuration, adapted to the small CIFAR10 pictures (32x32) | ||
# Any other related config would work, | ||
# and the attention mechanisms don't have to be the same across layers | ||
# Any other related config would work, and the attention mechanisms don't have to be the same across layers | ||
base_hierarchical_configs = [ | ||
BasicLayerConfig( | ||
embedding=64, | ||
|
@@ -121,8 +121,8 @@ def forward(self, x): | |
|
||
# Adjust batch depending on the available memory on your machine. | ||
# You can also use reversible layers to save memory | ||
REF_BATCH = 512 | ||
BATCH = 512 # lower if not enough GPU memory | ||
REF_BATCH = 768 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. looks like a classic default for Cifar10 |
||
BATCH = 256 # lower if not enough GPU memory | ||
|
||
MAX_EPOCHS = 50 | ||
NUM_WORKERS = 4 | ||
|
@@ -172,6 +172,7 @@ def forward(self, x): | |
num_classes=num_classes, | ||
attention="scaled_dot_product", | ||
layer_norm_style="pre", | ||
feedforward="MLP", | ||
use_rotary_embeddings=True, | ||
) | ||
trainer = pl.Trainer( | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -62,6 +62,10 @@ def __init__( | |
# This operator does not really handle q,k,v | ||
self.requires_same_k_q_dimensions = True | ||
|
||
# This attention requires the 2d structure out of the context, | ||
# implictly assumed to be a squared length | ||
self.requires_squared_context = True | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this was already true before, but not formalized like this, I think it's cleaner ? "pooling" (PoolingFormer) and "visual" both recover the 2d structure of and assume a squared context length for that |
||
|
||
def forward(self, q: torch.Tensor, *_, **__): | ||
# Expose the 2D token structure | ||
B, HW, C = q.shape | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,96 @@ | ||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. | ||
# | ||
# This source code is licensed under the BSD license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
|
||
import math | ||
from dataclasses import dataclass | ||
|
||
import torch | ||
import torch.nn as nn | ||
|
||
from xformers.components.attention import Attention, AttentionConfig, register_attention | ||
|
||
|
||
@dataclass | ||
class VisualAttentionConfig(AttentionConfig): | ||
dim_model: int # dimension of the input sequence | ||
|
||
|
||
class LKA(nn.Module): | ||
def __init__(self, dim: int): | ||
super().__init__() | ||
self.conv0 = nn.Conv2d(dim, dim, 5, padding=2, groups=dim) | ||
self.conv_spatial = nn.Conv2d( | ||
dim, dim, 7, stride=1, padding=9, groups=dim, dilation=3 | ||
) | ||
self.conv1 = nn.Conv2d(dim, dim, 1) | ||
|
||
def forward(self, x: torch.Tensor): | ||
u = x.clone() | ||
attn = self.conv0(x) | ||
attn = self.conv_spatial(attn) | ||
attn = self.conv1(attn) | ||
|
||
return u * attn | ||
|
||
|
||
@register_attention("visual", VisualAttentionConfig) | ||
class Visual(Attention): | ||
def __init__( | ||
self, | ||
dim_model: int, | ||
*_, | ||
**__, | ||
): | ||
""" | ||
Large kernel attention mechanism, as proposed in `Visual Attention Network`_, Guo et al (2022). | ||
The original notation is tentatively kept as is. See https://github.com/Visual-Attention-Network | ||
for the reference implementation | ||
|
||
.. Note: compared to the paper, this block contains the LKA (Large Kernel Attention) | ||
and the prior and posterior transformations (Conv2d and activation) | ||
|
||
.. _`Visual Attention Network` : https://arxiv.org/pdf/2202.09741.pdf | ||
""" | ||
super().__init__() | ||
|
||
self.block = nn.Sequential( | ||
nn.Conv2d(dim_model, dim_model, 1), | ||
nn.GELU(), | ||
LKA(dim_model), | ||
nn.Conv2d(dim_model, dim_model, 1), | ||
) | ||
|
||
# MHA related flags: | ||
self.requires_same_k_q_dimensions = ( | ||
True # This mechanism only really supports self attention | ||
) | ||
self.supports_attention_mask = False | ||
self.requires_skip_multi_head = ( | ||
True # This mechanism skips the multihead attention altogether | ||
) | ||
self.requires_squared_context = ( | ||
True # Recovering the 2D structure from context assumes squared content | ||
) | ||
|
||
self.requires_input_projection = ( | ||
False # This mechanism does not require that the MHA projects inputs | ||
) | ||
|
||
def forward(self, q: torch.Tensor, *_, **__): | ||
# Expose the 2D token structure | ||
B, HW, C = q.shape | ||
H = int(math.sqrt(HW)) | ||
assert H * H == HW | ||
|
||
x = q.transpose(-2, -1).reshape(B, C, H, H) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I've not benchmarked that, but maybe that it's beneficial to .contiguous() here, depending on the Conv2D kernels |
||
|
||
# Large kernel attention | ||
residual = x.clone() | ||
x = self.block(x) | ||
x = x + residual | ||
|
||
# Get back to B HW C | ||
return x.flatten(2, 3).transpose(-2, -1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure about the defaults here, how to show that you can use these to repro "Visual Attention" for instance ? Should we show different presets ?