diff --git a/examples/cifarMetaformer.py b/examples/cifarMetaformer.py index cf88f61c4a..11af03c4aa 100644 --- a/examples/cifarMetaformer.py +++ b/examples/cifarMetaformer.py @@ -89,6 +89,7 @@ def __init__( use_rotary_embeddings=use_rotary_embeddings, mlp_multiplier=4, dim_head=32, + feedforward="ConvMLP", ) # Now instantiate the metaformer trunk diff --git a/tests/test_feedforward.py b/tests/test_feedforward.py index 80fa0e3679..3bf58cd330 100644 --- a/tests/test_feedforward.py +++ b/tests/test_feedforward.py @@ -12,7 +12,7 @@ from xformers.helpers.test_utils import init_torch_distributed_local BATCH = 4 -SEQ = 512 +SEQ = 256 EMBD = 16 LATENT = 128 DROPOUT = 0.5 diff --git a/xformers/components/feedforward/base.py b/xformers/components/feedforward/base.py index 867276b038..3c483ce593 100644 --- a/xformers/components/feedforward/base.py +++ b/xformers/components/feedforward/base.py @@ -35,8 +35,13 @@ def __init__( **kwargs, ): super().__init__() + + # This feedforward requires a CUDA accelerator self.requires_cuda = False + # This feedforward requires a context length which is squared, often due to 2D pooling + self.requires_squared_context = False + @classmethod def from_config(cls: Type[Self], config: FeedforwardConfig) -> Self: # Generate the class inputs from the config diff --git a/xformers/components/feedforward/conv_mlp.py b/xformers/components/feedforward/conv_mlp.py new file mode 100644 index 0000000000..cd36bad5b9 --- /dev/null +++ b/xformers/components/feedforward/conv_mlp.py @@ -0,0 +1,105 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + + +# CREDITS: Largely reusing the code from the reference VAN implementation +# see https://github.com/Visual-Attention-Network + +import math +from dataclasses import dataclass +from typing import Optional + +import torch.nn as nn + +from xformers.components import Activation, build_activation +from xformers.components.feedforward import Feedforward, FeedforwardConfig +from xformers.factory.weight_init import _no_grad_trunc_normal_ + +from . import register_feedforward + + +@dataclass +class ConvMlpConfig(FeedforwardConfig): + hidden_layer_multiplier: int + dim_model: int + dim_model_out: Optional[int] + act_layer: Activation + dropout: float + + +@register_feedforward("ConvMLP", ConvMlpConfig) +class ConvMLP(Feedforward): + """ + A Convolutional feed-forward network, as proposed in VAN_ (Vision Attention Network, Guo et al.) + + .. _VAN: https://arxiv.org/pdf/2202.09741.pdf + """ + + def __init__( + self, + dim_model: int, + hidden_layer_multiplier: int = 1, + dim_model_out: Optional[int] = None, + activation: Activation = Activation.GeLU, + dropout=0.0, + *args, + **kwargs, + ): + super().__init__() + out_features = dim_model_out or dim_model + hidden_features = hidden_layer_multiplier * dim_model + + self.conv_mlp = nn.Sequential( + nn.Conv2d(dim_model, hidden_features, 1), + nn.Conv2d( + hidden_features, + hidden_features, + 3, + 1, + 1, + bias=True, + groups=hidden_features, + ), + build_activation(activation), + nn.Conv2d(hidden_features, out_features, 1), + nn.Dropout(dropout), + ) + + # This feedforward requires a context length which is squared, often due to 2D pooling + self.requires_squared_context = True + + def init_weights(self, **kwargs): + # Follow the original init, but also make it possible to initialize from the outside + def init_module(m: nn.Module): + if isinstance(m, nn.Linear): + _no_grad_trunc_normal_(m.weight, mean=0.0, std=0.02, a=2.0, b=-2.0) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + elif isinstance(m, nn.Conv2d): + fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + fan_out //= m.groups + m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) + if m.bias is not None: + m.bias.data.zero_() + + self.apply(init_module) + + def forward(self, x): + # The conv layers expect NCHW, we have NLC by default + B, L, C = x.shape + HW = int(math.sqrt(x.shape[-2])) + assert HW**2 == L, "ConvMLP is 2D by default, and it assumes square pictures" + + x = x.reshape((B, HW, HW, C)).swapdims(1, -1) + + # The actual FW, including the 2d convolutions + x = self.conv_mlp(x) + + # back to NLC + x = x.transpose(1, -1) + return x.flatten(1, 2) diff --git a/xformers/factory/weight_init.py b/xformers/factory/weight_init.py index 6da5d1dba4..90f92f7cdb 100644 --- a/xformers/factory/weight_init.py +++ b/xformers/factory/weight_init.py @@ -30,10 +30,6 @@ class xFormerWeightInit(str, Enum): Small = "small" -# TODO: Check with a bunch of quick trainings whether all the inits are in the green -# TODO: Check test coverage - - def get_weight_init_fn(init_choice: xFormerWeightInit): """ Provide the xFormers factory with weight init routines. @@ -92,35 +88,6 @@ def _small_init_(tensor: torch.Tensor, gain: float = 1.0) -> torch.Tensor: return _no_grad_uniform_(tensor, -a, a) -def _variance_scaling(tensor, scale=1.0, mode="fan_in", distribution="normal"): - fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor) - if mode == "fan_in": - denom = fan_in - elif mode == "fan_out": - denom = fan_out - elif mode == "fan_avg": - denom = (fan_in + fan_out) / 2 - - variance = scale / denom - - if distribution == "truncated_normal": - # constant is stddev of standard normal truncated to (-2, 2) - _no_grad_trunc_normal_( - tensor, - mean=0.0, - std=math.sqrt(variance) / 0.87962566103423978, - a=-2.0, - b=2.0, - ) - elif distribution == "normal": - tensor.normal_(std=math.sqrt(variance)) - elif distribution == "uniform": - bound = math.sqrt(3 * variance) - tensor.uniform_(-bound, bound) - else: - raise ValueError(f"invalid distribution {distribution}") - - def _lecun_normal(tensor, gain=1.0): fan_in, _ = _calculate_fan_in_and_fan_out(tensor) denom = fan_in diff --git a/xformers/helpers/hierarchical_configs.py b/xformers/helpers/hierarchical_configs.py index 2de8e2f424..84ce4c7b53 100644 --- a/xformers/helpers/hierarchical_configs.py +++ b/xformers/helpers/hierarchical_configs.py @@ -27,6 +27,7 @@ def get_hierarchical_configuration( use_rotary_embeddings: bool = True, mlp_multiplier: int = 4, dim_head=32, + feedforward="MLP", ): """ A small helper to generate hierarchical xformers configurations, @@ -49,7 +50,7 @@ def get_hierarchical_configuration( }, }, "feedforward_config": { - "name": "MLP", + "name": feedforward, "activation": "gelu", "hidden_layer_multiplier": mlp_multiplier, "dropout": 0.0,