Skip to content

Commit

Permalink
Merge pull request #13 from AnswerDotAI/rmsnorm
Browse files Browse the repository at this point in the history
Add RMSNorm and Swish
  • Loading branch information
bclavie authored May 15, 2024
2 parents 60a0d2b + 9531b71 commit b240ad1
Show file tree
Hide file tree
Showing 11 changed files with 74 additions and 33 deletions.
5 changes: 4 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -163,4 +163,7 @@ cython_debug/
.vscode

# weights and biases
wandb/
wandb/

# OS X
.DS_Store
4 changes: 0 additions & 4 deletions glue.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,12 +52,8 @@ def build_algorithm(name, kwargs):
return algorithms.GradientClipping(**kwargs)
elif name == "alibi":
return algorithms.Alibi(**kwargs)
elif name == "fused_layernorm":
return algorithms.FusedLayerNorm(**kwargs)
elif name == "gated_linear_units":
return algorithms.GatedLinearUnits(**kwargs)
elif name == "low_precision_layernorm":
return algorithms.LowPrecisionLayerNorm(**kwargs)
else:
raise ValueError(f"Not sure how to build algorithm: {name}")

Expand Down
4 changes: 0 additions & 4 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,12 +69,8 @@ def build_algorithm(name, kwargs):
return algorithms.GradientClipping(**kwargs)
elif name == "alibi":
return algorithms.Alibi(**kwargs)
elif name == "fused_layernorm":
return algorithms.FusedLayerNorm(**kwargs)
elif name == "gated_linear_units":
return algorithms.GatedLinearUnits(**kwargs)
elif name == "low_precision_layernorm":
return algorithms.LowPrecisionLayerNorm(**kwargs)
else:
raise ValueError(f"Not sure how to build algorithm: {name}")

Expand Down
4 changes: 0 additions & 4 deletions sequence_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,12 +74,8 @@ def build_algorithm(name, kwargs):
return algorithms.GradientClipping(**kwargs)
elif name == "alibi":
return algorithms.Alibi(**kwargs)
elif name == "fused_layernorm":
return algorithms.FusedLayerNorm(**kwargs)
elif name == "gated_linear_units":
return algorithms.GatedLinearUnits(**kwargs)
elif name == "low_precision_layernorm":
return algorithms.LowPrecisionLayerNorm(**kwargs)
else:
raise ValueError(f"Not sure how to build algorithm: {name}")

Expand Down
69 changes: 62 additions & 7 deletions src/bert_layers.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,18 @@
# Copyright 2024 **AUTHORS_TODO**
# License: Apache-2.0

# RMSNorm Implementation: Copyright Meta (from their Llama RMSNorm implementation)
# License: LLAMA 2 COMMUNITY LICENSE AGREEMENT

# Copyright 2022 MosaicML Examples authors
# SPDX-License-Identifier: Apache-2.0

# Copyright 2023 MosaicML Examples authors
# SPDX-License-Identifier: Apache-2.0

# Copyright 2023 MosaicML Examples authors
# SPDX-License-Identifier: Apache-2.0

# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
# Copyright (c) 2018-2021, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2023, Tri Dao.
Expand Down Expand Up @@ -76,6 +85,54 @@
logger = logging.getLogger(__name__)


class RMSNorm(torch.nn.Module):
"""Llama2 RMSNorm implementation"""

def __init__(self, dim: int, eps: float = 1e-6):
"""
Initialize the RMSNorm normalization layer.
Args:
dim (int): The dimension of the input tensor.
eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6.
Attributes:
eps (float): A small value added to the denominator for numerical stability.
weight (nn.Parameter): Learnable scaling parameter.
"""
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))

def _norm(self, x):
"""
Apply the RMSNorm normalization to the input tensor.
Args:
x (torch.Tensor): The input tensor.
Returns:
torch.Tensor: The normalized tensor.
"""
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)

def forward(self, x):
"""
Forward pass through the RMSNorm layer.
Args:
x (torch.Tensor): The input tensor.
Returns:
torch.Tensor: The output tensor after applying RMSNorm.
"""
output = self._norm(x.float()).type_as(x)
return output * self.weight


class BertEmbeddings(nn.Module):
"""Construct the embeddings for words, ignoring position.
Expand All @@ -98,9 +155,7 @@ def __init__(self, config):
# ALiBi doesn't use position embeddings
self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)

# self.LayerNorm is not snake-cased to stick with TensorFlow model
# variable name and be able to load any TensorFlow checkpoint file
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.LayerNorm = RMSNorm(config.hidden_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.register_buffer(
"token_type_ids", torch.zeros(config.max_position_embeddings, dtype=torch.long), persistent=False
Expand Down Expand Up @@ -291,7 +346,7 @@ class BertSelfOutput(nn.Module):
def __init__(self, config):
super().__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.LayerNorm = RMSNorm(config.hidden_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob)

def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
Expand Down Expand Up @@ -363,10 +418,10 @@ def __init__(self, config):
super().__init__()
self.config = config
self.gated_layers = nn.Linear(config.hidden_size, config.intermediate_size * 2, bias=False)
self.act = nn.GELU(approximate="none")
self.act = nn.SiLU()
self.wo = nn.Linear(config.intermediate_size, config.hidden_size)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.layernorm = RMSNorm(config.hidden_size, eps=config.layer_norm_eps)

def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
"""Compute new hidden states from current hidden states.
Expand Down Expand Up @@ -605,7 +660,7 @@ def __init__(self, config):
self.transform_act_fn = ACT2FN[config.hidden_act]
else:
self.transform_act_fn = config.hidden_act
self.LayerNorm = torch.nn.LayerNorm(config.hidden_size, eps=1e-12)
self.LayerNorm = RMSNorm(config.hidden_size, eps=1e-12)

def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states)
Expand Down
4 changes: 2 additions & 2 deletions src/mosaic_bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def create_mosaic_bert_mlm(
"attention_probs_dropout_prob": 0.0,
"classifier_dropout": null,
"gradient_checkpointing": false,
"hidden_act": "gelu",
"hidden_act": "silu",
"hidden_dropout_prob": 0.1,
"hidden_size": 768,
"initializer_range": 0.02,
Expand Down Expand Up @@ -177,7 +177,7 @@ def create_mosaic_bert_classification(
"attention_probs_dropout_prob": 0.0,
"classifier_dropout": null,
"gradient_checkpointing": false,
"hidden_act": "gelu",
"hidden_act": "silu",
"hidden_dropout_prob": 0.1,
"hidden_size": 768,
"id2label": {
Expand Down
5 changes: 2 additions & 3 deletions yamls/finetuning/glue/mcloud_run.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -82,9 +82,8 @@ parameters:
alpha_f: 0.0

# Algorithms
algorithms:
low_precision_layernorm: {}

# algorithms:

# Task configuration
tasks:
mnli:
Expand Down
3 changes: 1 addition & 2 deletions yamls/finetuning/glue/mosaic-bert-base-uncased.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,7 @@ scheduler:
alpha_f: 0.0

# Algorithms
algorithms:
low_precision_layernorm: {}
# algorithms:

# Task configuration
tasks:
Expand Down
3 changes: 1 addition & 2 deletions yamls/main/mcloud_run_a100_40gb.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -107,8 +107,7 @@ parameters:
eps: 1.0e-06
weight_decay: 1.0e-5 # Amount of weight decay regularization

algorithms:
low_precision_layernorm: {}
# algorithms:

max_duration: 286720000sp # Subsample the training data for ~275M samples
eval_interval: 2000ba
Expand Down
3 changes: 1 addition & 2 deletions yamls/main/mcloud_run_a100_80gb.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -107,8 +107,7 @@ parameters:
eps: 1.0e-06
weight_decay: 1.0e-5 # Amount of weight decay regularization

algorithms:
low_precision_layernorm: {}
# algorithms:

max_duration: 286720000sp # Subsample the training data for ~275M samples
eval_interval: 2000ba
Expand Down
3 changes: 1 addition & 2 deletions yamls/main/mosaic-bert-base-uncased.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,7 @@ optimizer:
eps: 1.0e-06
weight_decay: 1.0e-5 # Amount of weight decay regularization

algorithms:
low_precision_layernorm: {}
# algorithms:

max_duration: 286720000sp # Subsample the training data for ~275M samples
eval_interval: 2000ba
Expand Down

0 comments on commit b240ad1

Please sign in to comment.