Skip to content

Commit

Permalink
split bert_layers into multiple files
Browse files Browse the repository at this point in the history
  • Loading branch information
warner-benjamin committed May 16, 2024
1 parent 0828615 commit e11d2c2
Show file tree
Hide file tree
Showing 12 changed files with 1,337 additions and 1,206 deletions.
80 changes: 43 additions & 37 deletions __init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,57 +9,63 @@

try:
import torch

# yapf: disable
from src.bert_layers import (BertEmbeddings, BertEncoder, BertForMaskedLM,
from src.bert_layers import (BertAlibiEmbeddings, BertAlibiEncoder, BertForMaskedLM,
BertForSequenceClassification,
BertGatedLinearUnitMLP, BertLayer,
BertResidualGLU, BertAlibiLayer,
BertLMPredictionHead, BertModel,
BertOnlyMLMHead, BertOnlyNSPHead, BertPooler,
BertPredictionHeadTransform, BertSelfOutput,
BertUnpadAttention, BertUnpadSelfAttention)
BertAlibiUnpadAttention, BertAlibiUnpadSelfAttention)
# yapf: enable
from src.bert_padding import (IndexFirstAxis, IndexPutFirstAxis,
index_first_axis, index_put_first_axis,
pad_input, unpad_input, unpad_input_only)
from src.bert_padding import (
IndexFirstAxis,
IndexPutFirstAxis,
index_first_axis,
index_put_first_axis,
pad_input,
unpad_input,
unpad_input_only,
)
from src.hf_bert import create_hf_bert_classification, create_hf_bert_mlm
from src.mosaic_bert import (create_mosaic_bert_classification,
create_mosaic_bert_mlm)
from src.mosaic_bert import create_mosaic_bert_classification, create_mosaic_bert_mlm
except ImportError as e:
try:
is_cuda_available = torch.cuda.is_available() # type: ignore
except:
except Exception:
is_cuda_available = False

reqs_file = 'requirements.txt' if is_cuda_available else 'requirements-cpu.txt'
reqs_file = "requirements.txt" if is_cuda_available else "requirements-cpu.txt"
raise ImportError(
f'Please make sure to pip install -r {reqs_file} to get the requirements for the BERT benchmark.'
f"Please make sure to pip install -r {reqs_file} to get the requirements for the BERT benchmark."
) from e

__all__ = [
'BertEmbeddings',
'BertEncoder',
'BertForMaskedLM',
'BertForSequenceClassification',
'BertGatedLinearUnitMLP',
'BertLayer',
'BertLMPredictionHead',
'BertModel',
'BertOnlyMLMHead',
'BertOnlyNSPHead',
'BertPooler',
'BertPredictionHeadTransform',
'BertSelfOutput',
'BertUnpadAttention',
'BertUnpadSelfAttention',
'IndexFirstAxis',
'IndexPutFirstAxis',
'index_first_axis',
'index_put_first_axis',
'pad_input',
'unpad_input',
'unpad_input_only',
'create_hf_bert_classification',
'create_hf_bert_mlm',
'create_mosaic_bert_classification',
'create_mosaic_bert_mlm',
"BertAlibiEmbeddings",
"BertAlibiEncoder",
"BertForMaskedLM",
"BertForSequenceClassification",
"BertResidualGLU",
"BertAlibiLayer",
"BertLMPredictionHead",
"BertModel",
"BertOnlyMLMHead",
"BertOnlyNSPHead",
"BertPooler",
"BertPredictionHeadTransform",
"BertSelfOutput",
"BertAlibiUnpadAttention",
"BertAlibiUnpadSelfAttention",
"IndexFirstAxis",
"IndexPutFirstAxis",
"index_first_axis",
"index_put_first_axis",
"pad_input",
"unpad_input",
"unpad_input_only",
"create_hf_bert_classification",
"create_hf_bert_mlm",
"create_mosaic_bert_classification",
"create_mosaic_bert_mlm",
]
24 changes: 12 additions & 12 deletions src/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,12 @@
sys.path.append(os.path.dirname(os.path.realpath(__file__)))

# yapf: disable
from bert_layers import (BertEmbeddings, BertEncoder, BertForMaskedLM,
BertForSequenceClassification, BertGatedLinearUnitMLP,
BertLayer, BertLMPredictionHead, BertModel,
from bert_layers import (BertAlibiEmbeddings, BertAlibiEncoder, BertForMaskedLM,
BertForSequenceClassification, BertResidualGLU,
BertAlibiLayer, BertLMPredictionHead, BertModel,
BertOnlyMLMHead, BertOnlyNSPHead, BertPooler,
BertPredictionHeadTransform, BertSelfOutput,
BertUnpadAttention, BertUnpadSelfAttention)
BertAlibiUnpadAttention, BertAlibiUnpadSelfAttention)
# yapf: enable
from bert_padding import (
IndexFirstAxis,
Expand All @@ -24,28 +24,28 @@
unpad_input,
unpad_input_only,
)
from configuration_bert import BertConfig
from configuration_bert import MosaicBertConfig

from hf_bert import create_hf_bert_classification, create_hf_bert_mlm
from mosaic_bert import create_mosaic_bert_classification, create_mosaic_bert_mlm

__all__ = [
"BertConfig",
"BertEmbeddings",
"BertEncoder",
"BertAlibiEmbeddings",
"BertAlibiEncoder",
"BertForMaskedLM",
"BertForSequenceClassification",
"BertGatedLinearUnitMLP",
"BertLayer",
"BertResidualGLU",
"BertAlibiLayer",
"BertLMPredictionHead",
"BertModel",
"BertOnlyMLMHead",
"BertOnlyNSPHead",
"BertPooler",
"BertPredictionHeadTransform",
"BertSelfOutput",
"BertUnpadAttention",
"BertUnpadSelfAttention",
"BertAlibiUnpadAttention",
"BertAlibiUnpadSelfAttention",
"MosaicBertConfig",
"IndexFirstAxis",
"IndexPutFirstAxis",
"index_first_axis",
Expand Down
Loading

0 comments on commit e11d2c2

Please sign in to comment.