diff --git a/__init__.py b/__init__.py index d69c1b4..a46c1d1 100644 --- a/__init__.py +++ b/__init__.py @@ -9,57 +9,63 @@ try: import torch + # yapf: disable - from src.bert_layers import (BertEmbeddings, BertEncoder, BertForMaskedLM, + from src.bert_layers import (BertAlibiEmbeddings, BertAlibiEncoder, BertForMaskedLM, BertForSequenceClassification, - BertGatedLinearUnitMLP, BertLayer, + BertResidualGLU, BertAlibiLayer, BertLMPredictionHead, BertModel, BertOnlyMLMHead, BertOnlyNSPHead, BertPooler, BertPredictionHeadTransform, BertSelfOutput, - BertUnpadAttention, BertUnpadSelfAttention) + BertAlibiUnpadAttention, BertAlibiUnpadSelfAttention) # yapf: enable - from src.bert_padding import (IndexFirstAxis, IndexPutFirstAxis, - index_first_axis, index_put_first_axis, - pad_input, unpad_input, unpad_input_only) + from src.bert_padding import ( + IndexFirstAxis, + IndexPutFirstAxis, + index_first_axis, + index_put_first_axis, + pad_input, + unpad_input, + unpad_input_only, + ) from src.hf_bert import create_hf_bert_classification, create_hf_bert_mlm - from src.mosaic_bert import (create_mosaic_bert_classification, - create_mosaic_bert_mlm) + from src.mosaic_bert import create_mosaic_bert_classification, create_mosaic_bert_mlm except ImportError as e: try: is_cuda_available = torch.cuda.is_available() # type: ignore - except: + except Exception: is_cuda_available = False - reqs_file = 'requirements.txt' if is_cuda_available else 'requirements-cpu.txt' + reqs_file = "requirements.txt" if is_cuda_available else "requirements-cpu.txt" raise ImportError( - f'Please make sure to pip install -r {reqs_file} to get the requirements for the BERT benchmark.' + f"Please make sure to pip install -r {reqs_file} to get the requirements for the BERT benchmark." ) from e __all__ = [ - 'BertEmbeddings', - 'BertEncoder', - 'BertForMaskedLM', - 'BertForSequenceClassification', - 'BertGatedLinearUnitMLP', - 'BertLayer', - 'BertLMPredictionHead', - 'BertModel', - 'BertOnlyMLMHead', - 'BertOnlyNSPHead', - 'BertPooler', - 'BertPredictionHeadTransform', - 'BertSelfOutput', - 'BertUnpadAttention', - 'BertUnpadSelfAttention', - 'IndexFirstAxis', - 'IndexPutFirstAxis', - 'index_first_axis', - 'index_put_first_axis', - 'pad_input', - 'unpad_input', - 'unpad_input_only', - 'create_hf_bert_classification', - 'create_hf_bert_mlm', - 'create_mosaic_bert_classification', - 'create_mosaic_bert_mlm', + "BertAlibiEmbeddings", + "BertAlibiEncoder", + "BertForMaskedLM", + "BertForSequenceClassification", + "BertResidualGLU", + "BertAlibiLayer", + "BertLMPredictionHead", + "BertModel", + "BertOnlyMLMHead", + "BertOnlyNSPHead", + "BertPooler", + "BertPredictionHeadTransform", + "BertSelfOutput", + "BertAlibiUnpadAttention", + "BertAlibiUnpadSelfAttention", + "IndexFirstAxis", + "IndexPutFirstAxis", + "index_first_axis", + "index_put_first_axis", + "pad_input", + "unpad_input", + "unpad_input_only", + "create_hf_bert_classification", + "create_hf_bert_mlm", + "create_mosaic_bert_classification", + "create_mosaic_bert_mlm", ] diff --git a/src/__init__.py b/src/__init__.py index ee2b2dc..08facd4 100644 --- a/src/__init__.py +++ b/src/__init__.py @@ -8,12 +8,12 @@ sys.path.append(os.path.dirname(os.path.realpath(__file__))) # yapf: disable -from bert_layers import (BertEmbeddings, BertEncoder, BertForMaskedLM, - BertForSequenceClassification, BertGatedLinearUnitMLP, - BertLayer, BertLMPredictionHead, BertModel, +from bert_layers import (BertAlibiEmbeddings, BertAlibiEncoder, BertForMaskedLM, + BertForSequenceClassification, BertResidualGLU, + BertAlibiLayer, BertLMPredictionHead, BertModel, BertOnlyMLMHead, BertOnlyNSPHead, BertPooler, BertPredictionHeadTransform, BertSelfOutput, - BertUnpadAttention, BertUnpadSelfAttention) + BertAlibiUnpadAttention, BertAlibiUnpadSelfAttention) # yapf: enable from bert_padding import ( IndexFirstAxis, @@ -24,19 +24,18 @@ unpad_input, unpad_input_only, ) -from configuration_bert import BertConfig +from configuration_bert import MosaicBertConfig from hf_bert import create_hf_bert_classification, create_hf_bert_mlm from mosaic_bert import create_mosaic_bert_classification, create_mosaic_bert_mlm __all__ = [ - "BertConfig", - "BertEmbeddings", - "BertEncoder", + "BertAlibiEmbeddings", + "BertAlibiEncoder", "BertForMaskedLM", "BertForSequenceClassification", - "BertGatedLinearUnitMLP", - "BertLayer", + "BertResidualGLU", + "BertAlibiLayer", "BertLMPredictionHead", "BertModel", "BertOnlyMLMHead", @@ -44,8 +43,9 @@ "BertPooler", "BertPredictionHeadTransform", "BertSelfOutput", - "BertUnpadAttention", - "BertUnpadSelfAttention", + "BertAlibiUnpadAttention", + "BertAlibiUnpadSelfAttention", + "MosaicBertConfig", "IndexFirstAxis", "IndexPutFirstAxis", "index_first_axis", diff --git a/src/bert_layers.py b/src/bert_layers.py deleted file mode 100644 index b6df200..0000000 --- a/src/bert_layers.py +++ /dev/null @@ -1,1155 +0,0 @@ -# Copyright 2024 **AUTHORS_TODO** -# License: Apache-2.0 - -# RMSNorm Implementation: Copyright Meta (from their Llama RMSNorm implementation) -# License: LLAMA 2 COMMUNITY LICENSE AGREEMENT - -# Copyright 2022 MosaicML Examples authors -# SPDX-License-Identifier: Apache-2.0 - -# Copyright 2023 MosaicML Examples authors -# SPDX-License-Identifier: Apache-2.0 - -# Copyright 2023 MosaicML Examples authors -# SPDX-License-Identifier: Apache-2.0 - -# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. -# Copyright (c) 2018-2021, NVIDIA CORPORATION. All rights reserved. -# Copyright (c) 2023, Tri Dao. - -"""Implements Mosaic BERT, with an eye towards the Hugging Face API. - -Mosaic BERT improves performance over Hugging Face BERT through the following: - -1. ALiBi. This architectural change removes positional embeddings and instead encodes positional -information through attention biases based on query-key position distance. It improves the effectiveness -of training with shorter sequence lengths by enabling extrapolation to longer sequences. - -2. Gated Linear Units (GLU). This architectural change replaces the FFN component of the BERT layer -to improve overall expressiveness, providing better convergence properties. - -3. Flash Attention. The MosaicBERT's self-attention layer makes use of Flash Attention, which dramatically -improves the speed of self-attention. Our implementation utilizes a bleeding edge implementation that -supports attention biases, which allows us to use Flash Attention with ALiBi. - -4. Unpadding. Padding is often used to simplify batching across sequences of different lengths. Standard BERT -implementations waste computation on padded tokens. MosaicBERT internally unpads to reduce unnecessary computation -and improve speed. It does this without changing how the user interfaces with the model, thereby -preserving the simple API of standard implementations. - - -Currently, MosaicBERT is available for masked language modeling :class:`BertForMaskedLM` and sequence -classification :class:`BertForSequenceClassification`. We aim to expand this catalogue in future releases. - -See :file:`./mosaic_bert.py` for utilities to simplify working with MosaicBERT in Composer, and for example usage -of the core Mosaic BERT classes. -""" - -import copy -import logging -import math -import os -import sys -import warnings -from typing import List, Optional, Tuple, Union - -# Add folder root to path to allow us to use relative imports regardless of what directory the script is run from -sys.path.append(os.path.dirname(os.path.realpath(__file__))) - -import importlib - -import bert_padding as bert_padding_module -import torch -import torch.nn as nn -from einops import rearrange -from torch.nn.modules.utils import consume_prefix_in_state_dict_if_present -from transformers.activations import ACT2FN -from transformers.modeling_outputs import MaskedLMOutput, SequenceClassifierOutput -from transformers.models.bert.modeling_bert import BertPreTrainedModel - -IMPL_USE_FLASH2 = False -# Import Flash Attention 2, which supports ALiBi https://github.com/Dao-AILab/flash-attention -try: - from flash_attn import flash_attn_varlen_qkvpacked_func # type: ignore - - installed_version = importlib.metadata.version("flash_attn") # type: ignore - if installed_version < "2.5.7": - raise ImportError("newer version of flash_attn required (>= 2.5.7)") - IMPL_USE_FLASH2 = True - flash_attn_qkvpacked_func = None -except ImportError as e: - warnings.warn( - f"Failed to import flash_attn. Will use slow and memory hungry PyTorch native implementation: {e}", stacklevel=2 - ) - -logger = logging.getLogger(__name__) - - -class RMSNorm(torch.nn.Module): - """Llama2 RMSNorm implementation""" - - def __init__(self, dim: int, eps: float = 1e-6): - """ - Initialize the RMSNorm normalization layer. - - Args: - dim (int): The dimension of the input tensor. - eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6. - - Attributes: - eps (float): A small value added to the denominator for numerical stability. - weight (nn.Parameter): Learnable scaling parameter. - - """ - super().__init__() - self.eps = eps - self.weight = nn.Parameter(torch.ones(dim)) - - def _norm(self, x): - """ - Apply the RMSNorm normalization to the input tensor. - - Args: - x (torch.Tensor): The input tensor. - - Returns: - torch.Tensor: The normalized tensor. - - """ - return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) - - def forward(self, x): - """ - Forward pass through the RMSNorm layer. - - Args: - x (torch.Tensor): The input tensor. - - Returns: - torch.Tensor: The output tensor after applying RMSNorm. - - """ - output = self._norm(x.float()).type_as(x) - return output * self.weight - - -class BertEmbeddings(nn.Module): - """Construct the embeddings for words, ignoring position. - - There are no positional embeddings since we use ALiBi and token_type - embeddings. - - This module is modeled after the Hugging Face BERT's - :class:`~transformers.model.bert.modeling_bert.BertEmbeddings`, but is - modified as part of Mosaic BERT's ALiBi implementation. The key change is - that position embeddings are removed. Position information instead comes - from attention biases that scale linearly with the position distance - between query and key tokens. - - This module ignores the `position_ids` input to the `forward` method. - """ - - def __init__(self, config, use_rmsnorm: bool = True): - super().__init__() - self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id) - # ALiBi doesn't use position embeddings - self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size) - - self.LayerNorm = ( - RMSNorm(config.hidden_size, eps=config.layer_norm_eps) - if use_rmsnorm - else nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) - ) - self.dropout = nn.Dropout(config.hidden_dropout_prob) - self.register_buffer( - "token_type_ids", torch.zeros(config.max_position_embeddings, dtype=torch.long), persistent=False - ) - - def forward( - self, - input_ids: Optional[torch.LongTensor] = None, - token_type_ids: Optional[torch.LongTensor] = None, - position_ids: Optional[torch.LongTensor] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - past_key_values_length: int = 0, - ) -> torch.Tensor: - if (input_ids is not None) == (inputs_embeds is not None): - raise ValueError("Must specify either input_ids or input_embeds!") - if input_ids is not None: - input_shape = input_ids.size() - else: - assert inputs_embeds is not None # just for type checking - input_shape = inputs_embeds.size()[:-1] - - seq_length = input_shape[1] - - if position_ids is None: - # great! ALiBi - pass - - # Setting the token_type_ids to the registered buffer in constructor - # where it is all zeros, which usually occurs when it's auto-generated; - # registered buffer helps users when tracing the model without passing - # token_type_ids, solves issue #5664 - if token_type_ids is None: - if hasattr(self, "token_type_ids"): - assert isinstance(self.token_type_ids, torch.LongTensor) - buffered_token_type_ids = self.token_type_ids[:, :seq_length] - buffered_token_type_ids_expanded = buffered_token_type_ids.expand(input_shape[0], seq_length) - token_type_ids = buffered_token_type_ids_expanded # type: ignore - else: - token_type_ids = torch.zeros( - input_shape, # type: ignore - dtype=torch.long, - device=self.word_embeddings.device, - ) # type: ignore # yapf: disable - - if inputs_embeds is None: - inputs_embeds = self.word_embeddings(input_ids) - token_type_embeddings = self.token_type_embeddings(token_type_ids) - - embeddings = inputs_embeds + token_type_embeddings - # no position embeddings! ALiBi - embeddings = self.LayerNorm(embeddings) - embeddings = self.dropout(embeddings) - return embeddings - - -class BertUnpadSelfAttention(nn.Module): - """Performs multi-headed self attention on a batch of unpadded sequences. - - If Flash Attention 2 is installed, this module uses Flash Attention to greatly improve throughput. - The Flash Attention implementation used in MosaicBERT supports arbitrary attention biases (which - we use to implement ALiBi), but does not support attention dropout. If either Flash Attention 2 is - not installed or `config.attention_probs_dropout_prob > 0`, the implementation will default to a - math-equivalent pytorch version, which is much slower. - - See `forward` method for additional detail. - """ - - def __init__(self, config): - super().__init__() - if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): - raise ValueError( - f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention " - f"heads ({config.num_attention_heads})" - ) - - self.num_attention_heads = config.num_attention_heads - self.attention_head_size = int(config.hidden_size / config.num_attention_heads) - self.all_head_size = self.num_attention_heads * self.attention_head_size - self.dropout = nn.Dropout(config.attention_probs_dropout_prob) - self.p_dropout = config.attention_probs_dropout_prob - self.Wqkv = nn.Linear(self.all_head_size, 3 * config.hidden_size) - - # Warn if defaulting to pytorch because of import issues - if not IMPL_USE_FLASH2: - warnings.warn( - "Unable to import flash_attn; defaulting MosaicBERT attention implementation to pytorch (this will reduce throughput when using this model)." - ) - - def forward( - self, - hidden_states: torch.Tensor, - cu_seqlens: torch.Tensor, - max_seqlen_in_batch: int, - indices: torch.Tensor, - attn_mask: torch.Tensor, - bias: torch.Tensor, - slopes: torch.Tensor, - ) -> torch.Tensor: - """Perform self-attention. - - There are three attention implementations supported: vanilla attention with ALiBi, - Triton Flash Attention with ALibi, and Flash Attention 2 with ALiBi - - In order to use the Triton kernel, dropout must be zero (i.e. attention_probs_dropout_prob = 0) - - The arguments are unpadded, and our implementations of attention require padded arguments, - so we first call `pad_input`. Once we compute attention, we re-unpad our outputs for the other layers. - The pad/unpad operations add overhead, but not sending pad tokens through ffs saves compute. - - Args: - hidden_states: (total_nnz, dim) - cu_seqlens: (batch + 1,) - max_seqlen_in_batch: int - indices: (total_nnz,) - attn_mask: (batch, max_seqlen_in_batch) - bias: (batch, heads, max_seqlen_in_batch, max_seqlen_in_batch) - slopes: (heads) or (batch, heads) - - Returns: - attention: (total_nnz, dim) - """ - bs, dim = hidden_states.shape - qkv = self.Wqkv(hidden_states) - - # Option 1: Flash Attention with ALiBi - if IMPL_USE_FLASH2: - qkv = qkv.view(-1, 3, self.num_attention_heads, self.attention_head_size) - assert 1 <= len(slopes.shape) <= 2, f"{slopes=}" - assert slopes.shape[-1] == self.num_attention_heads, f"{slopes=}" - - convert_dtype = qkv.dtype not in (torch.float16, torch.bfloat16) - if convert_dtype: - # FA2 implementation only supports fp16 and bf16 - # If FA2 is supported, bfloat16 must be supported - # as of FA2 2.4.2. (Turing GPUs not supported) - orig_dtype = qkv.dtype - qkv = qkv.to(torch.bfloat16) - - attention = flash_attn_varlen_qkvpacked_func( - qkv, - cu_seqlens=cu_seqlens, - max_seqlen=max_seqlen_in_batch, - dropout_p=self.p_dropout, - alibi_slopes=slopes, - ) - attention = attention.to(orig_dtype) # type: ignore - else: - attention = flash_attn_varlen_qkvpacked_func( - qkv, - cu_seqlens=cu_seqlens, - max_seqlen=max_seqlen_in_batch, - dropout_p=self.p_dropout, - alibi_slopes=slopes, - ) - else: - qkv = bert_padding_module.pad_input( - qkv, indices, cu_seqlens.shape[0] - 1, max_seqlen_in_batch - ) # batch, max_seqlen_in_batch, thd - unpad_bs, *_ = qkv.shape - qkv = qkv.view(unpad_bs, -1, 3, self.num_attention_heads, self.attention_head_size) - # if we have nonzero attention dropout (e.g. during fine-tuning) or no Triton, compute attention in PyTorch - q = qkv[:, :, 0, :, :].permute(0, 2, 1, 3) # b h s d - k = qkv[:, :, 1, :, :].permute(0, 2, 3, 1) # b h d s - v = qkv[:, :, 2, :, :].permute(0, 2, 1, 3) # b h s d - attention_scores = torch.matmul(q, k) / math.sqrt(self.attention_head_size) - attention_scores = attention_scores + bias - attention_probs = nn.functional.softmax(attention_scores, dim=-1) - attention_probs = self.dropout(attention_probs) - attention = torch.matmul(attention_probs, v).permute(0, 2, 1, 3) # b s h d - - if not IMPL_USE_FLASH2: - attention = bert_padding_module.unpad_input_only(attention, torch.squeeze(attn_mask) == 1) - return attention.view(bs, dim) - - -# Copy of transformer's library BertSelfOutput that will not be caught by surgery methods looking for HF BERT modules. -class BertSelfOutput(nn.Module): - """Computes the output of the attention layer. - - This module is modeled after the Hugging Face BERT's - :class:`~transformers.model.bert.modeling_bert.BertSelfOutput`. - The implementation is identical. Rather than use the original module - directly, we re-implement it here so that Mosaic BERT's modules will not - be affected by any Composer surgery algorithm that modifies Hugging Face - BERT modules. - """ - - def __init__(self, config, use_rmsnorm: bool = True): - super().__init__() - self.dense = nn.Linear(config.hidden_size, config.hidden_size) - self.LayerNorm = ( - RMSNorm(config.hidden_size, eps=config.layer_norm_eps) - if use_rmsnorm - else nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) - ) - self.dropout = nn.Dropout(config.hidden_dropout_prob) - - def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor: - hidden_states = self.dense(hidden_states) - hidden_states = self.dropout(hidden_states) - hidden_states = self.LayerNorm(hidden_states + input_tensor) - return hidden_states - - -class BertUnpadAttention(nn.Module): - """Chains attention, Dropout, and LayerNorm for Mosaic BERT.""" - - def __init__(self, config, use_rmsnorm: bool = True): - super().__init__() - self.self = BertUnpadSelfAttention(config) - self.output = BertSelfOutput( - config, - use_rmsnorm=use_rmsnorm, - ) - - def forward( - self, - input_tensor: torch.Tensor, - cu_seqlens: torch.Tensor, - max_s: int, - subset_idx: Optional[torch.Tensor] = None, - indices: Optional[torch.Tensor] = None, - attn_mask: Optional[torch.Tensor] = None, - bias: Optional[torch.Tensor] = None, - slopes: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - """Forward pass for scaled self-attention without padding. - - Arguments: - input_tensor: (total_nnz, dim) - cu_seqlens: (batch + 1,) - max_s: int - subset_idx: () set of indices whose values we care about at the end of the layer - (e.g., the masked tokens, if this is the final layer). - indices: None or (total_nnz,) - attn_mask: None or (batch, max_seqlen_in_batch) - bias: None or (batch, heads, max_seqlen_in_batch, max_seqlen_in_batch) - slopes: None or (batch, heads) or (heads,) - """ - assert (bias is None) == (slopes is None), f"{bias=}, {slopes=}" - self_output = self.self(input_tensor, cu_seqlens, max_s, indices, attn_mask, bias, slopes) - if subset_idx is not None: - return self.output( - bert_padding_module.index_first_axis(self_output, subset_idx), - bert_padding_module.index_first_axis(input_tensor, subset_idx), - ) - else: - return self.output(self_output, input_tensor) - - -class BertGatedLinearUnitMLP(nn.Module): - """Applies the FFN at the end of each Mosaic BERT layer. - - Compared to the default BERT architecture, this block replaces :class:`~transformers.model.bert.modeling_bert.BertIntermediate` - and :class:`~transformers.model.bert.modeling_bert.SelfOutput` with a single module that has similar functionality, but - introduces Gated Linear Units. - - Note: Mosaic BERT adds parameters in order to implement Gated Linear Units. To keep parameter count consistent with that of a - standard Hugging Face BERT, scale down `config.intermediate_size` by 2/3. For example, a Mosaic BERT constructed with - `config.intermediate_size=2048` will have the same parameter footprint as its Hugging Face BERT counterpart constructed - with the `config.intermediate_size=3072`. - However, in most cases it will not be necessary to adjust `config.intermediate_size` since, despite the increased - parameter size, Mosaic BERT typically offers a net higher throughput than a Hugging Face BERT built from the same `config`. - """ - - def __init__( - self, - config, - use_rmsnorm: bool = True, - use_silu: bool = True, - ): - super().__init__() - self.config = config - self.gated_layers = nn.Linear(config.hidden_size, config.intermediate_size * 2, bias=False) - self.act = nn.SiLU() if use_silu else nn.GELU() - self.wo = nn.Linear(config.intermediate_size, config.hidden_size) - self.dropout = nn.Dropout(config.hidden_dropout_prob) - self.layernorm = ( - RMSNorm(config.hidden_size, eps=config.layer_norm_eps) - if use_rmsnorm - else nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) - ) - - def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: - """Compute new hidden states from current hidden states. - - Args: - hidden_states (torch.Tensor): The (unpadded) hidden states from - the attention layer [nnz, dim]. - """ - residual_connection = hidden_states - # compute the activation - hidden_states = self.gated_layers(hidden_states) - gated = hidden_states[:, : self.config.intermediate_size] - non_gated = hidden_states[:, self.config.intermediate_size :] - hidden_states = self.act(gated) * non_gated - hidden_states = self.dropout(hidden_states) - # multiply by the second matrix - hidden_states = self.wo(hidden_states) - # add the residual connection and post-LN - hidden_states = self.layernorm(hidden_states + residual_connection) - return hidden_states - - -class BertLayer(nn.Module): - """Composes the Mosaic BERT attention and FFN blocks into a single layer.""" - - def __init__(self, config, use_rmsnorm: bool = True, use_silu: bool = True): - super(BertLayer, self).__init__() - self.attention = BertUnpadAttention(config, use_rmsnorm=use_rmsnorm) - self.mlp = BertGatedLinearUnitMLP(config, use_rmsnorm=use_rmsnorm, use_silu=use_silu) - - def forward( - self, - hidden_states: torch.Tensor, - cu_seqlens: torch.Tensor, - seqlen: int, - subset_idx: Optional[torch.Tensor] = None, - indices: Optional[torch.Tensor] = None, - attn_mask: Optional[torch.Tensor] = None, - bias: Optional[torch.Tensor] = None, - slopes: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - """Forward pass for a BERT layer, including both attention and MLP. - - Args: - hidden_states: (total_nnz, dim) - cu_seqlens: (batch + 1,) - seqlen: int - subset_idx: () set of indices whose values we care about at the end of the layer - (e.g., the masked tokens, if this is the final layer). - indices: None or (total_nnz,) - attn_mask: None or (batch, max_seqlen_in_batch) - bias: None or (batch, heads, max_seqlen_in_batch, max_seqlen_in_batch) - slopes: None or (batch, heads) or (heads,) - """ - assert (bias is None) == (slopes is None), f"{bias=}, {slopes=}" - attention_output = self.attention( - hidden_states, cu_seqlens, seqlen, subset_idx, indices, attn_mask, bias, slopes - ) - layer_output = self.mlp(attention_output) - return layer_output - - -class BertEncoder(nn.Module): - """A stack of BERT layers providing the backbone of Mosaic BERT. - - This module is modeled after the Hugging Face BERT's :class:`~transformers.model.bert.modeling_bert.BertEncoder`, - but with substantial modifications to implement unpadding and ALiBi. - - Compared to the analogous Hugging Face BERT module, this module handles unpadding to reduce unnecessary computation - at padded tokens, and pre-computes attention biases to implement ALiBi. - """ - - def __init__(self, config, use_rmsnorm: bool = True, use_silu: bool = True): - super().__init__() - layer = BertLayer(config, use_rmsnorm=use_rmsnorm, use_silu=use_silu) - self.layer = nn.ModuleList([copy.deepcopy(layer) for _ in range(config.num_hidden_layers)]) - - self.num_attention_heads = config.num_attention_heads - - # The alibi mask will be dynamically expanded if it is too small for - # the input the model receives. But it generally helps to initialize it - # to a reasonably large size to help pre-allocate CUDA memory. - # The default `alibi_starting_size` is 512. - self._current_alibi_size = int(config.alibi_starting_size) - self.alibi = torch.zeros((1, self.num_attention_heads, self._current_alibi_size, self._current_alibi_size)) - self.rebuild_alibi_tensor(size=config.alibi_starting_size) - - def rebuild_alibi_tensor(self, size: int, device: Optional[Union[torch.device, str]] = None): - # Alibi - # Following https://github.com/ofirpress/attention_with_linear_biases/issues/5 (Implementation 1) - # In the causal case, you can exploit the fact that softmax is invariant to a uniform translation - # of the logits, which makes the math work out *after* applying causal masking. If no causal masking - # will be applied, it is necessary to construct the diagonal mask. - n_heads = self.num_attention_heads - - def _get_alibi_head_slopes(n_heads: int) -> List[float]: - def get_slopes_power_of_2(n_heads: int) -> List[float]: - start = 2 ** (-(2 ** -(math.log2(n_heads) - 3))) - ratio = start - return [start * ratio**i for i in range(n_heads)] - - # In the paper, they only train models that have 2^a heads for some a. This function - # has some good properties that only occur when the input is a power of 2. To - # maintain that even when the number of heads is not a power of 2, we use a - # workaround. - if math.log2(n_heads).is_integer(): - return get_slopes_power_of_2(n_heads) - - closest_power_of_2 = 2 ** math.floor(math.log2(n_heads)) - slopes_a = get_slopes_power_of_2(closest_power_of_2) - slopes_b = _get_alibi_head_slopes(2 * closest_power_of_2) - slopes_b = slopes_b[0::2][: n_heads - closest_power_of_2] - return slopes_a + slopes_b - - context_position = torch.arange(size, device=device)[:, None] - memory_position = torch.arange(size, device=device)[None, :] - relative_position = torch.abs(memory_position - context_position) - # [n_heads, max_token_length, max_token_length] - relative_position = relative_position.unsqueeze(0).expand(n_heads, -1, -1) - slopes = torch.Tensor(_get_alibi_head_slopes(n_heads)).to(device) - self.slopes = slopes - alibi = slopes.unsqueeze(1).unsqueeze(1) * -relative_position - # [1, n_heads, max_token_length, max_token_length] - alibi = alibi.unsqueeze(0) - assert alibi.shape == torch.Size([1, n_heads, size, size]) - - self._current_alibi_size = size - self.alibi = alibi - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: torch.Tensor, - output_all_encoded_layers: Optional[bool] = True, - subset_mask: Optional[torch.Tensor] = None, - ) -> List[torch.Tensor]: - extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) - extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility - extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 - - attention_mask_bool = attention_mask.bool() - batch, seqlen = hidden_states.shape[:2] - # Unpad inputs and mask. It will remove tokens that are padded. - # Assume ntokens is total number of tokens (padded and non-padded) - # and ntokens_unpad is total number of non-padded tokens. - # Then unpadding performs the following compression of the inputs: - # hidden_states[ntokens,hidden] -> hidden_states[ntokens_unpad,hidden] - hidden_states, indices, cu_seqlens, _ = bert_padding_module.unpad_input(hidden_states, attention_mask_bool) - - # Add alibi matrix to extended_attention_mask - if self._current_alibi_size < seqlen: - # Rebuild the alibi tensor when needed - warnings.warn(f"Increasing alibi size from {self._current_alibi_size} to {seqlen}") - self.rebuild_alibi_tensor(size=seqlen, device=hidden_states.device) - elif self.alibi.device != hidden_states.device: - # Device catch-up - self.alibi = self.alibi.to(hidden_states.device) - self.slopes = self.slopes.to(hidden_states.device) # type: ignore - alibi_bias = self.alibi[:, :, :seqlen, :seqlen] - attn_bias = extended_attention_mask[:, :, :seqlen, :seqlen] - alibi_attn_mask = attn_bias + alibi_bias - - all_encoder_layers = [] - if subset_mask is None: - for layer_module in self.layer: - hidden_states = layer_module( - hidden_states, - cu_seqlens, - seqlen, - None, - indices, - attn_mask=attention_mask, - bias=alibi_attn_mask, - slopes=self.slopes, - ) - if output_all_encoded_layers: - all_encoder_layers.append(hidden_states) - # Pad inputs and mask. It will insert back zero-padded tokens. - # Assume ntokens is total number of tokens (padded and non-padded) - # and ntokens_unpad is total number of non-padded tokens. - # Then padding performs the following de-compression: - # hidden_states[ntokens_unpad,hidden] -> hidden_states[ntokens,hidden] - hidden_states = bert_padding_module.pad_input(hidden_states, indices, batch, seqlen) - else: - for i in range(len(self.layer) - 1): - layer_module = self.layer[i] - hidden_states = layer_module( - hidden_states, - cu_seqlens, - seqlen, - None, - indices, - attn_mask=attention_mask, - bias=alibi_attn_mask, - slopes=self.slopes, - ) - if output_all_encoded_layers: - all_encoder_layers.append(hidden_states) - subset_idx = torch.nonzero(subset_mask[attention_mask_bool], as_tuple=False).flatten() - hidden_states = self.layer[-1]( - hidden_states, - cu_seqlens, - seqlen, - subset_idx=subset_idx, - indices=indices, - attn_mask=attention_mask, - bias=alibi_attn_mask, - slopes=self.slopes, - ) - - if not output_all_encoded_layers: - all_encoder_layers.append(hidden_states) - return all_encoder_layers - - -class BertPooler(nn.Module): - def __init__(self, config): - super(BertPooler, self).__init__() - self.dense = nn.Linear(config.hidden_size, config.hidden_size) - self.activation = nn.Tanh() - - def forward(self, hidden_states: torch.Tensor, pool: Optional[bool] = True) -> torch.Tensor: - # We "pool" the model by simply taking the hidden state corresponding - # to the first token. - first_token_tensor = hidden_states[:, 0] if pool else hidden_states - pooled_output = self.dense(first_token_tensor) - pooled_output = self.activation(pooled_output) - return pooled_output - - -class BertPredictionHeadTransform(nn.Module): - def __init__(self, config, use_rmsnorm: bool = True): - super().__init__() - self.dense = nn.Linear(config.hidden_size, config.hidden_size) - if isinstance(config.hidden_act, str): - self.transform_act_fn = ACT2FN[config.hidden_act] - else: - self.transform_act_fn = config.hidden_act - self.LayerNorm = ( - RMSNorm(config.hidden_size, eps=1e-12) if use_rmsnorm else nn.LayerNorm(config.hidden_size, eps=1e-12) - ) - - def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: - hidden_states = self.dense(hidden_states) - hidden_states = self.transform_act_fn(hidden_states) - hidden_states = self.LayerNorm(hidden_states) - return hidden_states - - -class BertModel(BertPreTrainedModel): - """Overall BERT model. - - Args: - config: a BertConfig class instance with the configuration to build a new model - - Inputs: - `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length] - with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts - `extract_features.py`, `run_classifier.py` and `run_squad.py`) - `token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token - types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to - a `sentence B` token (see BERT paper for more details). - `attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices - selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max - input sequence length in the current batch. It's the mask that we typically use for attention when - a batch has varying length sentences. - `output_all_encoded_layers`: boolean which controls the content of the `encoded_layers` output as described below. Default: `True`. - - Outputs: Tuple of (encoded_layers, pooled_output) - `encoded_layers`: controlled by `output_all_encoded_layers` argument: - - `output_all_encoded_layers=True`: outputs a list of the full sequences of encoded-hidden-states at the end - of each attention block (i.e. 12 full sequences for BERT-base, 24 for BERT-large), each - encoded-hidden-state is a torch.FloatTensor of size [batch_size, sequence_length, hidden_size], - - `output_all_encoded_layers=False`: outputs only the full sequence of hidden-states corresponding - to the last attention block of shape [batch_size, sequence_length, hidden_size], - `pooled_output`: a torch.FloatTensor of size [batch_size, hidden_size] which is the output of a - classifier pretrained on top of the hidden state associated to the first character of the - input (`CLS`) to train on the Next-Sentence task (see BERT's paper). - - Example usage: - ```python - # Already been converted into WordPiece token ids - input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]]) - input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]]) - token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]]) - config = modeling.BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768, - num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072) - model = BertModel(config=config) - all_encoder_layers, pooled_output = model(input_ids, token_type_ids, input_mask) - ``` - """ - - def __init__( - self, - config, - add_pooling_layer: bool = True, - use_rmsnorm: bool = True, - use_silu: bool = True, - ): - super(BertModel, self).__init__(config) - self.embeddings = BertEmbeddings(config, use_rmsnorm=use_rmsnorm) - self.encoder = BertEncoder(config, use_rmsnorm=use_rmsnorm, use_silu=use_silu) - self.pooler = BertPooler(config) if add_pooling_layer else None - self.post_init() - - def get_input_embeddings(self): - return self.embeddings.word_embeddings - - def set_input_embeddings(self, value): - self.embeddings.word_embeddings = value - - def forward( - self, - input_ids: torch.Tensor, - token_type_ids: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.Tensor] = None, - output_all_encoded_layers: Optional[bool] = False, - masked_tokens_mask: Optional[torch.Tensor] = None, - **kwargs, - ) -> Tuple[Union[List[torch.Tensor], torch.Tensor], Optional[torch.Tensor]]: - if attention_mask is None: - attention_mask = torch.ones_like(input_ids) - if token_type_ids is None: - token_type_ids = torch.zeros_like(input_ids) - - embedding_output = self.embeddings(input_ids, token_type_ids, position_ids) - - subset_mask = [] - first_col_mask = [] - - if masked_tokens_mask is None: - subset_mask = None - else: - first_col_mask = torch.zeros_like(masked_tokens_mask) - first_col_mask[:, 0] = True - subset_mask = masked_tokens_mask | first_col_mask - - encoder_outputs = self.encoder( - embedding_output, - attention_mask, - output_all_encoded_layers=output_all_encoded_layers, - subset_mask=subset_mask, - ) - - if masked_tokens_mask is None: - sequence_output = encoder_outputs[-1] - pooled_output = self.pooler(sequence_output) if self.pooler is not None else None - else: - # TD [2022-03-01]: the indexing here is very tricky. - attention_mask_bool = attention_mask.bool() - subset_idx = subset_mask[attention_mask_bool] # type: ignore - sequence_output = encoder_outputs[-1][masked_tokens_mask[attention_mask_bool][subset_idx]] - if self.pooler is not None: - pool_input = encoder_outputs[-1][first_col_mask[attention_mask_bool][subset_idx]] - pooled_output = self.pooler(pool_input, pool=False) - else: - pooled_output = None - - if not output_all_encoded_layers: - encoder_outputs = sequence_output - - if self.pooler is not None: - return encoder_outputs, pooled_output - - return encoder_outputs, None - - -################### -# Bert Heads -################### -class BertLMPredictionHead(nn.Module): - def __init__(self, config, bert_model_embedding_weights, use_rmsnorm: bool = True): - super().__init__() - self.transform = BertPredictionHeadTransform(config, use_rmsnorm=use_rmsnorm) - # The output weights are the same as the input embeddings, but there is - # an output-only bias for each token. - self.decoder = nn.Linear(bert_model_embedding_weights.size(1), bert_model_embedding_weights.size(0)) - self.decoder.weight = bert_model_embedding_weights - - def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: - hidden_states = self.transform(hidden_states) - hidden_states = self.decoder(hidden_states) - return hidden_states - - -class BertOnlyMLMHead(nn.Module): - def __init__(self, config, bert_model_embedding_weights, use_rmsnorm: bool = True): - super().__init__() - self.predictions = BertLMPredictionHead(config, bert_model_embedding_weights, use_rmsnorm=use_rmsnorm) - - def forward(self, sequence_output: torch.Tensor) -> torch.Tensor: - prediction_scores = self.predictions(sequence_output) - return prediction_scores - - -class BertOnlyNSPHead(nn.Module): - def __init__(self, config): - super().__init__() - self.seq_relationship = nn.Linear(config.hidden_size, 2) - - def forward(self, pooled_output: torch.Tensor) -> torch.Tensor: - seq_relationship_score = self.seq_relationship(pooled_output) - return seq_relationship_score - - -##################### -# Various Bert models -##################### - - -class BertForPreTraining(BertPreTrainedModel): - # TBD: Coming in Future Commit - pass - - -class BertLMHeadModel(BertPreTrainedModel): - # TBD: Coming in Future Commit - pass - - -class BertForMaskedLM(BertPreTrainedModel): - def __init__(self, config, ablations: dict = {}): - super().__init__(config) - - if config.is_decoder: - warnings.warn( - "If you want to use `BertForMaskedLM` make sure `config.is_decoder=False` for " - "bi-directional self-attention." - ) - - use_rmsnorm = ablations.get("use_rmsnorm", True) - use_silu = ablations.get("use_silu", True) - - self.bert = BertModel( - config, - add_pooling_layer=False, - use_rmsnorm=use_rmsnorm, - use_silu=use_silu, - ) - self.cls = BertOnlyMLMHead(config, self.bert.embeddings.word_embeddings.weight, use_rmsnorm=use_rmsnorm) - - # Initialize weights and apply final processing - self.post_init() - - @classmethod - def from_composer( - cls, pretrained_checkpoint, state_dict=None, cache_dir=None, from_tf=False, config=None, *inputs, **kwargs - ): - """Load from pre-trained.""" - model = cls(config, *inputs, **kwargs) - if from_tf: - raise ValueError("Mosaic BERT does not support loading TensorFlow weights.") - - state_dict = torch.load(pretrained_checkpoint) - # If the state_dict was saved after wrapping with `composer.HuggingFaceModel`, it takes on the `model` prefix - consume_prefix_in_state_dict_if_present(state_dict, prefix="model.") - missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False) - - if len(missing_keys) > 0: - logger.warning(f"Found these missing keys in the checkpoint: {', '.join(missing_keys)}") - if len(unexpected_keys) > 0: - logger.warning(f"Found these unexpected keys in the checkpoint: {', '.join(unexpected_keys)}") - - return model - - def get_output_embeddings(self): - return self.cls.predictions.decoder - - def set_output_embeddings(self, new_embeddings): - self.cls.predictions.decoder = new_embeddings - - def forward( - self, - input_ids: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, - token_type_ids: Optional[torch.Tensor] = None, - position_ids: Optional[torch.Tensor] = None, - head_mask: Optional[torch.Tensor] = None, - inputs_embeds: Optional[torch.Tensor] = None, - encoder_hidden_states: Optional[torch.Tensor] = None, - encoder_attention_mask: Optional[torch.Tensor] = None, - labels: Optional[torch.Tensor] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple[torch.Tensor], MaskedLMOutput]: - # labels should be a `torch.LongTensor` of shape - # `(batch_size, sequence_length)`. These are used for computing the - # masked language modeling loss. - # - # Indices should be in `[-100, 0, ..., config.vocab_size]` (see - # `input_ids` docstring) Tokens with indices set to `-100` are ignored - # (masked), the loss is only computed for the tokens with labels in `[0, - # ..., config.vocab_size]` - # - # Prediction scores are only computed for masked tokens and the (bs, - # seqlen) dimensions are flattened - if (input_ids is not None) == (inputs_embeds is not None): - raise ValueError("Must specify either input_ids or input_embeds!") - - if labels is None: - masked_tokens_mask = None - else: - masked_tokens_mask = labels > 0 - - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - outputs = self.bert( - input_ids, - attention_mask=attention_mask, - token_type_ids=token_type_ids, - position_ids=position_ids, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - masked_tokens_mask=masked_tokens_mask, - ) - - sequence_output = outputs[0] - prediction_scores = self.cls(sequence_output) - - loss = None - if labels is not None: - # Compute loss - loss_fct = nn.CrossEntropyLoss() - masked_token_idx = torch.nonzero(labels.flatten() > 0, as_tuple=False).flatten() - loss = loss_fct(prediction_scores, labels.flatten()[masked_token_idx]) - - assert input_ids is not None, "Coding error; please open an issue" - batch, seqlen = input_ids.shape[:2] - prediction_scores = rearrange( - bert_padding_module.index_put_first_axis(prediction_scores, masked_token_idx, batch * seqlen), - "(b s) d -> b s d", - b=batch, - ) - - if not return_dict: - output = (prediction_scores,) + outputs[2:] - return ((loss,) + output) if loss is not None else output - - return MaskedLMOutput( - loss=loss, - logits=prediction_scores, - hidden_states=None, - attentions=None, - ) - - def prepare_inputs_for_generation(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **model_kwargs): - input_shape = input_ids.shape - effective_batch_size = input_shape[0] - - # add a dummy token - if self.config.pad_token_id is None: - raise ValueError("The PAD token should be defined for generation") - - attention_mask = torch.cat([attention_mask, attention_mask.new_zeros((attention_mask.shape[0], 1))], dim=-1) - dummy_token = torch.full( - (effective_batch_size, 1), self.config.pad_token_id, dtype=torch.long, device=input_ids.device - ) - input_ids = torch.cat([input_ids, dummy_token], dim=1) - - return {"input_ids": input_ids, "attention_mask": attention_mask} - - -class BertForNextSentencePrediction(BertPreTrainedModel): - # TBD: Push in future commit - pass - - -class BertForSequenceClassification(BertPreTrainedModel): - """Bert Model transformer with a sequence classification/regression head. - - This head is just a linear layer on top of the pooled output. Used for, - e.g., GLUE tasks. - """ - - def __init__(self, config): - super().__init__(config) - self.num_labels = config.num_labels - self.config = config - - self.bert = BertModel(config) - classifier_dropout = ( - config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob - ) - self.dropout = nn.Dropout(classifier_dropout) - self.classifier = nn.Linear(config.hidden_size, config.num_labels) - - # Initialize weights and apply final processing - self.post_init() - - @classmethod - def from_composer( - cls, pretrained_checkpoint, state_dict=None, cache_dir=None, from_tf=False, config=None, *inputs, **kwargs - ): - """Load from pre-trained.""" - model = cls(config, *inputs, **kwargs) - if from_tf: - raise ValueError("Mosaic BERT does not support loading TensorFlow weights.") - - state_dict = torch.load(pretrained_checkpoint) - # If the state_dict was saved after wrapping with `composer.HuggingFaceModel`, it takes on the `model` prefix - consume_prefix_in_state_dict_if_present(state_dict, prefix="model.") - missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False) - - if len(missing_keys) > 0: - logger.warning(f"Found these missing keys in the checkpoint: {', '.join(missing_keys)}") - if len(unexpected_keys) > 0: - logger.warning(f"Found these unexpected keys in the checkpoint: {', '.join(unexpected_keys)}") - - return model - - def forward( - self, - input_ids: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, - token_type_ids: Optional[torch.Tensor] = None, - position_ids: Optional[torch.Tensor] = None, - head_mask: Optional[torch.Tensor] = None, - inputs_embeds: Optional[torch.Tensor] = None, - labels: Optional[torch.Tensor] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple[torch.Tensor], SequenceClassifierOutput]: - # labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): - # Labels for computing the sequence classification/regression loss. - # Indices should be in `[0, ..., config.num_labels - 1]`. - # If `config.num_labels == 1` a regression loss is computed - # (mean-square loss). If `config.num_labels > 1` a classification loss - # is computed (cross-entropy). - - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - outputs = self.bert( - input_ids, - attention_mask=attention_mask, - token_type_ids=token_type_ids, - position_ids=position_ids, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - pooled_output = outputs[1] - - pooled_output = self.dropout(pooled_output) - logits = self.classifier(pooled_output) - - loss = None - if labels is not None: - # Compute loss - if self.config.problem_type is None: - if self.num_labels == 1: - self.config.problem_type = "regression" - elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): - self.config.problem_type = "single_label_classification" - else: - self.config.problem_type = "multi_label_classification" - - if self.config.problem_type == "regression": - loss_fct = nn.MSELoss() - if self.num_labels == 1: - loss = loss_fct(logits.squeeze(), labels.squeeze()) - else: - loss = loss_fct(logits, labels) - elif self.config.problem_type == "single_label_classification": - loss_fct = nn.CrossEntropyLoss() - loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) - elif self.config.problem_type == "multi_label_classification": - loss_fct = nn.BCEWithLogitsLoss() - loss = loss_fct(logits, labels) - - if not return_dict: - output = (logits,) + outputs[2:] - return ((loss,) + output) if loss is not None else output - - return SequenceClassifierOutput( - loss=loss, - logits=logits, - hidden_states=None, - attentions=None, - ) - - -class BertForMultipleChoice(BertPreTrainedModel): - # TBD: Push in future commit - pass - - -class BertForTokenClassification(BertPreTrainedModel): - # TBD: Push in future commit - pass - - -class BertForQuestionAnswering(BertPreTrainedModel): - """Bert Model with a span classification head. - - This is used for extractive question-answering tasks like SQuAD (a linear - layers on top of the hidden states' output to compute `span start logits` - and `span end logits`). - """ - - # TBD: Push in future commit diff --git a/src/bert_layers/__init__.py b/src/bert_layers/__init__.py new file mode 100644 index 0000000..c72ba57 --- /dev/null +++ b/src/bert_layers/__init__.py @@ -0,0 +1,32 @@ +from .attention import BertAlibiUnpadAttention, BertAlibiUnpadSelfAttention, BertSelfOutput +from .embeddings import BertAlibiEmbeddings +from .layers import BertAlibiEncoder, BertAlibiLayer, BertResidualGLU +from .model import ( + BertLMPredictionHead, + BertModel, + BertForMaskedLM, + BertForSequenceClassification, + BertOnlyMLMHead, + BertOnlyNSPHead, + BertPooler, + BertPredictionHeadTransform, +) + + +__all__ = [ + "BertAlibiEmbeddings", + "BertAlibiEncoder", + "BertForMaskedLM", + "BertForSequenceClassification", + "BertResidualGLU", + "BertAlibiLayer", + "BertLMPredictionHead", + "BertModel", + "BertOnlyMLMHead", + "BertOnlyNSPHead", + "BertPooler", + "BertPredictionHeadTransform", + "BertSelfOutput", + "BertAlibiUnpadAttention", + "BertAlibiUnpadSelfAttention", +] diff --git a/src/bert_layers/attention.py b/src/bert_layers/attention.py new file mode 100644 index 0000000..63e9e80 --- /dev/null +++ b/src/bert_layers/attention.py @@ -0,0 +1,236 @@ +# Copyright 2024 **AUTHORS_TODO** +# License: Apache-2.0 + +# Copyright 2022 MosaicML Examples authors +# SPDX-License-Identifier: Apache-2.0 + +# Copyright 2023 MosaicML Examples authors +# SPDX-License-Identifier: Apache-2.0 + +# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. +# Copyright (c) 2018-2021, NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2023, Tri Dao. + + +import torch +import torch.nn as nn +import warnings +from typing import Optional +import importlib +import logging +import math + +import bert_padding +from .norm import RMSNorm + +IMPL_USE_FLASH2 = False +# Import Flash Attention 2, which supports ALiBi https://github.com/Dao-AILab/flash-attention +try: + from flash_attn import flash_attn_varlen_qkvpacked_func # type: ignore + + installed_version = importlib.metadata.version("flash_attn") # type: ignore + if installed_version < "2.5.7": + raise ImportError("newer version of flash_attn required (>= 2.5.7)") + IMPL_USE_FLASH2 = True + flash_attn_qkvpacked_func = None +except ImportError as e: + warnings.warn( + f"Failed to import flash_attn. Will use slow and memory hungry PyTorch native implementation: {e}", stacklevel=2 + ) + +logger = logging.getLogger(__name__) + + +class BertAlibiUnpadSelfAttention(nn.Module): + """Performs multi-headed self attention on a batch of unpadded sequences. + + If Flash Attention 2 is installed, this module uses Flash Attention to greatly improve throughput. + The Flash Attention implementation used in MosaicBERT supports arbitrary attention biases (which + we use to implement ALiBi), but does not support attention dropout. If either Flash Attention 2 is + not installed or `config.attention_probs_dropout_prob > 0`, the implementation will default to a + math-equivalent pytorch version, which is much slower. + + See `forward` method for additional detail. + """ + + def __init__(self, config): + super().__init__() + if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): + raise ValueError( + f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention " + f"heads ({config.num_attention_heads})" + ) + + self.num_attention_heads = config.num_attention_heads + self.attention_head_size = int(config.hidden_size / config.num_attention_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + self.dropout = nn.Dropout(config.attention_probs_dropout_prob) + self.p_dropout = config.attention_probs_dropout_prob + self.Wqkv = nn.Linear(self.all_head_size, 3 * config.hidden_size) + + # Warn if defaulting to pytorch because of import issues + if not IMPL_USE_FLASH2: + warnings.warn( + "Unable to import flash_attn; defaulting MosaicBERT attention implementation to pytorch (this will reduce throughput when using this model)." + ) + + def forward( + self, + hidden_states: torch.Tensor, + cu_seqlens: torch.Tensor, + max_seqlen_in_batch: int, + indices: torch.Tensor, + attn_mask: torch.Tensor, + bias: torch.Tensor, + slopes: torch.Tensor, + ) -> torch.Tensor: + """Perform self-attention. + + There are three attention implementations supported: vanilla attention with ALiBi, + Triton Flash Attention with ALibi, and Flash Attention 2 with ALiBi + + In order to use the Triton kernel, dropout must be zero (i.e. attention_probs_dropout_prob = 0) + + The arguments are unpadded, and our implementations of attention require padded arguments, + so we first call `pad_input`. Once we compute attention, we re-unpad our outputs for the other layers. + The pad/unpad operations add overhead, but not sending pad tokens through ffs saves compute. + + Args: + hidden_states: (total_nnz, dim) + cu_seqlens: (batch + 1,) + max_seqlen_in_batch: int + indices: (total_nnz,) + attn_mask: (batch, max_seqlen_in_batch) + bias: (batch, heads, max_seqlen_in_batch, max_seqlen_in_batch) + slopes: (heads) or (batch, heads) + + Returns: + attention: (total_nnz, dim) + """ + bs, dim = hidden_states.shape + qkv = self.Wqkv(hidden_states) + + # Option 1: Flash Attention with ALiBi + if IMPL_USE_FLASH2: + qkv = qkv.view(-1, 3, self.num_attention_heads, self.attention_head_size) + assert 1 <= len(slopes.shape) <= 2, f"{slopes=}" + assert slopes.shape[-1] == self.num_attention_heads, f"{slopes=}" + + convert_dtype = qkv.dtype not in (torch.float16, torch.bfloat16) + if convert_dtype: + # FA2 implementation only supports fp16 and bf16 + # If FA2 is supported, bfloat16 must be supported + # as of FA2 2.4.2. (Turing GPUs not supported) + orig_dtype = qkv.dtype + qkv = qkv.to(torch.bfloat16) + + attention = flash_attn_varlen_qkvpacked_func( + qkv, + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen_in_batch, + dropout_p=self.p_dropout, + alibi_slopes=slopes, + ) + attention = attention.to(orig_dtype) # type: ignore + else: + attention = flash_attn_varlen_qkvpacked_func( + qkv, + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen_in_batch, + dropout_p=self.p_dropout, + alibi_slopes=slopes, + ) + else: + qkv = bert_padding.pad_input( + qkv, indices, cu_seqlens.shape[0] - 1, max_seqlen_in_batch + ) # batch, max_seqlen_in_batch, thd + unpad_bs, *_ = qkv.shape + qkv = qkv.view(unpad_bs, -1, 3, self.num_attention_heads, self.attention_head_size) + # if we have nonzero attention dropout (e.g. during fine-tuning) or no Triton, compute attention in PyTorch + q = qkv[:, :, 0, :, :].permute(0, 2, 1, 3) # b h s d + k = qkv[:, :, 1, :, :].permute(0, 2, 3, 1) # b h d s + v = qkv[:, :, 2, :, :].permute(0, 2, 1, 3) # b h s d + attention_scores = torch.matmul(q, k) / math.sqrt(self.attention_head_size) + attention_scores = attention_scores + bias + attention_probs = nn.functional.softmax(attention_scores, dim=-1) + attention_probs = self.dropout(attention_probs) + attention = torch.matmul(attention_probs, v).permute(0, 2, 1, 3) # b s h d + + if not IMPL_USE_FLASH2: + attention = bert_padding.unpad_input_only(attention, torch.squeeze(attn_mask) == 1) + return attention.view(bs, dim) + + +# Copy of transformer's library BertSelfOutput that will not be caught by surgery methods looking for HF BERT modules. +class BertSelfOutput(nn.Module): + """Computes the output of the attention layer. + + This module is modeled after the Hugging Face BERT's + :class:`~transformers.model.bert.modeling_bert.BertSelfOutput`. + The implementation is identical. Rather than use the original module + directly, we re-implement it here so that Mosaic BERT's modules will not + be affected by any Composer surgery algorithm that modifies Hugging Face + BERT modules. + """ + + def __init__(self, config, use_rmsnorm: bool = True): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.LayerNorm = ( + RMSNorm(config.hidden_size, eps=config.layer_norm_eps) + if use_rmsnorm + else nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + ) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +class BertAlibiUnpadAttention(nn.Module): + """Chains attention, Dropout, and LayerNorm for Mosaic BERT.""" + + def __init__(self, config, use_rmsnorm: bool = True): + super().__init__() + self.self = BertAlibiUnpadSelfAttention(config) + self.output = BertSelfOutput( + config, + use_rmsnorm=use_rmsnorm, + ) + + def forward( + self, + input_tensor: torch.Tensor, + cu_seqlens: torch.Tensor, + max_s: int, + subset_idx: Optional[torch.Tensor] = None, + indices: Optional[torch.Tensor] = None, + attn_mask: Optional[torch.Tensor] = None, + bias: Optional[torch.Tensor] = None, + slopes: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """Forward pass for scaled self-attention without padding. + + Arguments: + input_tensor: (total_nnz, dim) + cu_seqlens: (batch + 1,) + max_s: int + subset_idx: () set of indices whose values we care about at the end of the layer + (e.g., the masked tokens, if this is the final layer). + indices: None or (total_nnz,) + attn_mask: None or (batch, max_seqlen_in_batch) + bias: None or (batch, heads, max_seqlen_in_batch, max_seqlen_in_batch) + slopes: None or (batch, heads) or (heads,) + """ + assert (bias is None) == (slopes is None), f"{bias=}, {slopes=}" + self_output = self.self(input_tensor, cu_seqlens, max_s, indices, attn_mask, bias, slopes) + if subset_idx is not None: + return self.output( + bert_padding.index_first_axis(self_output, subset_idx), + bert_padding.index_first_axis(input_tensor, subset_idx), + ) + else: + return self.output(self_output, input_tensor) diff --git a/src/bert_layers/embeddings.py b/src/bert_layers/embeddings.py new file mode 100644 index 0000000..ecb9f19 --- /dev/null +++ b/src/bert_layers/embeddings.py @@ -0,0 +1,101 @@ +# Copyright 2024 **AUTHORS_TODO** +# License: Apache-2.0 + +# Copyright 2022 MosaicML Examples authors +# SPDX-License-Identifier: Apache-2.0 + +# Copyright 2023 MosaicML Examples authors +# SPDX-License-Identifier: Apache-2.0 + +# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. +# Copyright (c) 2018-2021, NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2023, Tri Dao. + + +import torch +import torch.nn as nn +from typing import Optional + +from .norm import RMSNorm + + +class BertAlibiEmbeddings(nn.Module): + """Construct the embeddings for words, ignoring position. + + There are no positional embeddings since we use ALiBi and token_type + embeddings. + + This module is modeled after the Hugging Face BERT's + :class:`~transformers.model.bert.modeling_bert.BertAlibiEmbeddings`, but is + modified as part of Mosaic BERT's ALiBi implementation. The key change is + that position embeddings are removed. Position information instead comes + from attention biases that scale linearly with the position distance + between query and key tokens. + + This module ignores the `position_ids` input to the `forward` method. + """ + + def __init__(self, config, use_rmsnorm: bool = True): + super().__init__() + self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id) + # ALiBi doesn't use position embeddings + self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size) + + self.LayerNorm = ( + RMSNorm(config.hidden_size, eps=config.layer_norm_eps) + if use_rmsnorm + else nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + ) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + self.register_buffer( + "token_type_ids", torch.zeros(config.max_position_embeddings, dtype=torch.long), persistent=False + ) + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + past_key_values_length: int = 0, + ) -> torch.Tensor: + if (input_ids is not None) == (inputs_embeds is not None): + raise ValueError("Must specify either input_ids or input_embeds!") + if input_ids is not None: + input_shape = input_ids.size() + else: + assert inputs_embeds is not None # just for type checking + input_shape = inputs_embeds.size()[:-1] + + seq_length = input_shape[1] + + if position_ids is None: + # great! ALiBi + pass + + # Setting the token_type_ids to the registered buffer in constructor + # where it is all zeros, which usually occurs when it's auto-generated; + # registered buffer helps users when tracing the model without passing + # token_type_ids, solves issue #5664 + if token_type_ids is None: + if hasattr(self, "token_type_ids"): + assert isinstance(self.token_type_ids, torch.LongTensor) + buffered_token_type_ids = self.token_type_ids[:, :seq_length] + buffered_token_type_ids_expanded = buffered_token_type_ids.expand(input_shape[0], seq_length) + token_type_ids = buffered_token_type_ids_expanded # type: ignore + else: + token_type_ids = torch.zeros( + input_shape, # type: ignore + dtype=torch.long, + device=self.word_embeddings.device, + ) # type: ignore # yapf: disable + + if inputs_embeds is None: + inputs_embeds = self.word_embeddings(input_ids) + token_type_embeddings = self.token_type_embeddings(token_type_ids) + + embeddings = inputs_embeds + token_type_embeddings + # no position embeddings! ALiBi + embeddings = self.LayerNorm(embeddings) + embeddings = self.dropout(embeddings) + return embeddings diff --git a/src/bert_layers/layers.py b/src/bert_layers/layers.py new file mode 100644 index 0000000..3988a4f --- /dev/null +++ b/src/bert_layers/layers.py @@ -0,0 +1,254 @@ +# Copyright 2024 **AUTHORS_TODO** +# License: Apache-2.0 + +# Copyright 2022 MosaicML Examples authors +# SPDX-License-Identifier: Apache-2.0 + +# Copyright 2023 MosaicML Examples authors +# SPDX-License-Identifier: Apache-2.0 + +# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. +# Copyright (c) 2018-2021, NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2023, Tri Dao. + + +import copy +import math +import warnings +from typing import Optional, Union, List + +import torch +import torch.nn as nn +from transformers.activations import ACT2FN + +import bert_padding +from .attention import BertAlibiUnpadAttention +from .mlp import BertResidualGLU +from .norm import RMSNorm + + +class BertAlibiLayer(nn.Module): + """Composes the Mosaic BERT attention and FFN blocks into a single layer.""" + + def __init__(self, config, use_rmsnorm: bool = True, use_silu: bool = True): + super().__init__() + self.attention = BertAlibiUnpadAttention(config, use_rmsnorm=use_rmsnorm) + self.mlp = BertResidualGLU(config, use_rmsnorm=use_rmsnorm, use_silu=use_silu) + + def forward( + self, + hidden_states: torch.Tensor, + cu_seqlens: torch.Tensor, + seqlen: int, + subset_idx: Optional[torch.Tensor] = None, + indices: Optional[torch.Tensor] = None, + attn_mask: Optional[torch.Tensor] = None, + bias: Optional[torch.Tensor] = None, + slopes: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """Forward pass for a BERT layer, including both attention and MLP. + + Args: + hidden_states: (total_nnz, dim) + cu_seqlens: (batch + 1,) + seqlen: int + subset_idx: () set of indices whose values we care about at the end of the layer + (e.g., the masked tokens, if this is the final layer). + indices: None or (total_nnz,) + attn_mask: None or (batch, max_seqlen_in_batch) + bias: None or (batch, heads, max_seqlen_in_batch, max_seqlen_in_batch) + slopes: None or (batch, heads) or (heads,) + """ + assert (bias is None) == (slopes is None), f"{bias=}, {slopes=}" + attention_output = self.attention( + hidden_states, cu_seqlens, seqlen, subset_idx, indices, attn_mask, bias, slopes + ) + layer_output = self.mlp(attention_output) + return layer_output + + +class BertAlibiEncoder(nn.Module): + """A stack of BERT layers providing the backbone of Mosaic BERT. + + This module is modeled after the Hugging Face BERT's :class:`~transformers.model.bert.modeling_bert.BertAlibiEncoder`, + but with substantial modifications to implement unpadding and ALiBi. + + Compared to the analogous Hugging Face BERT module, this module handles unpadding to reduce unnecessary computation + at padded tokens, and pre-computes attention biases to implement ALiBi. + """ + + def __init__(self, config, use_rmsnorm: bool = True, use_silu: bool = True): + super().__init__() + layer = BertAlibiLayer(config, use_rmsnorm=use_rmsnorm, use_silu=use_silu) + self.layer = nn.ModuleList([copy.deepcopy(layer) for _ in range(config.num_hidden_layers)]) + + self.num_attention_heads = config.num_attention_heads + + # The alibi mask will be dynamically expanded if it is too small for + # the input the model receives. But it generally helps to initialize it + # to a reasonably large size to help pre-allocate CUDA memory. + # The default `alibi_starting_size` is 512. + self._current_alibi_size = int(config.alibi_starting_size) + self.alibi = torch.zeros((1, self.num_attention_heads, self._current_alibi_size, self._current_alibi_size)) + self.rebuild_alibi_tensor(size=config.alibi_starting_size) + + def rebuild_alibi_tensor(self, size: int, device: Optional[Union[torch.device, str]] = None): + # Alibi + # Following https://github.com/ofirpress/attention_with_linear_biases/issues/5 (Implementation 1) + # In the causal case, you can exploit the fact that softmax is invariant to a uniform translation + # of the logits, which makes the math work out *after* applying causal masking. If no causal masking + # will be applied, it is necessary to construct the diagonal mask. + n_heads = self.num_attention_heads + + def _get_alibi_head_slopes(n_heads: int) -> List[float]: + def get_slopes_power_of_2(n_heads: int) -> List[float]: + start = 2 ** (-(2 ** -(math.log2(n_heads) - 3))) + ratio = start + return [start * ratio**i for i in range(n_heads)] + + # In the paper, they only train models that have 2^a heads for some a. This function + # has some good properties that only occur when the input is a power of 2. To + # maintain that even when the number of heads is not a power of 2, we use a + # workaround. + if math.log2(n_heads).is_integer(): + return get_slopes_power_of_2(n_heads) + + closest_power_of_2 = 2 ** math.floor(math.log2(n_heads)) + slopes_a = get_slopes_power_of_2(closest_power_of_2) + slopes_b = _get_alibi_head_slopes(2 * closest_power_of_2) + slopes_b = slopes_b[0::2][: n_heads - closest_power_of_2] + return slopes_a + slopes_b + + context_position = torch.arange(size, device=device)[:, None] + memory_position = torch.arange(size, device=device)[None, :] + relative_position = torch.abs(memory_position - context_position) + # [n_heads, max_token_length, max_token_length] + relative_position = relative_position.unsqueeze(0).expand(n_heads, -1, -1) + slopes = torch.Tensor(_get_alibi_head_slopes(n_heads)).to(device) + self.slopes = slopes + alibi = slopes.unsqueeze(1).unsqueeze(1) * -relative_position + # [1, n_heads, max_token_length, max_token_length] + alibi = alibi.unsqueeze(0) + assert alibi.shape == torch.Size([1, n_heads, size, size]) + + self._current_alibi_size = size + self.alibi = alibi + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor, + output_all_encoded_layers: Optional[bool] = True, + subset_mask: Optional[torch.Tensor] = None, + ) -> List[torch.Tensor]: + extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) + extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility + extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 + + attention_mask_bool = attention_mask.bool() + batch, seqlen = hidden_states.shape[:2] + # Unpad inputs and mask. It will remove tokens that are padded. + # Assume ntokens is total number of tokens (padded and non-padded) + # and ntokens_unpad is total number of non-padded tokens. + # Then unpadding performs the following compression of the inputs: + # hidden_states[ntokens,hidden] -> hidden_states[ntokens_unpad,hidden] + hidden_states, indices, cu_seqlens, _ = bert_padding.unpad_input(hidden_states, attention_mask_bool) + + # Add alibi matrix to extended_attention_mask + if self._current_alibi_size < seqlen: + # Rebuild the alibi tensor when needed + warnings.warn(f"Increasing alibi size from {self._current_alibi_size} to {seqlen}") + self.rebuild_alibi_tensor(size=seqlen, device=hidden_states.device) + elif self.alibi.device != hidden_states.device: + # Device catch-up + self.alibi = self.alibi.to(hidden_states.device) + self.slopes = self.slopes.to(hidden_states.device) # type: ignore + alibi_bias = self.alibi[:, :, :seqlen, :seqlen] + attn_bias = extended_attention_mask[:, :, :seqlen, :seqlen] + alibi_attn_mask = attn_bias + alibi_bias + + all_encoder_layers = [] + if subset_mask is None: + for layer_module in self.layer: + hidden_states = layer_module( + hidden_states, + cu_seqlens, + seqlen, + None, + indices, + attn_mask=attention_mask, + bias=alibi_attn_mask, + slopes=self.slopes, + ) + if output_all_encoded_layers: + all_encoder_layers.append(hidden_states) + # Pad inputs and mask. It will insert back zero-padded tokens. + # Assume ntokens is total number of tokens (padded and non-padded) + # and ntokens_unpad is total number of non-padded tokens. + # Then padding performs the following de-compression: + # hidden_states[ntokens_unpad,hidden] -> hidden_states[ntokens,hidden] + hidden_states = bert_padding.pad_input(hidden_states, indices, batch, seqlen) + else: + for i in range(len(self.layer) - 1): + layer_module = self.layer[i] + hidden_states = layer_module( + hidden_states, + cu_seqlens, + seqlen, + None, + indices, + attn_mask=attention_mask, + bias=alibi_attn_mask, + slopes=self.slopes, + ) + if output_all_encoded_layers: + all_encoder_layers.append(hidden_states) + subset_idx = torch.nonzero(subset_mask[attention_mask_bool], as_tuple=False).flatten() + hidden_states = self.layer[-1]( + hidden_states, + cu_seqlens, + seqlen, + subset_idx=subset_idx, + indices=indices, + attn_mask=attention_mask, + bias=alibi_attn_mask, + slopes=self.slopes, + ) + + if not output_all_encoded_layers: + all_encoder_layers.append(hidden_states) + return all_encoder_layers + + +class BertPooler(nn.Module): + def __init__(self, config): + super(BertPooler, self).__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.activation = nn.Tanh() + + def forward(self, hidden_states: torch.Tensor, pool: Optional[bool] = True) -> torch.Tensor: + # We "pool" the model by simply taking the hidden state corresponding + # to the first token. + first_token_tensor = hidden_states[:, 0] if pool else hidden_states + pooled_output = self.dense(first_token_tensor) + pooled_output = self.activation(pooled_output) + return pooled_output + + +class BertPredictionHeadTransform(nn.Module): + def __init__(self, config, use_rmsnorm: bool = True): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + if isinstance(config.hidden_act, str): + self.transform_act_fn = ACT2FN[config.hidden_act] + else: + self.transform_act_fn = config.hidden_act + self.LayerNorm = ( + RMSNorm(config.hidden_size, eps=1e-12) if use_rmsnorm else nn.LayerNorm(config.hidden_size, eps=1e-12) + ) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.transform_act_fn(hidden_states) + hidden_states = self.LayerNorm(hidden_states) + return hidden_states diff --git a/src/bert_layers/mlp.py b/src/bert_layers/mlp.py new file mode 100644 index 0000000..32b1065 --- /dev/null +++ b/src/bert_layers/mlp.py @@ -0,0 +1,71 @@ +# Copyright 2024 **AUTHORS_TODO** +# License: Apache-2.0 + +# Copyright 2022 MosaicML Examples authors +# SPDX-License-Identifier: Apache-2.0 + +# Copyright 2023 MosaicML Examples authors +# SPDX-License-Identifier: Apache-2.0 + +# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. +# Copyright (c) 2018-2021, NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2023, Tri Dao. + +import torch +import torch.nn as nn + +from .norm import RMSNorm + + +class BertResidualGLU(nn.Module): + """Applies the FFN at the end of each Mosaic BERT layer. + + Compared to the default BERT architecture, this block replaces :class:`~transformers.model.bert.modeling_bert.BertIntermediate` + and :class:`~transformers.model.bert.modeling_bert.SelfOutput` with a single module that has similar functionality, but + introduces Gated Linear Units. + + Note: Mosaic BERT adds parameters in order to implement Gated Linear Units. To keep parameter count consistent with that of a + standard Hugging Face BERT, scale down `config.intermediate_size` by 2/3. For example, a Mosaic BERT constructed with + `config.intermediate_size=2048` will have the same parameter footprint as its Hugging Face BERT counterpart constructed + with the `config.intermediate_size=3072`. + However, in most cases it will not be necessary to adjust `config.intermediate_size` since, despite the increased + parameter size, Mosaic BERT typically offers a net higher throughput than a Hugging Face BERT built from the same `config`. + """ + + def __init__( + self, + config, + use_rmsnorm: bool = True, + use_silu: bool = True, + ): + super().__init__() + self.config = config + self.gated_layers = nn.Linear(config.hidden_size, config.intermediate_size * 2, bias=False) + self.act = nn.SiLU() if use_silu else nn.GELU() + self.wo = nn.Linear(config.intermediate_size, config.hidden_size) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + self.layernorm = ( + RMSNorm(config.hidden_size, eps=config.layer_norm_eps) + if use_rmsnorm + else nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + ) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + """Compute new hidden states from current hidden states. + + Args: + hidden_states (torch.Tensor): The (unpadded) hidden states from + the attention layer [nnz, dim]. + """ + residual_connection = hidden_states + # compute the activation + hidden_states = self.gated_layers(hidden_states) + gated = hidden_states[:, : self.config.intermediate_size] + non_gated = hidden_states[:, self.config.intermediate_size :] + hidden_states = self.act(gated) * non_gated + hidden_states = self.dropout(hidden_states) + # multiply by the second matrix + hidden_states = self.wo(hidden_states) + # add the residual connection and post-LN + hidden_states = self.layernorm(hidden_states + residual_connection) + return hidden_states diff --git a/src/bert_layers/model.py b/src/bert_layers/model.py new file mode 100644 index 0000000..13c55cd --- /dev/null +++ b/src/bert_layers/model.py @@ -0,0 +1,529 @@ +# Copyright 2024 **AUTHORS_TODO** +# License: Apache-2.0 + +# RMSNorm Implementation: Copyright Meta (from their Llama RMSNorm implementation) +# License: LLAMA 2 COMMUNITY LICENSE AGREEMENT + +# Copyright 2022 MosaicML Examples authors +# SPDX-License-Identifier: Apache-2.0 + +# Copyright 2023 MosaicML Examples authors +# SPDX-License-Identifier: Apache-2.0 + +# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. +# Copyright (c) 2018-2021, NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2023, Tri Dao. + +"""Implements Mosaic BERT, with an eye towards the Hugging Face API. + +Mosaic BERT improves performance over Hugging Face BERT through the following: + +1. ALiBi. This architectural change removes positional embeddings and instead encodes positional +information through attention biases based on query-key position distance. It improves the effectiveness +of training with shorter sequence lengths by enabling extrapolation to longer sequences. + +2. Gated Linear Units (GLU). This architectural change replaces the FFN component of the BERT layer +to improve overall expressiveness, providing better convergence properties. + +3. Flash Attention. The MosaicBERT's self-attention layer makes use of Flash Attention, which dramatically +improves the speed of self-attention. Our implementation utilizes a bleeding edge implementation that +supports attention biases, which allows us to use Flash Attention with ALiBi. + +4. Unpadding. Padding is often used to simplify batching across sequences of different lengths. Standard BERT +implementations waste computation on padded tokens. MosaicBERT internally unpads to reduce unnecessary computation +and improve speed. It does this without changing how the user interfaces with the model, thereby +preserving the simple API of standard implementations. + + +Currently, MosaicBERT is available for masked language modeling :class:`BertForMaskedLM` and sequence +classification :class:`BertForSequenceClassification`. We aim to expand this catalogue in future releases. + +See :file:`./mosaic_bert.py` for utilities to simplify working with MosaicBERT in Composer, and for example usage +of the core Mosaic BERT classes. +""" + +import logging +import os +import sys +import warnings +from typing import List, Optional, Tuple, Union + +# Add folder root to path to allow us to use relative imports regardless of what directory the script is run from +sys.path.append(os.path.dirname(os.path.realpath(__file__))) + +import torch +import torch.nn as nn +from einops import rearrange +from torch.nn.modules.utils import consume_prefix_in_state_dict_if_present +from transformers.modeling_outputs import MaskedLMOutput, SequenceClassifierOutput +from transformers.models.bert.modeling_bert import BertPreTrainedModel + +import bert_padding as bert_padding_module +from .layers import BertAlibiEncoder, BertPooler, BertPredictionHeadTransform +from .embeddings import BertAlibiEmbeddings + +logger = logging.getLogger(__name__) + + +class BertModel(BertPreTrainedModel): + """Overall BERT model. + + Args: + config: a BertConfig class instance with the configuration to build a new model + + Inputs: + `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length] + with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts + `extract_features.py`, `run_classifier.py` and `run_squad.py`) + `token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token + types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to + a `sentence B` token (see BERT paper for more details). + `attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices + selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max + input sequence length in the current batch. It's the mask that we typically use for attention when + a batch has varying length sentences. + `output_all_encoded_layers`: boolean which controls the content of the `encoded_layers` output as described below. Default: `True`. + + Outputs: Tuple of (encoded_layers, pooled_output) + `encoded_layers`: controlled by `output_all_encoded_layers` argument: + - `output_all_encoded_layers=True`: outputs a list of the full sequences of encoded-hidden-states at the end + of each attention block (i.e. 12 full sequences for BERT-base, 24 for BERT-large), each + encoded-hidden-state is a torch.FloatTensor of size [batch_size, sequence_length, hidden_size], + - `output_all_encoded_layers=False`: outputs only the full sequence of hidden-states corresponding + to the last attention block of shape [batch_size, sequence_length, hidden_size], + `pooled_output`: a torch.FloatTensor of size [batch_size, hidden_size] which is the output of a + classifier pretrained on top of the hidden state associated to the first character of the + input (`CLS`) to train on the Next-Sentence task (see BERT's paper). + + Example usage: + ```python + # Already been converted into WordPiece token ids + input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]]) + input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]]) + token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]]) + config = modeling.BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768, + num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072) + model = BertModel(config=config) + all_encoder_layers, pooled_output = model(input_ids, token_type_ids, input_mask) + ``` + """ + + def __init__( + self, + config, + add_pooling_layer: bool = True, + use_rmsnorm: bool = True, + use_silu: bool = True, + ): + super(BertModel, self).__init__(config) + self.embeddings = BertAlibiEmbeddings(config, use_rmsnorm=use_rmsnorm) + self.encoder = BertAlibiEncoder(config, use_rmsnorm=use_rmsnorm, use_silu=use_silu) + self.pooler = BertPooler(config) if add_pooling_layer else None + self.post_init() + + def get_input_embeddings(self): + return self.embeddings.word_embeddings + + def set_input_embeddings(self, value): + self.embeddings.word_embeddings = value + + def forward( + self, + input_ids: torch.Tensor, + token_type_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + output_all_encoded_layers: Optional[bool] = False, + masked_tokens_mask: Optional[torch.Tensor] = None, + **kwargs, + ) -> Tuple[Union[List[torch.Tensor], torch.Tensor], Optional[torch.Tensor]]: + if attention_mask is None: + attention_mask = torch.ones_like(input_ids) + if token_type_ids is None: + token_type_ids = torch.zeros_like(input_ids) + + embedding_output = self.embeddings(input_ids, token_type_ids, position_ids) + + subset_mask = [] + first_col_mask = [] + + if masked_tokens_mask is None: + subset_mask = None + else: + first_col_mask = torch.zeros_like(masked_tokens_mask) + first_col_mask[:, 0] = True + subset_mask = masked_tokens_mask | first_col_mask + + encoder_outputs = self.encoder( + embedding_output, + attention_mask, + output_all_encoded_layers=output_all_encoded_layers, + subset_mask=subset_mask, + ) + + if masked_tokens_mask is None: + sequence_output = encoder_outputs[-1] + pooled_output = self.pooler(sequence_output) if self.pooler is not None else None + else: + # TD [2022-03-01]: the indexing here is very tricky. + attention_mask_bool = attention_mask.bool() + subset_idx = subset_mask[attention_mask_bool] # type: ignore + sequence_output = encoder_outputs[-1][masked_tokens_mask[attention_mask_bool][subset_idx]] + if self.pooler is not None: + pool_input = encoder_outputs[-1][first_col_mask[attention_mask_bool][subset_idx]] + pooled_output = self.pooler(pool_input, pool=False) + else: + pooled_output = None + + if not output_all_encoded_layers: + encoder_outputs = sequence_output + + if self.pooler is not None: + return encoder_outputs, pooled_output + + return encoder_outputs, None + + +################### +# Bert Heads +################### +class BertLMPredictionHead(nn.Module): + def __init__(self, config, bert_model_embedding_weights, use_rmsnorm: bool = True): + super().__init__() + self.transform = BertPredictionHeadTransform(config, use_rmsnorm=use_rmsnorm) + # The output weights are the same as the input embeddings, but there is + # an output-only bias for each token. + self.decoder = nn.Linear(bert_model_embedding_weights.size(1), bert_model_embedding_weights.size(0)) + self.decoder.weight = bert_model_embedding_weights + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.transform(hidden_states) + hidden_states = self.decoder(hidden_states) + return hidden_states + + +class BertOnlyMLMHead(nn.Module): + def __init__(self, config, bert_model_embedding_weights, use_rmsnorm: bool = True): + super().__init__() + self.predictions = BertLMPredictionHead(config, bert_model_embedding_weights, use_rmsnorm=use_rmsnorm) + + def forward(self, sequence_output: torch.Tensor) -> torch.Tensor: + prediction_scores = self.predictions(sequence_output) + return prediction_scores + + +class BertOnlyNSPHead(nn.Module): + def __init__(self, config): + super().__init__() + self.seq_relationship = nn.Linear(config.hidden_size, 2) + + def forward(self, pooled_output: torch.Tensor) -> torch.Tensor: + seq_relationship_score = self.seq_relationship(pooled_output) + return seq_relationship_score + + +##################### +# Various Bert models +##################### + + +class BertForPreTraining(BertPreTrainedModel): + # TBD: Coming in Future Commit + pass + + +class BertLMHeadModel(BertPreTrainedModel): + # TBD: Coming in Future Commit + pass + + +class BertForMaskedLM(BertPreTrainedModel): + def __init__(self, config, ablations: dict = {}): + super().__init__(config) + + if config.is_decoder: + warnings.warn( + "If you want to use `BertForMaskedLM` make sure `config.is_decoder=False` for " + "bi-directional self-attention." + ) + + use_rmsnorm = ablations.get("use_rmsnorm", True) + use_silu = ablations.get("use_silu", True) + + self.bert = BertModel( + config, + add_pooling_layer=False, + use_rmsnorm=use_rmsnorm, + use_silu=use_silu, + ) + self.cls = BertOnlyMLMHead(config, self.bert.embeddings.word_embeddings.weight, use_rmsnorm=use_rmsnorm) + + # Initialize weights and apply final processing + self.post_init() + + @classmethod + def from_composer( + cls, pretrained_checkpoint, state_dict=None, cache_dir=None, from_tf=False, config=None, *inputs, **kwargs + ): + """Load from pre-trained.""" + model = cls(config, *inputs, **kwargs) + if from_tf: + raise ValueError("Mosaic BERT does not support loading TensorFlow weights.") + + state_dict = torch.load(pretrained_checkpoint) + # If the state_dict was saved after wrapping with `composer.HuggingFaceModel`, it takes on the `model` prefix + consume_prefix_in_state_dict_if_present(state_dict, prefix="model.") + missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False) + + if len(missing_keys) > 0: + logger.warning(f"Found these missing keys in the checkpoint: {', '.join(missing_keys)}") + if len(unexpected_keys) > 0: + logger.warning(f"Found these unexpected keys in the checkpoint: {', '.join(unexpected_keys)}") + + return model + + def get_output_embeddings(self): + return self.cls.predictions.decoder + + def set_output_embeddings(self, new_embeddings): + self.cls.predictions.decoder = new_embeddings + + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple[torch.Tensor], MaskedLMOutput]: + # labels should be a `torch.LongTensor` of shape + # `(batch_size, sequence_length)`. These are used for computing the + # masked language modeling loss. + # + # Indices should be in `[-100, 0, ..., config.vocab_size]` (see + # `input_ids` docstring) Tokens with indices set to `-100` are ignored + # (masked), the loss is only computed for the tokens with labels in `[0, + # ..., config.vocab_size]` + # + # Prediction scores are only computed for masked tokens and the (bs, + # seqlen) dimensions are flattened + if (input_ids is not None) == (inputs_embeds is not None): + raise ValueError("Must specify either input_ids or input_embeds!") + + if labels is None: + masked_tokens_mask = None + else: + masked_tokens_mask = labels > 0 + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.bert( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + masked_tokens_mask=masked_tokens_mask, + ) + + sequence_output = outputs[0] + prediction_scores = self.cls(sequence_output) + + loss = None + if labels is not None: + # Compute loss + loss_fct = nn.CrossEntropyLoss() + masked_token_idx = torch.nonzero(labels.flatten() > 0, as_tuple=False).flatten() + loss = loss_fct(prediction_scores, labels.flatten()[masked_token_idx]) + + assert input_ids is not None, "Coding error; please open an issue" + batch, seqlen = input_ids.shape[:2] + prediction_scores = rearrange( + bert_padding_module.index_put_first_axis(prediction_scores, masked_token_idx, batch * seqlen), + "(b s) d -> b s d", + b=batch, + ) + + if not return_dict: + output = (prediction_scores,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return MaskedLMOutput( + loss=loss, + logits=prediction_scores, + hidden_states=None, + attentions=None, + ) + + def prepare_inputs_for_generation(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **model_kwargs): + input_shape = input_ids.shape + effective_batch_size = input_shape[0] + + # add a dummy token + if self.config.pad_token_id is None: + raise ValueError("The PAD token should be defined for generation") + + attention_mask = torch.cat([attention_mask, attention_mask.new_zeros((attention_mask.shape[0], 1))], dim=-1) + dummy_token = torch.full( + (effective_batch_size, 1), self.config.pad_token_id, dtype=torch.long, device=input_ids.device + ) + input_ids = torch.cat([input_ids, dummy_token], dim=1) + + return {"input_ids": input_ids, "attention_mask": attention_mask} + + +class BertForNextSentencePrediction(BertPreTrainedModel): + # TBD: Push in future commit + pass + + +class BertForSequenceClassification(BertPreTrainedModel): + """Bert Model transformer with a sequence classification/regression head. + + This head is just a linear layer on top of the pooled output. Used for, + e.g., GLUE tasks. + """ + + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.config = config + + self.bert = BertModel(config) + classifier_dropout = ( + config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob + ) + self.dropout = nn.Dropout(classifier_dropout) + self.classifier = nn.Linear(config.hidden_size, config.num_labels) + + # Initialize weights and apply final processing + self.post_init() + + @classmethod + def from_composer( + cls, pretrained_checkpoint, state_dict=None, cache_dir=None, from_tf=False, config=None, *inputs, **kwargs + ): + """Load from pre-trained.""" + model = cls(config, *inputs, **kwargs) + if from_tf: + raise ValueError("Mosaic BERT does not support loading TensorFlow weights.") + + state_dict = torch.load(pretrained_checkpoint) + # If the state_dict was saved after wrapping with `composer.HuggingFaceModel`, it takes on the `model` prefix + consume_prefix_in_state_dict_if_present(state_dict, prefix="model.") + missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False) + + if len(missing_keys) > 0: + logger.warning(f"Found these missing keys in the checkpoint: {', '.join(missing_keys)}") + if len(unexpected_keys) > 0: + logger.warning(f"Found these unexpected keys in the checkpoint: {', '.join(unexpected_keys)}") + + return model + + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple[torch.Tensor], SequenceClassifierOutput]: + # labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + # Labels for computing the sequence classification/regression loss. + # Indices should be in `[0, ..., config.num_labels - 1]`. + # If `config.num_labels == 1` a regression loss is computed + # (mean-square loss). If `config.num_labels > 1` a classification loss + # is computed (cross-entropy). + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.bert( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + pooled_output = outputs[1] + + pooled_output = self.dropout(pooled_output) + logits = self.classifier(pooled_output) + + loss = None + if labels is not None: + # Compute loss + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = nn.MSELoss() + if self.num_labels == 1: + loss = loss_fct(logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = nn.CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = nn.BCEWithLogitsLoss() + loss = loss_fct(logits, labels) + + if not return_dict: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return SequenceClassifierOutput( + loss=loss, + logits=logits, + hidden_states=None, + attentions=None, + ) + + +class BertForMultipleChoice(BertPreTrainedModel): + # TBD: Push in future commit + pass + + +class BertForTokenClassification(BertPreTrainedModel): + # TBD: Push in future commit + pass + + +class BertForQuestionAnswering(BertPreTrainedModel): + """Bert Model with a span classification head. + + This is used for extractive question-answering tasks like SQuAD (a linear + layers on top of the hidden states' output to compute `span start logits` + and `span end logits`). + """ + + # TBD: Push in future commit diff --git a/src/bert_layers/norm.py b/src/bert_layers/norm.py new file mode 100644 index 0000000..597050b --- /dev/null +++ b/src/bert_layers/norm.py @@ -0,0 +1,57 @@ +# Copyright 2024 **AUTHORS_TODO** +# License: Apache-2.0 + +# RMSNorm Implementation: Copyright Meta (from their Llama RMSNorm implementation) +# License: LLAMA 2 COMMUNITY LICENSE AGREEMENT + + +import torch +import torch.nn as nn + + +class RMSNorm(nn.Module): + """Llama2 RMSNorm implementation""" + + def __init__(self, dim: int, eps: float = 1e-6): + """ + Initialize the RMSNorm normalization layer. + + Args: + dim (int): The dimension of the input tensor. + eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6. + + Attributes: + eps (float): A small value added to the denominator for numerical stability. + weight (nn.Parameter): Learnable scaling parameter. + + """ + super().__init__() + self.eps = eps + self.weight = nn.Parameter(torch.ones(dim)) + + def _norm(self, x): + """ + Apply the RMSNorm normalization to the input tensor. + + Args: + x (torch.Tensor): The input tensor. + + Returns: + torch.Tensor: The normalized tensor. + + """ + return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) + + def forward(self, x): + """ + Forward pass through the RMSNorm layer. + + Args: + x (torch.Tensor): The input tensor. + + Returns: + torch.Tensor: The output tensor after applying RMSNorm. + + """ + output = self._norm(x.float()).type_as(x) + return output * self.weight diff --git a/src/configuration_bert.py b/src/configuration_bert.py index 9fac4c1..1a9d946 100644 --- a/src/configuration_bert.py +++ b/src/configuration_bert.py @@ -4,7 +4,7 @@ from transformers import BertConfig as TransformersBertConfig -class BertConfig(TransformersBertConfig): +class MosaicBertConfig(TransformersBertConfig): def __init__( self, alibi_starting_size: int = 512, diff --git a/src/mosaic_bert.py b/src/mosaic_bert.py index d4a722d..5c6114c 100644 --- a/src/mosaic_bert.py +++ b/src/mosaic_bert.py @@ -97,7 +97,7 @@ def create_mosaic_bert_mlm( if not pretrained_model_name: pretrained_model_name = "bert-base-uncased" - config = configuration_bert_module.BertConfig.from_pretrained(pretrained_model_name, **model_config) + config = configuration_bert_module.MosaicBertConfig.from_pretrained(pretrained_model_name, **model_config) # Padding for divisibility by 8 if config.vocab_size % 8 != 0: