Skip to content

Commit

Permalink
Support kwargs to Backbone.from_preset and fix the dtype forwarding i…
Browse files Browse the repository at this point in the history
…n Task.from_preset
  • Loading branch information
james77777778 committed Aug 8, 2024
1 parent b890ca9 commit 562b9dd
Show file tree
Hide file tree
Showing 6 changed files with 79 additions and 13 deletions.
2 changes: 1 addition & 1 deletion keras_nlp/src/models/backbone.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@ class like `keras_nlp.models.Backbone.from_preset()`, or from
f"`from_preset` directly on `{preset_cls.__name__}` instead."
)

backbone = load_serialized_object(preset, CONFIG_FILE)
backbone = load_serialized_object(preset, CONFIG_FILE, **kwargs)
if load_weights:
jax_memory_cleanup(backbone)
backbone.load_weights(get_file(preset, MODEL_WEIGHTS_FILE))
Expand Down
18 changes: 17 additions & 1 deletion keras_nlp/src/models/backbone_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from keras_nlp.src.utils.preset_utils import load_config


class TestTask(TestCase):
class TestBackbone(TestCase):
def test_preset_accessors(self):
bert_presets = set(BertBackbone.presets.keys())
gpt2_presets = set(GPT2Backbone.presets.keys())
Expand All @@ -46,6 +46,22 @@ def test_from_preset(self):
GPT2Backbone,
)

@pytest.mark.large
def test_from_preset_with_kwargs(self):
# Test `dtype`
backbone = Backbone.from_preset(
"bert_tiny_en_uncased", load_weights=False, dtype="bfloat16"
)
self.assertIsInstance(backbone, BertBackbone)
self.assertEqual(backbone.dtype_policy.name, "bfloat16")

# Test kwargs forwarding
backbone = Backbone.from_preset(
"bert_tiny_en_uncased", load_weights=False, dropout=0.5
)
self.assertIsInstance(backbone, BertBackbone)
self.assertAllClose(backbone.dropout, 0.5)

@pytest.mark.large
def test_from_preset_errors(self):
with self.assertRaises(ValueError):
Expand Down
8 changes: 3 additions & 5 deletions keras_nlp/src/models/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,13 +258,11 @@ def from_preset(
)
cls = subclasses[0]
# Forward dtype to the backbone.
config_overrides = {}
backbone_kwargs = {}
if "dtype" in kwargs:
config_overrides["dtype"] = kwargs.pop("dtype")
backbone_kwargs = {"dtype": kwargs.pop("dtype")}
backbone = backbone_preset_cls.from_preset(
preset,
load_weights=load_weights,
config_overrides=config_overrides,
preset, load_weights=load_weights, **backbone_kwargs
)
if "preprocessor" in kwargs:
preprocessor = kwargs.pop("preprocessor")
Expand Down
10 changes: 10 additions & 0 deletions keras_nlp/src/models/task_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,16 @@ def test_from_preset(self):
# TODO: Add a classifier task loading test when there is a classifier
# with new design available on Kaggle.

@pytest.mark.large
def test_from_preset_with_kwargs(self):
# Test `dtype`
model = CausalLM.from_preset(
"gpt2_base_en", load_weights=False, dtype="bfloat16"
)
self.assertIsInstance(model, GPT2CausalLM)
self.assertEqual(model.dtype_policy.name, "bfloat16")
self.assertEqual(model.backbone.dtype_policy.name, "bfloat16")

@pytest.mark.large
def test_from_preset_errors(self):
with self.assertRaises(ValueError):
Expand Down
37 changes: 31 additions & 6 deletions keras_nlp/src/utils/preset_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -561,13 +561,16 @@ def check_format(preset):
return "keras"


def load_serialized_object(
preset,
config_file=CONFIG_FILE,
config_overrides={},
):
def load_serialized_object(preset, config_file=CONFIG_FILE, **kwargs):
kwargs = kwargs or {}
config = load_config(preset, config_file)
config["config"] = {**config["config"], **config_overrides}

# `dtype` in config might be a serialized `DTypePolicy` or `DTypePolicyMap`.
# Ensure that `dtype` is properly configured.
dtype = kwargs.pop("dtype", None)
config = set_dtype_in_config(config, dtype)

config["config"] = {**config["config"], **kwargs}
return keras.saving.deserialize_keras_object(config)


Expand All @@ -590,3 +593,25 @@ def jax_memory_cleanup(layer):
for weight in layer.weights:
if getattr(weight, "_value", None) is not None:
weight._value.delete()


def set_dtype_in_config(config, dtype=None):
if dtype is None:
return config

config = config.copy()
if "dtype" not in config["config"]:
# Forward `dtype` to the config.
config["config"]["dtype"] = dtype
elif (
"dtype" in config["config"]
and isinstance(config["config"]["dtype"], dict)
and "DTypePolicyMap" in config["config"]["dtype"]["class_name"]
):
# If it is `DTypePolicyMap` in `config`, forward `dtype` as its default
# policy.
policy_map_config = config["config"]["dtype"]["config"]
policy_map_config["default_policy"] = dtype
for k in policy_map_config["policy_map"].keys():
policy_map_config["policy_map"][k]["config"]["source_name"] = dtype
return config
17 changes: 17 additions & 0 deletions keras_nlp/src/utils/preset_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,12 @@
from keras_nlp.src.models import BertBackbone
from keras_nlp.src.models import BertTokenizer
from keras_nlp.src.tests.test_case import TestCase
from keras_nlp.src.utils.keras_utils import has_quantization_support
from keras_nlp.src.utils.preset_utils import CONFIG_FILE
from keras_nlp.src.utils.preset_utils import METADATA_FILE
from keras_nlp.src.utils.preset_utils import TOKENIZER_CONFIG_FILE
from keras_nlp.src.utils.preset_utils import check_format
from keras_nlp.src.utils.preset_utils import load_serialized_object


class PresetUtilsTest(TestCase):
Expand Down Expand Up @@ -113,3 +115,18 @@ def test_incorrect_metadata(self):

with self.assertRaisesRegex(ValueError, "doesn't have `keras_version`"):
check_format(preset_dir)

@parameterized.named_parameters(
("gemma2_2b_en", "gemma2_2b_en", "bfloat16", False),
("llama2_7b_en_int8", "llama2_7b_en_int8", "bfloat16", True),
)
@pytest.mark.extra_large
def test_load_serialized_object(self, preset, dtype, is_quantized):
if is_quantized and not has_quantization_support():
self.skipTest("This version of Keras doesn't support quantization.")

model = load_serialized_object(preset, dtype=dtype)
if is_quantized:
self.assertEqual(model.dtype_policy.name, "map_bfloat16")
else:
self.assertEqual(model.dtype_policy.name, "bfloat16")

0 comments on commit 562b9dd

Please sign in to comment.