-
Notifications
You must be signed in to change notification settings - Fork 246
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add support of kwargs
to Backbone.from_preset
and fix the dtype forwarding in Task.from_preset
#1742
Conversation
dtype
argument to Backbone.from_preset
and fix the dtype forwarding in CausalLM.from_preset
dtype
argument to Backbone.from_preset
and fix dtype
forwarding in CausalLM.from_preset
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good i think! just one small comment
keras_nlp/src/models/backbone.py
Outdated
backbone = load_serialized_object(preset, CONFIG_FILE) | ||
# Forward `config_overrides` and `dtype`. | ||
config_overrides = {} | ||
if "config_overrides" in kwargs: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I believe the kwargs
were supposed to ask as the config override directly (though it look like this broke at some point).
So if you wanted to set bert dropout, for example, you could do
model = keras_nlp.models.BertBackbone.from_preset("bert_base_en_uncased", dropout=0.5)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, it is a bit weird to see config_overrides
and I think there is a dangerous default value (config_overrides={}
) for it in load_serialized_object
.
I can try to fix it.
23e761e
to
3e56494
Compare
dtype
argument to Backbone.from_preset
and fix dtype
forwarding in CausalLM.from_preset
kwargs
to Backbone.from_preset
and fix the dtype forwarding in Task.from_preset
…n Task.from_preset
3e56494
to
562b9dd
Compare
I have updated import keras_nlp
llama_lm = keras_nlp.models.CausalLM.from_preset(
"llama2_instruct_7b_en_int8", load_weights=False, dtype="bfloat16"
)
assert llama_lm.backbone.token_embedding.compute_dtype == "bfloat16"
bert_backbone = keras_nlp.models.BertBackbone.from_preset(
"bert_base_en_uncased", load_weights=False, dtype="bfloat16", dropout=0.5
)
assert bert_backbone.token_embedding.compute_dtype == "bfloat16"
assert bert_backbone.dropout == 0.5 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! This LGTM
…n Task.from_preset (#1742)
Currently,
CausalLM.from_preset(..., dtype="bfloat16")
has no effect because it doesn't forwarddtype
to the backbone.This PR fixes that issue and also adds
dtype
support toBackbone.from_preset
.Additionally, I have updated the logic in
load_serialized_object
to supportdtype
when usingDTypePolicyMap
, ensuring that the pre-quantized preset will obeydtype
: