Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[feat] Adding a conv MLP, following VAN #321

Merged
merged 5 commits into from
Jun 6, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Added
- Four blocksparsity layouts from DeepSpeed [#320]
- Support several initialization options [#312]
- Conv2DFeedforward feedforward part [#321]


## [0.0.11] - 2022-05-30
Expand Down
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,7 @@ Patrick et al., 2021](https://arxiv.org/abs/2106.05392)*
- [MLP](xformers/components/feedforward/mlp.py)
- [Fused](xformers/components/feedforward/fused_mlp.py)
- [Mixture of Experts](xformers/components/feedforward/mixture_of_experts.py)
- [Conv2DFeedforward](xformers/components/feedforward/conv_mlp.py)

</p></details>

Expand Down
1 change: 1 addition & 0 deletions examples/cifarMetaformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ def __init__(
use_rotary_embeddings=use_rotary_embeddings,
mlp_multiplier=4,
dim_head=32,
feedforward="Conv2DFeedforward",
)

# Now instantiate the metaformer trunk
Expand Down
5 changes: 4 additions & 1 deletion tests/test_block_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,10 @@ def test_xformer_decoder_block(
)

# Test different sequence lengths when encoding and decoding
if not decoder_block.requires_same_k_q_dimensions:
if (
not decoder_block.requires_same_k_q_dimensions
and not decoder_block.requires_squared_context_length
):
if not causal or not decoder_block.causal_attention:
_ = decoder_block(inputs[:, :-16], encoded)
else:
Expand Down
2 changes: 1 addition & 1 deletion tests/test_feedforward.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from xformers.helpers.test_utils import init_torch_distributed_local

BATCH = 4
SEQ = 512
SEQ = 256
EMBD = 16
LATENT = 128
DROPOUT = 0.5
Expand Down
5 changes: 4 additions & 1 deletion tests/test_model_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,16 +199,19 @@ def check_against_default(p):


@pytest.mark.parametrize("weight_init", [w.value for w in xFormerWeightInit])
@pytest.mark.parametrize("feedforward", ["MLP", "Conv2DFeedforward"])
@pytest.mark.parametrize("deepnorm", [False, True])
@pytest.mark.parametrize("device", DEVICES)
def test_weight_init(weight_init, deepnorm, device):
def test_weight_init(weight_init, feedforward, deepnorm, device):
torch.cuda.manual_seed(42)
torch.manual_seed(42)

config = test_configs_dict

if deepnorm:
config["encoder"]["layer_norm_style"] = "deepnorm"
config["encoder"]["feedforward_config"]["name"] = feedforward

config["decoder"]["layer_norm_style"] = "deepnorm"

# Make sure that all the init methods catch all the weights
Expand Down
5 changes: 5 additions & 0 deletions xformers/components/feedforward/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,13 @@ def __init__(
**kwargs,
):
super().__init__()

# This feedforward requires a CUDA accelerator
self.requires_cuda = False

# This feedforward requires a context length which is squared, often due to 2D pooling
self.requires_squared_context = False

@classmethod
def from_config(cls: Type[Self], config: FeedforwardConfig) -> Self:
# Generate the class inputs from the config
Expand Down
97 changes: 97 additions & 0 deletions xformers/components/feedforward/conv_mlp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.


# CREDITS: Largely reusing the code from the reference VAN implementation
# see https://github.com/Visual-Attention-Network

import math
from dataclasses import dataclass
from typing import Optional

import torch.nn as nn

from xformers.components import Activation, build_activation
from xformers.components.feedforward import Feedforward, FeedforwardConfig

from . import register_feedforward


@dataclass
class ConvMlpConfig(FeedforwardConfig):
hidden_layer_multiplier: int
dim_model: int
dim_model_out: Optional[int]
act_layer: Activation
dropout: float


@register_feedforward("Conv2DFeedforward", ConvMlpConfig)
class Conv2DFeedforward(Feedforward):
"""
A Convolutional feed-forward network, as proposed in VAN_ (Vision Attention Network, Guo et al.)

.. _VAN: https://arxiv.org/pdf/2202.09741.pdf
"""

def __init__(
self,
dim_model: int,
hidden_layer_multiplier: int = 1,
dim_model_out: Optional[int] = None,
activation: Activation = Activation.GeLU,
dropout=0.0,
*args,
**kwargs,
):
super().__init__()
out_features = dim_model_out or dim_model
hidden_features = hidden_layer_multiplier * dim_model

self.conv_mlp = nn.Sequential(
nn.Conv2d(dim_model, hidden_features, 1),
nn.Conv2d(
hidden_features,
hidden_features,
3,
1,
1,
bias=True,
groups=hidden_features,
),
build_activation(activation),
nn.Conv2d(hidden_features, out_features, 1),
nn.Dropout(dropout),
)

# This feedforward requires a context length which is squared, often due to 2D pooling
self.requires_squared_context = True
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this does 2D convolutions, meaning that the layer needs to be able to go from [Batch x Context x Embedding] to [Batch x H x W x Embedding]. A solution which is not too intrusive is to force the use of sequences being squared numbers, meaning essentially that we only work with square pictures. It's pretty common in vision codebases, I think that another solution would be to keep track of the original H and W prior to flattening this dimension.


def init_weights(self, **kwargs):
# Follow the original init, but also make it possible to initialize from the outside
def init_module(m: nn.Module):
if isinstance(m, nn.Conv2d):
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
fan_out //= m.groups
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
if m.bias is not None:
m.bias.data.zero_()

self.apply(init_module)

def forward(self, x):
# The conv layers expect NCHW, we have NLC by default
B, L, C = x.shape
HW = int(math.sqrt(x.shape[-2]))
assert HW**2 == L, "Conv2DFeedforward requires squared context lengths"

x = x.reshape((B, HW, HW, C)).swapdims(1, -1)

# The actual FW, including the 2d convolutions
x = self.conv_mlp(x)

# back to NLC
x = x.transpose(1, -1)
return x.flatten(1, 2)
4 changes: 3 additions & 1 deletion xformers/factory/block_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,9 +275,11 @@ def __init__(self, config: xFormerDecoderConfig, **kwargs):
cross_mha = build_multi_head_attention(config.multi_head_config_cross)
feedforward = build_feedforward(config.feedforward_config)

# Expose attention specific capabilities
# Expose attention or feedforward specific capabilities
self.supports_attention_mask = mha.attention.supports_attention_mask
self.requires_same_k_q_dimensions = mha.attention.requires_same_k_q_dimensions
self.requires_squared_context_length = feedforward.requires_squared_context

self.causal_attention = (
mha.attention.causal if hasattr(mha.attention, "causal") else False
)
Expand Down
4 changes: 0 additions & 4 deletions xformers/factory/weight_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,6 @@ class xFormerWeightInit(str, Enum):
Small = "small"


# TODO: Check with a bunch of quick trainings whether all the inits are in the green
# TODO: Check test coverage


def get_weight_init_fn(init_choice: xFormerWeightInit):
"""
Provide the xFormers factory with weight init routines.
Expand Down
3 changes: 2 additions & 1 deletion xformers/helpers/hierarchical_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ def get_hierarchical_configuration(
use_rotary_embeddings: bool = True,
mlp_multiplier: int = 4,
dim_head=32,
feedforward="MLP",
):
"""
A small helper to generate hierarchical xformers configurations,
Expand All @@ -49,7 +50,7 @@ def get_hierarchical_configuration(
},
},
"feedforward_config": {
"name": "MLP",
"name": feedforward,
"activation": "gelu",
"hidden_layer_multiplier": mlp_multiplier,
"dropout": 0.0,
Expand Down