Skip to content

Commit

Permalink
Adding a conv MLP, following VAN
Browse files Browse the repository at this point in the history
  • Loading branch information
blefaudeux committed Jun 3, 2022
1 parent 52d1dd0 commit 0f5d2a9
Show file tree
Hide file tree
Showing 6 changed files with 114 additions and 35 deletions.
1 change: 1 addition & 0 deletions examples/cifarMetaformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ def __init__(
use_rotary_embeddings=use_rotary_embeddings,
mlp_multiplier=4,
dim_head=32,
feedforward="ConvMLP",
)

# Now instantiate the metaformer trunk
Expand Down
2 changes: 1 addition & 1 deletion tests/test_feedforward.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from xformers.helpers.test_utils import init_torch_distributed_local

BATCH = 4
SEQ = 512
SEQ = 256
EMBD = 16
LATENT = 128
DROPOUT = 0.5
Expand Down
5 changes: 5 additions & 0 deletions xformers/components/feedforward/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,13 @@ def __init__(
**kwargs,
):
super().__init__()

# This feedforward requires a CUDA accelerator
self.requires_cuda = False

# This feedforward requires a context length which is squared, often due to 2D pooling
self.requires_squared_context = False

@classmethod
def from_config(cls: Type[Self], config: FeedforwardConfig) -> Self:
# Generate the class inputs from the config
Expand Down
105 changes: 105 additions & 0 deletions xformers/components/feedforward/conv_mlp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.


# CREDITS: Largely reusing the code from the reference VAN implementation
# see https://github.com/Visual-Attention-Network

import math
from dataclasses import dataclass
from typing import Optional

import torch.nn as nn

from xformers.components import Activation, build_activation
from xformers.components.feedforward import Feedforward, FeedforwardConfig
from xformers.factory.weight_init import _no_grad_trunc_normal_

from . import register_feedforward


@dataclass
class ConvMlpConfig(FeedforwardConfig):
hidden_layer_multiplier: int
dim_model: int
dim_model_out: Optional[int]
act_layer: Activation
dropout: float


@register_feedforward("ConvMLP", ConvMlpConfig)
class ConvMLP(Feedforward):
"""
A Convolutional feed-forward network, as proposed in VAN_ (Vision Attention Network, Guo et al.)
.. _VAN: https://arxiv.org/pdf/2202.09741.pdf
"""

def __init__(
self,
dim_model: int,
hidden_layer_multiplier: int = 1,
dim_model_out: Optional[int] = None,
activation: Activation = Activation.GeLU,
dropout=0.0,
*args,
**kwargs,
):
super().__init__()
out_features = dim_model_out or dim_model
hidden_features = hidden_layer_multiplier * dim_model

self.conv_mlp = nn.Sequential(
nn.Conv2d(dim_model, hidden_features, 1),
nn.Conv2d(
hidden_features,
hidden_features,
3,
1,
1,
bias=True,
groups=hidden_features,
),
build_activation(activation),
nn.Conv2d(hidden_features, out_features, 1),
nn.Dropout(dropout),
)

# This feedforward requires a context length which is squared, often due to 2D pooling
self.requires_squared_context = True

def init_weights(self, **kwargs):
# Follow the original init, but also make it possible to initialize from the outside
def init_module(m: nn.Module):
if isinstance(m, nn.Linear):
_no_grad_trunc_normal_(m.weight, mean=0.0, std=0.02, a=2.0, b=-2.0)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
elif isinstance(m, nn.Conv2d):
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
fan_out //= m.groups
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
if m.bias is not None:
m.bias.data.zero_()

self.apply(init_module)

def forward(self, x):
# The conv layers expect NCHW, we have NLC by default
B, L, C = x.shape
HW = int(math.sqrt(x.shape[-2]))
assert HW**2 == L, "ConvMLP is 2D by default, and it assumes square pictures"

x = x.reshape((B, HW, HW, C)).swapdims(1, -1)

# The actual FW, including the 2d convolutions
x = self.conv_mlp(x)

# back to NLC
x = x.transpose(1, -1)
return x.flatten(1, 2)
33 changes: 0 additions & 33 deletions xformers/factory/weight_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,6 @@ class xFormerWeightInit(str, Enum):
Small = "small"


# TODO: Check with a bunch of quick trainings whether all the inits are in the green
# TODO: Check test coverage


def get_weight_init_fn(init_choice: xFormerWeightInit):
"""
Provide the xFormers factory with weight init routines.
Expand Down Expand Up @@ -92,35 +88,6 @@ def _small_init_(tensor: torch.Tensor, gain: float = 1.0) -> torch.Tensor:
return _no_grad_uniform_(tensor, -a, a)


def _variance_scaling(tensor, scale=1.0, mode="fan_in", distribution="normal"):
fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor)
if mode == "fan_in":
denom = fan_in
elif mode == "fan_out":
denom = fan_out
elif mode == "fan_avg":
denom = (fan_in + fan_out) / 2

variance = scale / denom

if distribution == "truncated_normal":
# constant is stddev of standard normal truncated to (-2, 2)
_no_grad_trunc_normal_(
tensor,
mean=0.0,
std=math.sqrt(variance) / 0.87962566103423978,
a=-2.0,
b=2.0,
)
elif distribution == "normal":
tensor.normal_(std=math.sqrt(variance))
elif distribution == "uniform":
bound = math.sqrt(3 * variance)
tensor.uniform_(-bound, bound)
else:
raise ValueError(f"invalid distribution {distribution}")


def _lecun_normal(tensor, gain=1.0):
fan_in, _ = _calculate_fan_in_and_fan_out(tensor)
denom = fan_in
Expand Down
3 changes: 2 additions & 1 deletion xformers/helpers/hierarchical_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ def get_hierarchical_configuration(
use_rotary_embeddings: bool = True,
mlp_multiplier: int = 4,
dim_head=32,
feedforward="MLP",
):
"""
A small helper to generate hierarchical xformers configurations,
Expand All @@ -49,7 +50,7 @@ def get_hierarchical_configuration(
},
},
"feedforward_config": {
"name": "MLP",
"name": feedforward,
"activation": "gelu",
"hidden_layer_multiplier": mlp_multiplier,
"dropout": 0.0,
Expand Down

0 comments on commit 0f5d2a9

Please sign in to comment.