Skip to content

Commit

Permalink
Merge pull request #21 from AnswerDotAI/ablations
Browse files Browse the repository at this point in the history
chore: add optional ablation config object to disable Noam arch changes
  • Loading branch information
bclavie authored May 16, 2024
2 parents b240ad1 + ce291d4 commit 0828615
Show file tree
Hide file tree
Showing 6 changed files with 77 additions and 27 deletions.
1 change: 1 addition & 0 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,7 @@ def build_model(cfg: DictConfig):
model_config=cfg.get("model_config", None),
tokenizer_name=cfg.get("tokenizer_name", None),
gradient_checkpointing=cfg.get("gradient_checkpointing", None),
ablations=cfg.get("ablations", {}),
)
else:
raise ValueError(f"Not sure how to build model with name={cfg.name}")
Expand Down
88 changes: 62 additions & 26 deletions src/bert_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,13 +149,17 @@ class BertEmbeddings(nn.Module):
This module ignores the `position_ids` input to the `forward` method.
"""

def __init__(self, config):
def __init__(self, config, use_rmsnorm: bool = True):
super().__init__()
self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
# ALiBi doesn't use position embeddings
self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)

self.LayerNorm = RMSNorm(config.hidden_size, eps=config.layer_norm_eps)
self.LayerNorm = (
RMSNorm(config.hidden_size, eps=config.layer_norm_eps)
if use_rmsnorm
else nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.register_buffer(
"token_type_ids", torch.zeros(config.max_position_embeddings, dtype=torch.long), persistent=False
Expand Down Expand Up @@ -343,10 +347,14 @@ class BertSelfOutput(nn.Module):
BERT modules.
"""

def __init__(self, config):
def __init__(self, config, use_rmsnorm: bool = True):
super().__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.LayerNorm = RMSNorm(config.hidden_size, eps=config.layer_norm_eps)
self.LayerNorm = (
RMSNorm(config.hidden_size, eps=config.layer_norm_eps)
if use_rmsnorm
else nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
)
self.dropout = nn.Dropout(config.hidden_dropout_prob)

def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
Expand All @@ -359,10 +367,13 @@ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> to
class BertUnpadAttention(nn.Module):
"""Chains attention, Dropout, and LayerNorm for Mosaic BERT."""

def __init__(self, config):
def __init__(self, config, use_rmsnorm: bool = True):
super().__init__()
self.self = BertUnpadSelfAttention(config)
self.output = BertSelfOutput(config)
self.output = BertSelfOutput(
config,
use_rmsnorm=use_rmsnorm,
)

def forward(
self,
Expand Down Expand Up @@ -414,14 +425,23 @@ class BertGatedLinearUnitMLP(nn.Module):
parameter size, Mosaic BERT typically offers a net higher throughput than a Hugging Face BERT built from the same `config`.
"""

def __init__(self, config):
def __init__(
self,
config,
use_rmsnorm: bool = True,
use_silu: bool = True,
):
super().__init__()
self.config = config
self.gated_layers = nn.Linear(config.hidden_size, config.intermediate_size * 2, bias=False)
self.act = nn.SiLU()
self.act = nn.SiLU() if use_silu else nn.GELU()
self.wo = nn.Linear(config.intermediate_size, config.hidden_size)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.layernorm = RMSNorm(config.hidden_size, eps=config.layer_norm_eps)
self.layernorm = (
RMSNorm(config.hidden_size, eps=config.layer_norm_eps)
if use_rmsnorm
else nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
)

def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
"""Compute new hidden states from current hidden states.
Expand All @@ -447,10 +467,10 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
class BertLayer(nn.Module):
"""Composes the Mosaic BERT attention and FFN blocks into a single layer."""

def __init__(self, config):
def __init__(self, config, use_rmsnorm: bool = True, use_silu: bool = True):
super(BertLayer, self).__init__()
self.attention = BertUnpadAttention(config)
self.mlp = BertGatedLinearUnitMLP(config)
self.attention = BertUnpadAttention(config, use_rmsnorm=use_rmsnorm)
self.mlp = BertGatedLinearUnitMLP(config, use_rmsnorm=use_rmsnorm, use_silu=use_silu)

def forward(
self,
Expand Down Expand Up @@ -494,9 +514,9 @@ class BertEncoder(nn.Module):
at padded tokens, and pre-computes attention biases to implement ALiBi.
"""

def __init__(self, config):
def __init__(self, config, use_rmsnorm: bool = True, use_silu: bool = True):
super().__init__()
layer = BertLayer(config)
layer = BertLayer(config, use_rmsnorm=use_rmsnorm, use_silu=use_silu)
self.layer = nn.ModuleList([copy.deepcopy(layer) for _ in range(config.num_hidden_layers)])

self.num_attention_heads = config.num_attention_heads
Expand Down Expand Up @@ -653,14 +673,16 @@ def forward(self, hidden_states: torch.Tensor, pool: Optional[bool] = True) -> t


class BertPredictionHeadTransform(nn.Module):
def __init__(self, config):
def __init__(self, config, use_rmsnorm: bool = True):
super().__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
if isinstance(config.hidden_act, str):
self.transform_act_fn = ACT2FN[config.hidden_act]
else:
self.transform_act_fn = config.hidden_act
self.LayerNorm = RMSNorm(config.hidden_size, eps=1e-12)
self.LayerNorm = (
RMSNorm(config.hidden_size, eps=1e-12) if use_rmsnorm else nn.LayerNorm(config.hidden_size, eps=1e-12)
)

def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states)
Expand Down Expand Up @@ -712,10 +734,16 @@ class BertModel(BertPreTrainedModel):
```
"""

def __init__(self, config, add_pooling_layer=True):
def __init__(
self,
config,
add_pooling_layer: bool = True,
use_rmsnorm: bool = True,
use_silu: bool = True,
):
super(BertModel, self).__init__(config)
self.embeddings = BertEmbeddings(config)
self.encoder = BertEncoder(config)
self.embeddings = BertEmbeddings(config, use_rmsnorm=use_rmsnorm)
self.encoder = BertEncoder(config, use_rmsnorm=use_rmsnorm, use_silu=use_silu)
self.pooler = BertPooler(config) if add_pooling_layer else None
self.post_init()

Expand Down Expand Up @@ -786,9 +814,9 @@ def forward(
# Bert Heads
###################
class BertLMPredictionHead(nn.Module):
def __init__(self, config, bert_model_embedding_weights):
def __init__(self, config, bert_model_embedding_weights, use_rmsnorm: bool = True):
super().__init__()
self.transform = BertPredictionHeadTransform(config)
self.transform = BertPredictionHeadTransform(config, use_rmsnorm=use_rmsnorm)
# The output weights are the same as the input embeddings, but there is
# an output-only bias for each token.
self.decoder = nn.Linear(bert_model_embedding_weights.size(1), bert_model_embedding_weights.size(0))
Expand All @@ -801,9 +829,9 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:


class BertOnlyMLMHead(nn.Module):
def __init__(self, config, bert_model_embedding_weights):
def __init__(self, config, bert_model_embedding_weights, use_rmsnorm: bool = True):
super().__init__()
self.predictions = BertLMPredictionHead(config, bert_model_embedding_weights)
self.predictions = BertLMPredictionHead(config, bert_model_embedding_weights, use_rmsnorm=use_rmsnorm)

def forward(self, sequence_output: torch.Tensor) -> torch.Tensor:
prediction_scores = self.predictions(sequence_output)
Expand Down Expand Up @@ -836,7 +864,7 @@ class BertLMHeadModel(BertPreTrainedModel):


class BertForMaskedLM(BertPreTrainedModel):
def __init__(self, config):
def __init__(self, config, ablations: dict = {}):
super().__init__(config)

if config.is_decoder:
Expand All @@ -845,8 +873,16 @@ def __init__(self, config):
"bi-directional self-attention."
)

self.bert = BertModel(config, add_pooling_layer=False)
self.cls = BertOnlyMLMHead(config, self.bert.embeddings.word_embeddings.weight)
use_rmsnorm = ablations.get("use_rmsnorm", True)
use_silu = ablations.get("use_silu", True)

self.bert = BertModel(
config,
add_pooling_layer=False,
use_rmsnorm=use_rmsnorm,
use_silu=use_silu,
)
self.cls = BertOnlyMLMHead(config, self.bert.embeddings.word_embeddings.weight, use_rmsnorm=use_rmsnorm)

# Initialize weights and apply final processing
self.post_init()
Expand Down
3 changes: 2 additions & 1 deletion src/mosaic_bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ def create_mosaic_bert_mlm(
tokenizer_name: Optional[str] = None,
gradient_checkpointing: Optional[bool] = False,
pretrained_checkpoint: Optional[str] = None,
ablations: dict = {},
):
"""Mosaic BERT masked language model based on |:hugging_face:| Transformers.
Expand Down Expand Up @@ -107,7 +108,7 @@ def create_mosaic_bert_mlm(
pretrained_checkpoint=pretrained_checkpoint, config=config
)
else:
model = bert_layers_module.BertForMaskedLM(config)
model = bert_layers_module.BertForMaskedLM(config, ablations=ablations)

if gradient_checkpointing:
model.gradient_checkpointing_enable() # type: ignore
Expand Down
4 changes: 4 additions & 0 deletions yamls/main/mcloud_run_a100_40gb.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,10 @@ parameters:
eps: 1.0e-06
weight_decay: 1.0e-5 # Amount of weight decay regularization

# ablations:
# use_rmsnorm: true
# use_silu: true

# algorithms:

max_duration: 286720000sp # Subsample the training data for ~275M samples
Expand Down
4 changes: 4 additions & 0 deletions yamls/main/mcloud_run_a100_80gb.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,10 @@ parameters:
eps: 1.0e-06
weight_decay: 1.0e-5 # Amount of weight decay regularization

# ablations:
# use_rmsnorm: true
# use_silu: true

# algorithms:

max_duration: 286720000sp # Subsample the training data for ~275M samples
Expand Down
4 changes: 4 additions & 0 deletions yamls/main/mosaic-bert-base-uncased.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,10 @@ optimizer:
eps: 1.0e-06
weight_decay: 1.0e-5 # Amount of weight decay regularization

# ablations:
# use_rmsnorm: true
# use_silu: true

# algorithms:

max_duration: 286720000sp # Subsample the training data for ~275M samples
Expand Down

0 comments on commit 0828615

Please sign in to comment.