Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support of kwargs to Backbone.from_preset and fix the dtype forwarding in Task.from_preset #1742

Merged
merged 1 commit into from
Aug 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion keras_nlp/src/models/backbone.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@ class like `keras_nlp.models.Backbone.from_preset()`, or from
f"`from_preset` directly on `{preset_cls.__name__}` instead."
)

backbone = load_serialized_object(preset, CONFIG_FILE)
backbone = load_serialized_object(preset, CONFIG_FILE, **kwargs)
if load_weights:
jax_memory_cleanup(backbone)
backbone.load_weights(get_file(preset, MODEL_WEIGHTS_FILE))
Expand Down
18 changes: 17 additions & 1 deletion keras_nlp/src/models/backbone_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from keras_nlp.src.utils.preset_utils import load_config


class TestTask(TestCase):
class TestBackbone(TestCase):
def test_preset_accessors(self):
bert_presets = set(BertBackbone.presets.keys())
gpt2_presets = set(GPT2Backbone.presets.keys())
Expand All @@ -46,6 +46,22 @@ def test_from_preset(self):
GPT2Backbone,
)

@pytest.mark.large
def test_from_preset_with_kwargs(self):
# Test `dtype`
backbone = Backbone.from_preset(
"bert_tiny_en_uncased", load_weights=False, dtype="bfloat16"
)
self.assertIsInstance(backbone, BertBackbone)
self.assertEqual(backbone.dtype_policy.name, "bfloat16")

# Test kwargs forwarding
backbone = Backbone.from_preset(
"bert_tiny_en_uncased", load_weights=False, dropout=0.5
)
self.assertIsInstance(backbone, BertBackbone)
self.assertAllClose(backbone.dropout, 0.5)

@pytest.mark.large
def test_from_preset_errors(self):
with self.assertRaises(ValueError):
Expand Down
8 changes: 3 additions & 5 deletions keras_nlp/src/models/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,13 +258,11 @@ def from_preset(
)
cls = subclasses[0]
# Forward dtype to the backbone.
config_overrides = {}
backbone_kwargs = {}
if "dtype" in kwargs:
config_overrides["dtype"] = kwargs.pop("dtype")
backbone_kwargs = {"dtype": kwargs.pop("dtype")}
backbone = backbone_preset_cls.from_preset(
preset,
load_weights=load_weights,
config_overrides=config_overrides,
preset, load_weights=load_weights, **backbone_kwargs
)
if "preprocessor" in kwargs:
preprocessor = kwargs.pop("preprocessor")
Expand Down
10 changes: 10 additions & 0 deletions keras_nlp/src/models/task_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,16 @@ def test_from_preset(self):
# TODO: Add a classifier task loading test when there is a classifier
# with new design available on Kaggle.

@pytest.mark.large
def test_from_preset_with_kwargs(self):
# Test `dtype`
model = CausalLM.from_preset(
"gpt2_base_en", load_weights=False, dtype="bfloat16"
)
self.assertIsInstance(model, GPT2CausalLM)
self.assertEqual(model.dtype_policy.name, "bfloat16")
self.assertEqual(model.backbone.dtype_policy.name, "bfloat16")

@pytest.mark.large
def test_from_preset_errors(self):
with self.assertRaises(ValueError):
Expand Down
37 changes: 31 additions & 6 deletions keras_nlp/src/utils/preset_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -561,13 +561,16 @@ def check_format(preset):
return "keras"


def load_serialized_object(
preset,
config_file=CONFIG_FILE,
config_overrides={},
):
def load_serialized_object(preset, config_file=CONFIG_FILE, **kwargs):
kwargs = kwargs or {}
config = load_config(preset, config_file)
config["config"] = {**config["config"], **config_overrides}

# `dtype` in config might be a serialized `DTypePolicy` or `DTypePolicyMap`.
# Ensure that `dtype` is properly configured.
dtype = kwargs.pop("dtype", None)
config = set_dtype_in_config(config, dtype)

config["config"] = {**config["config"], **kwargs}
return keras.saving.deserialize_keras_object(config)


Expand All @@ -590,3 +593,25 @@ def jax_memory_cleanup(layer):
for weight in layer.weights:
if getattr(weight, "_value", None) is not None:
weight._value.delete()


def set_dtype_in_config(config, dtype=None):
if dtype is None:
return config

config = config.copy()
if "dtype" not in config["config"]:
# Forward `dtype` to the config.
config["config"]["dtype"] = dtype
elif (
"dtype" in config["config"]
and isinstance(config["config"]["dtype"], dict)
and "DTypePolicyMap" in config["config"]["dtype"]["class_name"]
):
# If it is `DTypePolicyMap` in `config`, forward `dtype` as its default
# policy.
policy_map_config = config["config"]["dtype"]["config"]
policy_map_config["default_policy"] = dtype
for k in policy_map_config["policy_map"].keys():
policy_map_config["policy_map"][k]["config"]["source_name"] = dtype
return config
17 changes: 17 additions & 0 deletions keras_nlp/src/utils/preset_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,12 @@
from keras_nlp.src.models import BertBackbone
from keras_nlp.src.models import BertTokenizer
from keras_nlp.src.tests.test_case import TestCase
from keras_nlp.src.utils.keras_utils import has_quantization_support
from keras_nlp.src.utils.preset_utils import CONFIG_FILE
from keras_nlp.src.utils.preset_utils import METADATA_FILE
from keras_nlp.src.utils.preset_utils import TOKENIZER_CONFIG_FILE
from keras_nlp.src.utils.preset_utils import check_format
from keras_nlp.src.utils.preset_utils import load_serialized_object


class PresetUtilsTest(TestCase):
Expand Down Expand Up @@ -113,3 +115,18 @@ def test_incorrect_metadata(self):

with self.assertRaisesRegex(ValueError, "doesn't have `keras_version`"):
check_format(preset_dir)

@parameterized.named_parameters(
("gemma2_2b_en", "gemma2_2b_en", "bfloat16", False),
("llama2_7b_en_int8", "llama2_7b_en_int8", "bfloat16", True),
)
@pytest.mark.extra_large
def test_load_serialized_object(self, preset, dtype, is_quantized):
if is_quantized and not has_quantization_support():
self.skipTest("This version of Keras doesn't support quantization.")

model = load_serialized_object(preset, dtype=dtype)
if is_quantized:
self.assertEqual(model.dtype_policy.name, "map_bfloat16")
else:
self.assertEqual(model.dtype_policy.name, "bfloat16")
Loading