Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Opt presets #707

Merged
merged 2 commits into from
Feb 3, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions keras_nlp/models/opt/opt_backbone.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@

"""OPT backbone model."""

import copy

import tensorflow as tf
from tensorflow import keras

Expand All @@ -22,6 +24,8 @@
)
from keras_nlp.layers.transformer_decoder import TransformerDecoder
from keras_nlp.models.backbone import Backbone
from keras_nlp.models.opt.opt_presets import backbone_presets
from keras_nlp.utils.python_utils import classproperty


def opt_kernel_initializer(stddev=0.02):
Expand Down Expand Up @@ -66,6 +70,11 @@ class OPTBackbone(Backbone):
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0], shape=(1, 12)
),
}

# Pretrained OPT decoder
model = keras_nlp.models.OPTBackbone.from_preset("opt_125m_en")
output = model(input_data)

# Randomly initialized OPT decoder model with a custom config
model = keras_nlp.models.OPTBackbone(
vocabulary_size=50265,
Expand Down Expand Up @@ -159,3 +168,7 @@ def get_config(self):
@property
def token_embedding(self):
return self.get_layer("embeddings").token_embedding

@classproperty
def presets(cls):
return copy.deepcopy(backbone_presets)
108 changes: 108 additions & 0 deletions keras_nlp/models/opt/opt_presets.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
# Copyright 2023 The KerasNLP Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""OPT model preset configurations."""

# Metadata for loading pretrained model weights.
backbone_presets = {
"opt_125m_en": {
"config": {
"vocabulary_size": 50272,
"num_layers": 12,
"num_heads": 12,
"hidden_dim": 768,
"intermediate_dim": 3072,
"dropout": 0.1,
"max_sequence_length": 2048,
},
"preprocessor_config": {},
"description": (
"12-layer OPT model where case in maintained. Trained on "
"BookCorpus, CommonCrawl, Pile, and PushShift.io corpora."
),
"weights_url": "https://storage.googleapis.com/keras-nlp/models/opt_125m_en/v1/model.h5",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are we starting with v1 or have you already augmented the count?

Copy link
Member Author

@mattdangerw mattdangerw Feb 3, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

v1 is the start, for all presets

"weights_hash": "63e444998982e48da4a1a3970f4c6203",
"vocabulary_url": "https://storage.googleapis.com/keras-nlp/models/opt_125m_en/v1/vocab.json",
"vocabulary_hash": "cf410ee085c5c69c957bb1f6d8456596",
"merges_url": "https://storage.googleapis.com/keras-nlp/models/opt_125m_en/v1/merges.txt",
"merges_hash": "75a37753dd7a28a2c5df80c28bf06e4e",
},
# We skip the 350m checkpoint because it does not match the structure of
# other checkpoints.
"opt_1.3b_en": {
"config": {
"vocabulary_size": 50272,
"num_layers": 24,
"num_heads": 32,
"hidden_dim": 2048,
"intermediate_dim": 8192,
"dropout": 0.1,
"max_sequence_length": 2048,
},
"preprocessor_config": {},
"description": (
"24-layer OPT model where case in maintained. Trained on "
"BookCorpus, CommonCrawl, Pile, and PushShift.io corpora."
),
"weights_url": "https://storage.googleapis.com/keras-nlp/models/opt_1.3b_en/v1/model.h5",
"weights_hash": "0365ac8483e99a912c9770521909ecce",
"vocabulary_url": "https://storage.googleapis.com/keras-nlp/models/opt_1.3b_en/v1/vocab.json",
"vocabulary_hash": "cf410ee085c5c69c957bb1f6d8456596",
"merges_url": "https://storage.googleapis.com/keras-nlp/models/opt_1.3b_en/v1/merges.txt",
"merges_hash": "75a37753dd7a28a2c5df80c28bf06e4e",
},
"opt_2.7b_en": {
"config": {
"vocabulary_size": 50272,
"num_layers": 32,
"num_heads": 32,
"hidden_dim": 2560,
"intermediate_dim": 10240,
"dropout": 0.1,
"max_sequence_length": 2048,
},
"preprocessor_config": {},
"description": (
"32-layer OPT model where case in maintained. Trained on "
"BookCorpus, CommonCrawl, Pile, and PushShift.io corpora."
),
"weights_url": "https://storage.googleapis.com/keras-nlp/models/opt_2.7b_en/v1/model.h5",
"weights_hash": "af56da9206a95b9287356955c5bc14e7",
"vocabulary_url": "https://storage.googleapis.com/keras-nlp/models/opt_2.7b_en/v1/vocab.json",
"vocabulary_hash": "cf410ee085c5c69c957bb1f6d8456596",
"merges_url": "https://storage.googleapis.com/keras-nlp/models/opt_2.7b_en/v1/merges.txt",
"merges_hash": "75a37753dd7a28a2c5df80c28bf06e4e",
},
"opt_6.7b_en": {
"config": {
"vocabulary_size": 50272,
"num_layers": 32,
"num_heads": 32,
"hidden_dim": 4096,
"intermediate_dim": 16384,
"dropout": 0.1,
"max_sequence_length": 2048,
},
"preprocessor_config": {},
"description": (
"32-layer OPT model where case in maintained. Trained on "
"BookCorpus, CommonCrawl, Pile, and PushShift.io corpora."
),
"weights_url": "https://storage.googleapis.com/keras-nlp/models/opt_6.7b_en/v1/model.h5",
"weights_hash": "543120fbe601b70e6ec04cc909781e21",
"vocabulary_url": "https://storage.googleapis.com/keras-nlp/models/opt_6.7b_en/v1/vocab.json",
"vocabulary_hash": "cf410ee085c5c69c957bb1f6d8456596",
"merges_url": "https://storage.googleapis.com/keras-nlp/models/opt_6.7b_en/v1/merges.txt",
"merges_hash": "75a37753dd7a28a2c5df80c28bf06e4e",
},
}
109 changes: 109 additions & 0 deletions keras_nlp/models/opt/opt_presets_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
# Copyright 2023 The KerasNLP Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for loading pretrained model presets."""

import pytest
import tensorflow as tf
from absl.testing import parameterized

from keras_nlp.models.opt.opt_backbone import OPTBackbone
from keras_nlp.models.opt.opt_tokenizer import OPTTokenizer


@pytest.mark.large
class OPTPresetSmokeTest(tf.test.TestCase, parameterized.TestCase):
"""
A smoke test for GPT-2 presets we run continuously.

This only tests the smallest weights we have available. Run with:
`pytest keras_nlp/models/opt/opt_presets_test.py --run_large`
"""

def test_tokenizer_output(self):
tokenizer = OPTTokenizer.from_preset("opt_125m_en")
outputs = tokenizer("The quick brown fox.")
expected_outputs = [133, 2119, 6219, 23602, 4]
self.assertAllEqual(outputs, expected_outputs)

@parameterized.named_parameters(
("preset_weights", True), ("random_weights", False)
)
def test_backbone_output(self, load_weights):
input_data = {
"token_ids": tf.constant([[133, 2119, 6219, 23602, 4]]),
"padding_mask": tf.constant([[1, 1, 1, 1, 1]]),
}
model = OPTBackbone.from_preset(
"opt_125m_en", load_weights=load_weights
)
outputs = model(input_data)[0, 0, :5]
if load_weights:
# The forward pass from a preset should be stable!
# This test should catch cases where we unintentionally change our
# network code in a way that would invalidate our preset weights.
# We should only update these numbers if we are updating a weights
# file, or have found a discrepancy with the upstream source.
expected_outputs = [-0.246, -1.004, -0.072, 0.097, 0.533]
# Keep a high tolerance, so we are robust to different hardware.
self.assertAllClose(outputs, expected_outputs, atol=0.01, rtol=0.01)

@parameterized.named_parameters(
("opt_tokenizer", OPTTokenizer),
("opt_backbone", OPTBackbone),
)
def test_preset_docstring(self, cls):
"""Check we did our docstring formatting correctly."""
for name in cls.presets:
self.assertRegex(cls.from_preset.__doc__, name)

@parameterized.named_parameters(
("opt_tokenizer", OPTTokenizer),
("opt_backbone", OPTBackbone),
)
def test_unknown_preset_error(self, cls):
# Not a preset name
with self.assertRaises(ValueError):
cls.from_preset("opt_clowntown")


@pytest.mark.extra_large
class OPTPresetFullTest(tf.test.TestCase, parameterized.TestCase):
"""
Test the full enumeration of our preset.

This tests every GPT-2 preset and is only run manually.
Run with:
`pytest keras_nlp/models/opt/opt_presets_test.py --run_extra_large`
"""

@parameterized.named_parameters(
("preset_weights", True), ("random_weights", False)
)
def test_load_opt(self, load_weights):
for preset in OPTBackbone.presets:
model = OPTBackbone.from_preset(preset, load_weights=load_weights)
input_data = {
"token_ids": tf.random.uniform(
shape=(1, 1024),
dtype=tf.int64,
maxval=model.vocabulary_size,
),
"padding_mask": tf.constant([1] * 1024, shape=(1, 1024)),
}
model(input_data)

def test_load_tokenizers(self):
for preset in OPTTokenizer.presets:
tokenizer = OPTTokenizer.from_preset(preset)
tokenizer("The quick brown fox.")
8 changes: 8 additions & 0 deletions keras_nlp/models/opt/opt_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,13 @@

"""OPT tokenizer."""

import copy

from tensorflow import keras

from keras_nlp.models.opt.opt_presets import backbone_presets
from keras_nlp.tokenizers.byte_pair_tokenizer import BytePairTokenizer
from keras_nlp.utils.python_utils import classproperty


@keras.utils.register_keras_serializable(package="keras_nlp")
Expand Down Expand Up @@ -108,3 +112,7 @@ def __init__(
self.start_token_id = self.token_to_id(start_token)
self.pad_token_id = self.token_to_id(pad_token)
self.end_token_id = self.token_to_id(end_token)

@classproperty
def presets(cls):
return copy.deepcopy(backbone_presets)
Loading