diff --git a/CHANGELOG.md b/CHANGELOG.md index 60041f0679..223ab72b29 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -19,6 +19,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - MLP benchmark - Move all triton kernels to triton v2 [#272] - Mem efficient attention, BW pass [#281] +- Metaformer support [#294] ## [0.0.10] - 2022-03-14 ### Fixed diff --git a/HOWTO.md b/HOWTO.md index eb636aa419..a56c3725fd 100644 --- a/HOWTO.md +++ b/HOWTO.md @@ -30,6 +30,8 @@ Let's present here a couple of code snippets on how to solve a couple of questio - [Intro](#intro) - [Transformer](#transformer) - [In practice](#in-practice) + - [Hierarchical Transformers](#hierarchical-transformers) + ## Understanding the dimension conventions @@ -749,3 +751,58 @@ class xFormerStackConfig: [2]: Kitaev, N., Kaiser, Ł., & Levskaya, A. (2020). Reformer: The Efficient Transformer. [3]: Vaswani et al., Attention is all you need, 2017 + + +### Hierarchical Transformers + +The original Transformer proposal processes ("transforms") sequences of tokens, across possibly many layers. Crucially, the number of tokens is unchanged cross the depth of the model, and this prove to be really efficient in many domains. + +It seems that some domains could however benefit from an architecture more typical from CNN, where there's a tradeoff across the depth of the model in between the spatial extent (ie: number of tokens) and their expressiveness (ie: the model or embedding dimension). These architectures are handled in xformers, through the "patch_embedding" element, which translates the sequence of tokens from one layer to another. + +A small helper is provided to make it easier to generate matching configurations, as follows. We present in this example a truncated version of a small [Metaformer](https://arxiv.org/abs/2111.11418v1). + +```python + from xformers.factory import xFormer, xFormerConfig + from xformers.helpers.hierarchical_configs import ( + BasicLayerConfig, + get_hierarchical_configuration, + ) + + + base_hierarchical_configs = [ + BasicLayerConfig( + embedding=64, # the dimensions just have to match along the layers + attention_mechanism="scaled_dot_product", # anything you like + patch_size=7, + stride=4, + padding=2, + seq_len=image_size * image_size // 16, + ), + BasicLayerConfig( + embedding=128, + attention_mechanism="scaled_dot_product", + patch_size=3, + stride=2, + padding=1, + seq_len=image_size * image_size // 64, + ), + BasicLayerConfig( + embedding=320, + attention_mechanism="scaled_dot_product", + patch_size=3, + stride=2, + padding=1, + seq_len=image_size * image_size // 256, + ), + ] + + # Fill in the gaps in the config + xformer_config = get_hierarchical_configuration( + base_hierarchical_configs, + layernorm_style="pre", + use_rotary_embeddings=False, + mlp_multiplier=4, + dim_head=32, + ) + config = xFormerConfig(xformer_config) +``` diff --git a/README.md b/README.md index 654b2fc63d..f884d55e77 100644 --- a/README.md +++ b/README.md @@ -190,6 +190,7 @@ Patrick et al., 2021](https://arxiv.org/abs/2106.05392)* 2. transformer block benchmark 3. [LRA](xformers/benchmarks/LRA/README.md), with SLURM suppot 4. Programatic and sweep friendly layer and model construction + 1. Compatible with hierarchical Transformers, like Swin or Metaformer 5. Hackable 1. Not using monolithic CUDA kernels, composable building blocks 2. Using [Triton](https://triton-lang.org/) for some optimized parts, explicit, pythonic and user-accessible diff --git a/docs/assets/metaformer.png b/docs/assets/metaformer.png new file mode 100644 index 0000000000..692e80935b Binary files /dev/null and b/docs/assets/metaformer.png differ diff --git a/docs/source/tutorials/hierarchical.rst b/docs/source/tutorials/hierarchical.rst new file mode 100644 index 0000000000..b56d182875 --- /dev/null +++ b/docs/source/tutorials/hierarchical.rst @@ -0,0 +1,56 @@ +Hierarchical Transformers +========================= + +The original Transformer proposal processes ("transforms") sequences of tokens, across possibly many layers. Crucially, the number of tokens is unchanged cross the depth of the model, and this prove to be really efficient in many domains. + +It seems that some domains could however benefit from an architecture more typical from CNN, where there's a tradeoff across the depth of the model in between the spatial extent (ie: number of tokens) and their expressiveness (ie: the model or embedding dimension). These architectures are handled in xformers, through the "patch_embedding" element, which translates the sequence of tokens from one layer to another. + +A small helper is provided to make it easier to generate matching configurations, as follows. We present in this example a truncated version of a small Metaformer_. + +.. _Metaformer: https://arxiv.org/abs/2111.11418v1 + +.. code-block:: python + + from xformers.factory import xFormer, xFormerConfig + from xformers.helpers.hierarchical_configs import ( + BasicLayerConfig, + get_hierarchical_configuration, + ) + + + base_hierarchical_configs = [ + BasicLayerConfig( + embedding=64, # the dimensions just have to match along the layers + attention_mechanism="scaled_dot_product", # anything you like + patch_size=7, + stride=4, + padding=2, + seq_len=image_size * image_size // 16, + ), + BasicLayerConfig( + embedding=128, + attention_mechanism="scaled_dot_product", + patch_size=3, + stride=2, + padding=1, + seq_len=image_size * image_size // 64, + ), + BasicLayerConfig( + embedding=320, + attention_mechanism="scaled_dot_product", + patch_size=3, + stride=2, + padding=1, + seq_len=image_size * image_size // 256, + ), + ] + + # Fill in the gaps in the config + xformer_config = get_hierarchical_configuration( + base_hierarchical_configs, + layernorm_style="pre", + use_rotary_embeddings=False, + mlp_multiplier=4, + dim_head=32, + ) + config = xFormerConfig(xformer_config) diff --git a/docs/source/tutorials/index.rst b/docs/source/tutorials/index.rst index fc59744f3d..37af922a98 100644 --- a/docs/source/tutorials/index.rst +++ b/docs/source/tutorials/index.rst @@ -11,3 +11,4 @@ Tutorials pytorch_encoder reversible triton + hierarchical diff --git a/examples/README.md b/examples/README.md index 05a134081d..4867d8d49a 100644 --- a/examples/README.md +++ b/examples/README.md @@ -25,3 +25,10 @@ If your current machine does not expose enough RAM and the example reports an `O This is meant to be an easy introduction to using xformers in practice, mirroring closely [this Pytorch Lightning](https://pytorchlightning.github.io/lightning-tutorials/notebooks/lightning_examples/cifar10-baseline.html) tutorial. The default settings are close to this tutorial, which trains a 11M parameters ResNet on the CIFAR dataset, we train a 10.6M ViT on the same dataset. The ViT configuration is not optimal for CIFAR, since the pictures have a very small size to begin with and information is probably lost given the patches. Nevertheless you should be able to reach about 80% accuracy within about an hour on a single GPU. ![Example curves](../docs/assets/microViT.png) + + +### MicroMetaformer + +This is very close to the MicroViT example above, but illustrating the use of a hierarchical Transformer ([Metaformer](https://arxiv.org/pdf/2111.11418.pdf)) this time, through a helper function which generates the required configuration given the pooling parameters. The suggested configuration is about 6.6M parameters big (half of a ResNet18) and trains to about 86% top-1 Cifar10 within minutes. + +![Example curves](../docs/assets/metaformer.png) diff --git a/examples/cifarMetaformer.py b/examples/cifarMetaformer.py new file mode 100644 index 0000000000..292d00ab12 --- /dev/null +++ b/examples/cifarMetaformer.py @@ -0,0 +1,184 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + + +import pytorch_lightning as pl +import torch +from pl_bolts.datamodules import CIFAR10DataModule +from pl_bolts.transforms.dataset_normalizations import cifar10_normalization +from torch import nn +from torchmetrics import Accuracy +from torchvision import transforms + +from examples.microViT import Classifier, VisionTransformer +from xformers.factory import xFormer, xFormerConfig +from xformers.helpers.hierarchical_configs import ( + BasicLayerConfig, + get_hierarchical_configuration, +) + + +class MetaVisionTransformer(VisionTransformer): + def __init__( + self, + steps, + learning_rate=5e-3, + betas=(0.9, 0.99), + weight_decay=0.03, + image_size=32, + num_classes=10, + dim=384, + attention="scaled_dot_product", + layer_norm_style="pre", + use_rotary_embeddings=True, + linear_warmup_ratio=0.1, + classifier=Classifier.GAP, + ): + + super(VisionTransformer, self).__init__() + + # all the inputs are saved under self.hparams (hyperparams) + self.save_hyperparameters() + + # Generate the skeleton of our hierarchical Transformer + + # This is a small poolformer configuration, adapted to the small CIFAR10 pictures (32x32) + # Any other related config would work, + # and the attention mechanisms don't have to be the same across layers + base_hierarchical_configs = [ + BasicLayerConfig( + embedding=64, + attention_mechanism=attention, + patch_size=3, + stride=2, + padding=1, + seq_len=image_size * image_size // 4, + ), + BasicLayerConfig( + embedding=128, + attention_mechanism=attention, + patch_size=3, + stride=2, + padding=1, + seq_len=image_size * image_size // 16, + ), + BasicLayerConfig( + embedding=320, + attention_mechanism=attention, + patch_size=3, + stride=2, + padding=1, + seq_len=image_size * image_size // 64, + ), + BasicLayerConfig( + embedding=512, + attention_mechanism=attention, + patch_size=3, + stride=2, + padding=1, + seq_len=image_size * image_size // 256, + ), + ] + + # Fill in the gaps in the config + xformer_config = get_hierarchical_configuration( + base_hierarchical_configs, + layernorm_style=layer_norm_style, + use_rotary_embeddings=use_rotary_embeddings, + mlp_multiplier=4, + dim_head=32, + ) + + # Now instantiate the metaformer trunk + config = xFormerConfig(xformer_config) + print(config) + self.trunk = xFormer.from_config(config) + print(self.trunk) + + # The classifier head + dim = base_hierarchical_configs[-1].embedding + self.ln = nn.LayerNorm(dim) + self.head = nn.Linear(dim, num_classes) + self.criterion = torch.nn.CrossEntropyLoss() + self.val_accuracy = Accuracy() + + def forward(self, x): + x = self.trunk(x) + x = self.ln(x) + + if self.hparams.classifier == Classifier.TOKEN: + x = x[:, 0] # only consider the token, we're classifying anyway + elif self.hparams.classifier == Classifier.GAP: + x = x.mean(dim=1) # mean over sequence len + + x = self.head(x) + return x + + +if __name__ == "__main__": + pl.seed_everything(42) + + # Adjust batch depending on the available memory on your machine. + # You can also use reversible layers to save memory + REF_BATCH = 512 + BATCH = 512 # lower if not enough GPU memory + + MAX_EPOCHS = 50 + NUM_WORKERS = 4 + GPUS = 1 + + train_transforms = transforms.Compose( + [ + transforms.RandomCrop(32, padding=4), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + cifar10_normalization(), + ] + ) + + test_transforms = transforms.Compose( + [ + transforms.ToTensor(), + cifar10_normalization(), + ] + ) + + # We'll use a datamodule here, which already handles dataset/dataloader/sampler + # See https://pytorchlightning.github.io/lightning-tutorials/notebooks/lightning_examples/cifar10-baseline.html + # for a full tutorial + dm = CIFAR10DataModule( + data_dir="data", + batch_size=BATCH, + num_workers=NUM_WORKERS, + pin_memory=True, + ) + dm.train_transforms = train_transforms + dm.test_transforms = test_transforms + dm.val_transforms = test_transforms + + image_size = dm.size(-1) # 32 for CIFAR + num_classes = dm.num_classes # 10 for CIFAR + + # compute total number of steps + batch_size = BATCH * GPUS + steps = dm.num_samples // REF_BATCH * MAX_EPOCHS + lm = MetaVisionTransformer( + steps=steps, + image_size=image_size, + num_classes=num_classes, + attention="scaled_dot_product", + layer_norm_style="pre", + use_rotary_embeddings=True, + ) + trainer = pl.Trainer( + gpus=GPUS, + max_epochs=MAX_EPOCHS, + precision=16, + accumulate_grad_batches=REF_BATCH // BATCH, + ) + trainer.fit(lm, dm) + + # check the training + trainer.test(lm, datamodule=dm) diff --git a/examples/microViT.py b/examples/microViT.py index 41515c59ad..8e12f923c2 100644 --- a/examples/microViT.py +++ b/examples/microViT.py @@ -171,7 +171,7 @@ def training_step(self, batch, _): "train_loss": loss.mean(), "learning_rate": self.lr_schedulers().get_last_lr()[0], }, - step=trainer.global_step, + step=self.global_step, ) return loss diff --git a/tests/test_attentions.py b/tests/test_attentions.py index ef4bec2968..4d29002555 100644 --- a/tests/test_attentions.py +++ b/tests/test_attentions.py @@ -3,6 +3,7 @@ # This source code is licensed under the BSD license found in the # LICENSE file in the root directory of this source tree. +import math from typing import Tuple import pytest @@ -96,6 +97,9 @@ def test_order_invariance( device: torch.device, ): + if int(math.sqrt(SEQ)) ** 2 != SEQ and attention_name == "poolling": + pytest.skip(f"{attention_name} requires squared sequence lengths") + torch.manual_seed(42) multi_head = _get_multihead( @@ -282,6 +286,9 @@ def test_broadcast_batch_dimension( device: torch.device, batch_sizes: Tuple[int, int, int], ): + if int(math.sqrt(SEQ)) ** 2 != SEQ and attention_name == "poolling": + pytest.skip(f"{attention_name} requires squared sequence lengths") + Q_BATCH, K_BATCH, V_BATCH = batch_sizes multi_head = _get_multihead(attention_name, 0.0, 0.0, False, heads, device) diff --git a/tests/test_hierarchical_transformer.py b/tests/test_hierarchical_transformer.py new file mode 100644 index 0000000000..535aaef442 --- /dev/null +++ b/tests/test_hierarchical_transformer.py @@ -0,0 +1,63 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + + +import torch + +from xformers.factory import xFormer, xFormerConfig +from xformers.helpers.hierarchical_configs import ( + BasicLayerConfig, + get_hierarchical_configuration, +) + +BATCH = 20 +SEQ = 512 +MODEL = 384 + + +def test_hierarchical_transformer(): + image_size = 32 + + base_hierarchical_configs = [ + BasicLayerConfig( + embedding=64, + attention_mechanism="scaled_dot_product", + patch_size=7, + stride=4, + padding=2, + seq_len=image_size * image_size // 16, + ), + BasicLayerConfig( + embedding=128, + attention_mechanism="scaled_dot_product", + patch_size=3, + stride=2, + padding=1, + seq_len=image_size * image_size // 64, + ), + BasicLayerConfig( + embedding=320, + attention_mechanism="scaled_dot_product", + patch_size=3, + stride=2, + padding=1, + seq_len=image_size * image_size // 256, + ), + ] + + # Fill in the gaps in the config + xformer_config = get_hierarchical_configuration( + base_hierarchical_configs, + layernorm_style="pre", + use_rotary_embeddings=False, + mlp_multiplier=4, + dim_head=32, + ) + config = xFormerConfig(xformer_config) + hierarchical_xformer = xFormer.from_config(config) + + # Forward some dummy data + dummy = torch.rand((2, 3, image_size, image_size)) + _ = hierarchical_xformer(dummy) diff --git a/xformers/components/multi_head_dispatch.py b/xformers/components/multi_head_dispatch.py index f904f1c876..f3bde6f67f 100644 --- a/xformers/components/multi_head_dispatch.py +++ b/xformers/components/multi_head_dispatch.py @@ -19,10 +19,10 @@ @dataclass class MultiHeadDispatchConfig: dim_model: int - residual_dropout: float num_heads: int attention: Attention bias: bool + residual_dropout: float dim_key: Optional[int] dim_value: Optional[int] in_proj_container: Optional[InProjContainer] @@ -55,10 +55,10 @@ class MultiHeadDispatch(nn.Module): def __init__( self, dim_model: int, - residual_dropout: float, num_heads: int, attention: Attention, bias: bool = True, + residual_dropout: float = 0.0, dim_key: Optional[int] = None, dim_value: Optional[int] = None, in_proj_container: Optional[InProjContainer] = None, diff --git a/xformers/factory/block_configs.py b/xformers/factory/block_configs.py index 30ab0c6693..55abf6f8bc 100644 --- a/xformers/factory/block_configs.py +++ b/xformers/factory/block_configs.py @@ -130,7 +130,7 @@ def __init__( patch_embedding_config: Optional[Dict[str, Any]] = None, **kwargs, ): - # Convenience, fill in duplicated field + # Convenience, fill in duplicated fields try: if "dim_model" not in multi_head_config.keys(): multi_head_config["dim_model"] = dim_model @@ -144,6 +144,12 @@ def __init__( ): position_encoding_config["dim_model"] = dim_model + if ( + patch_embedding_config is not None + and "out_channels" not in patch_embedding_config.keys() + ): + patch_embedding_config["out_channels"] = dim_model + except AttributeError: # A config instance was passed in, this is fine pass diff --git a/xformers/helpers/hierarchical_configs.py b/xformers/helpers/hierarchical_configs.py new file mode 100644 index 0000000000..2de8e2f424 --- /dev/null +++ b/xformers/helpers/hierarchical_configs.py @@ -0,0 +1,104 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + + +import copy +from dataclasses import dataclass +from typing import Any, Dict, List + +from xformers.components.residual import LayerNormStyle + + +@dataclass +class BasicLayerConfig: + embedding: int + attention_mechanism: str + patch_size: int + stride: int + padding: int + seq_len: int + + +def get_hierarchical_configuration( + layer_basic_configs: List[BasicLayerConfig], + layernorm_style: LayerNormStyle = LayerNormStyle.Pre, + use_rotary_embeddings: bool = True, + mlp_multiplier: int = 4, + dim_head=32, +): + """ + A small helper to generate hierarchical xformers configurations, + which correspond for instance to poolformer or swin architectures. + + Contrary to more "classical" Transformer architectures, which conserve the sequence/context + length across layers, hierarchical Transformers trade the sequence length for the embedding dimension + """ + + base_config: Dict[str, Any] = { + "block_type": "encoder", + "dim_model": 0, + "use_triton": False, + "layer_norm_style": str(layernorm_style), + "multi_head_config": { + "num_heads": 0, + "use_rotary_embeddings": use_rotary_embeddings, + "attention": { + "name": "TBD", + }, + }, + "feedforward_config": { + "name": "MLP", + "activation": "gelu", + "hidden_layer_multiplier": mlp_multiplier, + "dropout": 0.0, + }, + "position_encoding_config": { + "name": "learnable", + "seq_len": 0, + "add_class_token": False, + }, + "patch_embedding_config": { + "in_channels": 3, + "kernel_size": 0, + "stride": 0, + "padding": 0, + }, + } + + xformers_config = [] + in_channels = 3 + + for layer_basic_config in layer_basic_configs: + lc = copy.deepcopy(base_config) + + # Fill in the changing model dimensions + lc["dim_model"] = layer_basic_config.embedding + + # Update the patches + lc["patch_embedding_config"] = { + "in_channels": in_channels, + "kernel_size": layer_basic_config.patch_size, + "stride": layer_basic_config.stride, + "padding": layer_basic_config.padding, + } + + # Update the number of channels for the next layer + in_channels = lc["dim_model"] * 1 + + lc["position_encoding_config"]["seq_len"] = layer_basic_config.seq_len + + # Fill in the number of heads + lc["multi_head_config"]["num_heads"] = layer_basic_config.embedding // dim_head + assert layer_basic_config.embedding % dim_head == 0 + + # Fill in the attention mechanism + lc["multi_head_config"]["attention"][ + "name" + ] = layer_basic_config.attention_mechanism + + print(lc) + xformers_config.append(lc) + + return xformers_config