-
Notifications
You must be signed in to change notification settings - Fork 636
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Running, could be nicer with some parameter auto-fill
- Loading branch information
1 parent
fa11d81
commit 9820481
Showing
8 changed files
with
317 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,196 @@ | ||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. | ||
# | ||
# This source code is licensed under the BSD license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
|
||
from enum import Enum | ||
|
||
import pytorch_lightning as pl | ||
import torch | ||
from microViT import VisionTransformer | ||
from pl_bolts.datamodules import CIFAR10DataModule | ||
from pl_bolts.transforms.dataset_normalizations import cifar10_normalization | ||
from torch import nn | ||
from torchmetrics import Accuracy | ||
from torchvision import transforms | ||
|
||
from xformers.factory import xFormer, xFormerConfig | ||
from xformers.helpers.hierarchical_configs import ( | ||
BasicLayerConfig, | ||
get_hierarchical_configuration, | ||
) | ||
|
||
|
||
class Classifier(str, Enum): | ||
GAP = "gap" | ||
TOKEN = "token" | ||
|
||
|
||
class MetaVisionTransformer(VisionTransformer): | ||
def __init__( | ||
self, | ||
steps, | ||
learning_rate=5e-4, | ||
betas=(0.9, 0.99), | ||
weight_decay=0.03, | ||
image_size=32, | ||
num_classes=10, | ||
patch_size=2, | ||
dim=384, | ||
n_layer=6, | ||
n_head=6, | ||
resid_pdrop=0.0, | ||
attn_pdrop=0.0, | ||
mlp_pdrop=0.0, | ||
attention="scaled_dot_product", | ||
layer_norm_style="pre", | ||
hidden_layer_multiplier=4, | ||
use_rotary_embeddings=True, | ||
linear_warmup_ratio=0.1, | ||
classifier: Classifier = Classifier.TOKEN, | ||
): | ||
|
||
super(VisionTransformer, self).__init__() | ||
|
||
# all the inputs are saved under self.hparams (hyperparams) | ||
self.save_hyperparameters() | ||
|
||
assert image_size % patch_size == 0 | ||
|
||
# Generate the skeleton of our hierarchical Transformer | ||
base_hierarchical_configs = [ | ||
BasicLayerConfig( | ||
embedding=64, | ||
attention_mechanism=attention, | ||
patch_size=7, | ||
stride=4, | ||
padding=2, | ||
seq_len=image_size * image_size // 16, | ||
), | ||
BasicLayerConfig( | ||
embedding=128, | ||
attention_mechanism=attention, | ||
patch_size=3, | ||
stride=2, | ||
padding=1, | ||
seq_len=image_size * image_size // 64, | ||
), | ||
BasicLayerConfig( | ||
embedding=320, | ||
attention_mechanism=attention, | ||
patch_size=3, | ||
stride=2, | ||
padding=1, | ||
seq_len=image_size * image_size // 256, | ||
), | ||
# BasicLayerConfig( | ||
# embedding=512, | ||
# attention_mechanism=attention, | ||
# patch_size=3, | ||
# stride=2, | ||
# padding=1, | ||
# seq_len=image_size * image_size // 1024, | ||
# ), | ||
] | ||
|
||
# Fill in the gaps in the config | ||
xformer_config = get_hierarchical_configuration( | ||
base_hierarchical_configs, | ||
layernorm_style=layer_norm_style, | ||
use_rotary_embeddings=use_rotary_embeddings, | ||
mlp_multiplier=4, | ||
dim_head=32, | ||
) | ||
|
||
# Now instantiate the metaformer trunk | ||
config = xFormerConfig(xformer_config) | ||
print(config) | ||
self.trunk = xFormer.from_config(config) | ||
print(self.trunk) | ||
|
||
# The classifier head | ||
dim = base_hierarchical_configs[-1].embedding | ||
self.ln = nn.LayerNorm(dim) | ||
self.head = nn.Linear(dim, num_classes) | ||
self.criterion = torch.nn.CrossEntropyLoss() | ||
self.val_accuracy = Accuracy() | ||
|
||
def forward(self, x): | ||
x = self.trunk(x) | ||
x = self.ln(x) | ||
|
||
if self.hparams.classifier == Classifier.TOKEN: | ||
x = x[:, 0] # only consider the token, we're classifying anyway | ||
elif self.hparams.classifier == Classifier.GAP: | ||
x = x.mean(dim=1) # mean over sequence len | ||
|
||
x = self.head(x) | ||
return x | ||
|
||
|
||
if __name__ == "__main__": | ||
pl.seed_everything(42) | ||
|
||
# Adjust batch depending on the available memory on your machine. | ||
# You can also use reversible layers to save memory | ||
REF_BATCH = 512 | ||
BATCH = 256 | ||
|
||
MAX_EPOCHS = 50 | ||
NUM_WORKERS = 4 | ||
GPUS = 1 | ||
|
||
train_transforms = transforms.Compose( | ||
[ | ||
transforms.RandomCrop(32, padding=4), | ||
transforms.RandomHorizontalFlip(), | ||
transforms.ToTensor(), | ||
cifar10_normalization(), | ||
] | ||
) | ||
|
||
test_transforms = transforms.Compose( | ||
[ | ||
transforms.ToTensor(), | ||
cifar10_normalization(), | ||
] | ||
) | ||
|
||
# We'll use a datamodule here, which already handles dataset/dataloader/sampler | ||
# See https://pytorchlightning.github.io/lightning-tutorials/notebooks/lightning_examples/cifar10-baseline.html | ||
# for a full tutorial | ||
dm = CIFAR10DataModule( | ||
data_dir="data", | ||
batch_size=BATCH, | ||
num_workers=NUM_WORKERS, | ||
pin_memory=True, | ||
) | ||
dm.train_transforms = train_transforms | ||
dm.test_transforms = test_transforms | ||
dm.val_transforms = test_transforms | ||
|
||
image_size = dm.size(-1) # 32 for CIFAR | ||
num_classes = dm.num_classes # 10 for CIFAR | ||
|
||
# compute total number of steps | ||
batch_size = BATCH * GPUS | ||
steps = dm.num_samples // REF_BATCH * MAX_EPOCHS | ||
lm = MetaVisionTransformer( | ||
steps=steps, | ||
image_size=image_size, | ||
num_classes=num_classes, | ||
attention="scaled_dot_product", | ||
layer_norm_style="pre", | ||
use_rotary_embeddings=True, | ||
) | ||
trainer = pl.Trainer( | ||
gpus=GPUS, | ||
max_epochs=MAX_EPOCHS, | ||
precision=16, | ||
accumulate_grad_batches=REF_BATCH // BATCH, | ||
) | ||
trainer.fit(lm, dm) | ||
|
||
# check the training | ||
trainer.test(lm, datamodule=dm) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,104 @@ | ||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. | ||
# | ||
# This source code is licensed under the BSD license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
|
||
import copy | ||
from dataclasses import dataclass | ||
from typing import Any, Dict, List | ||
|
||
from xformers.components.residual import LayerNormStyle | ||
|
||
|
||
@dataclass | ||
class BasicLayerConfig: | ||
embedding: int | ||
attention_mechanism: str | ||
patch_size: int | ||
stride: int | ||
padding: int | ||
seq_len: int | ||
|
||
|
||
def get_hierarchical_configuration( | ||
layer_basic_configs: List[BasicLayerConfig], | ||
layernorm_style: LayerNormStyle = LayerNormStyle.Pre, | ||
use_rotary_embeddings: bool = True, | ||
mlp_multiplier: int = 4, | ||
dim_head=32, | ||
): | ||
""" | ||
A small helper to generate hierarchical xformers configurations, | ||
which correspond for instance to poolformer or swin architectures. | ||
Contrary to more "classical" Transformer architectures, which conserve the sequence/context | ||
length across layers, hierarchical Transformers trade the sequence length for the embedding dimension | ||
""" | ||
|
||
base_config: Dict[str, Any] = { | ||
"block_type": "encoder", | ||
"dim_model": 0, | ||
"use_triton": False, | ||
"layer_norm_style": str(layernorm_style), | ||
"multi_head_config": { | ||
"num_heads": 0, | ||
"use_rotary_embeddings": use_rotary_embeddings, | ||
"attention": { | ||
"name": "TBD", | ||
}, | ||
}, | ||
"feedforward_config": { | ||
"name": "MLP", | ||
"activation": "gelu", | ||
"hidden_layer_multiplier": mlp_multiplier, | ||
"dropout": 0.0, | ||
}, | ||
"position_encoding_config": { | ||
"name": "learnable", | ||
"seq_len": 0, | ||
"add_class_token": False, | ||
}, | ||
"patch_embedding_config": { | ||
"in_channels": 3, | ||
"kernel_size": 0, | ||
"stride": 0, | ||
"padding": 0, | ||
}, | ||
} | ||
|
||
xformers_config = [] | ||
in_channels = 3 | ||
|
||
for layer_basic_config in layer_basic_configs: | ||
lc = copy.deepcopy(base_config) | ||
|
||
# Fill in the changing model dimensions | ||
lc["dim_model"] = layer_basic_config.embedding | ||
|
||
# Update the patches | ||
lc["patch_embedding_config"] = { | ||
"in_channels": in_channels, | ||
"kernel_size": layer_basic_config.patch_size, | ||
"stride": layer_basic_config.stride, | ||
"padding": layer_basic_config.padding, | ||
} | ||
|
||
# Update the number of channels for the next layer | ||
in_channels = lc["dim_model"] * 1 | ||
|
||
lc["position_encoding_config"]["seq_len"] = layer_basic_config.seq_len | ||
|
||
# Fill in the number of heads | ||
lc["multi_head_config"]["num_heads"] = layer_basic_config.embedding // dim_head | ||
assert layer_basic_config.embedding % dim_head == 0 | ||
|
||
# Fill in the attention mechanism | ||
lc["multi_head_config"]["attention"][ | ||
"name" | ||
] = layer_basic_config.attention_mechanism | ||
|
||
print(lc) | ||
xformers_config.append(lc) | ||
|
||
return xformers_config |