-
Notifications
You must be signed in to change notification settings - Fork 252
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Adding GPTNeoXBackbone
#1056
Adding GPTNeoXBackbone
#1056
Conversation
@mattdangerw I opened this PR so that easier for you to review the code ! As of now I referred Huggingface implementation of GPTNeoX and put parts together (Rotary embedding, attention) converting into Tensorflow. Please Note:
|
Very cool! Excited for this. Generally question, are we better off calling this GPTNeoX or Pythia? We can ship checkpoints for either, but we may want to go with whatever is better known as the general name. The first thing I would suggest doing is extending the colab to actually do some weight conversion. It looks like the original project hosts weights in huggingface, so that seems like the place to start. Essentially we would want a colab or a script that can download the original weights, convert them to our backbone model, and run some dummy inputs through both and confirm the outputs are equivalent. Here is an example colab with distilbert that @abheesht17 made. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the PR! I took a brief scan over the code, and it's quite complex to tell if it is correct or not. So we need 2 things here:
- As Matt commented above, please share a colab with weights conversion so that we know the code works properly (produce the same result as HF given the same checkpoint).
- For methods like
_compute_attention
, we need some comments to illustrate what it is doing, otherwise we will lose track of it very soon. We will take a deeper look at these, let's focus on the weights conversion for now.
Thanks!
Hello from EleutherAI! I’m excited to see this PR in the works. The short answer is that I would recommend calling it the GPT-NeoX architecture. The GPT-NeoX architecture has been used by a variety of models including GPT-J, GPT-NeoX-20B, Pythia, PaLM, and more. If you don’t consider changing the PE to be a meaningful architecture change, then StableLM and (I believe but can’t find documentation of this fact right now) MPT also use the architecture. (I can give a longer answer & history if that’s desired) We also release weights in the format that the GPT-NeoX library keeps them natively. This format is more convenient for distributed training but not for inference. Several such weight formats can be found on the HuggingFace Hub or linked to from our README. The HuggingFace weights are official releases though, and were produced by EleutherAI. Our library actually support an “export to HF” script, and if you develop a conversion script for your library we can add an “export to keras” script as well. |
@StellaAthena hello EleutherAI! Thanks so much! This is very helpful context. Let's go with the PE differences we would probably consider to be a separate architecture. MPT is ALiBi not RoPE (I think?), so that is probably something we would sort into a separate architecture in our |
…st_name still fails though)
My |
@mattdangerw Here is the working model+ checkpoints loader without cached attention I'm looking into matching the outputs |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Left some miscellaneous comments, but think you are overall on the right track already. First step will be to confirm we have an equivalent forward pass with "upstream" versions, then we can refine the code here more.
rotary_emb_base=10000, | ||
kernel_initializer="glorot_uniform", | ||
bias_initializer="zeros", | ||
use_parallel_residual=True, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is this value both true and false for checkpoints we care about? or does one win out?
if the latter, we could consider ditching this
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes it is True
for all pythia and gpt-neox-20b checkpoints
1.0 / (self.rotary_emb_base ** (range / self.dim)) | ||
) | ||
|
||
@staticmethod |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why bother marking this a static method and leaving it public?
seems like we could just leave this as a private _apply_rotary_pos_emb
regular method for now (and let the fact that this does not access anything on self be incidental)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sounds good
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is still marked as a staticmethod, I think there is no need.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A few more comments.
|
||
|
||
class RotaryEmbedding(keras.layers.Layer): | ||
def __init__(self, dim, rotary_emb_base=10000): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should try to name this more descriptively than dim
. Is this the "hidden dim" of the model? If so let's call this hidden_dim
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@mattdangerw It's not hidden_dim
it is attn_head_size * rotary_pct
.
How about we rename it to rotary_ndims
itself?
|
||
def build(self, input_shape): | ||
super().build(input_shape) | ||
self.inverse_freq = self.add_weight( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
actually looking at this, this inverse_freq
should all be static right? if we don't need this trainable, instead of having this be a weight, let's move this into call
somewhere, we can just compute it on the fly
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey @mattdangerw ! We can definitely do that, but I would like you to take a look at this. https://github.com/huggingface/transformers/blob/17a55534f5e5df10ac4804d4270bf6b8cc24998d/src/transformers/models/esm/modeling_tf_esm.py#L102-L107
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure I follow the comment totally. The issue looks to be a precision one, but if the goal is to keep these explicitly as float32, why not just compute them on the fly with an explicit float32 dtype? I still don't understand the need for a variable. And the fact that this is a trainable seems incorrect looking at the torch implementation, these are not trainable in torch.
In general, I would be careful attempting to apply what seems like a fairly technical point about esm checkpoints to other models. Ideally we would just check how close our forward pass outputs are for the actual pythia checkpoints under fully precision (float32 everywhere), and mixed precision (float32 for variables, float16 for computations), and use that to determine our approach here.
sin_emb = sin_emb[:, : tf.shape(tensor)[1], :, :] | ||
x1, x2 = tf.split(tensor, 2, axis=-1) | ||
half_rot_tensor = tf.concat((-x2, x1), axis=-1) | ||
# Incompatible shapes: [32,256,8,2] vs. [1,256,1,16] [Op:Mul] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
remember to cleanup little notes like this
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done !
max_sequence_length=512, | ||
kernel_initializer="glorot_uniform", | ||
bias_initializer="zeros", | ||
rotary_pct=0.25, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
let's consider better names for these two argument. probably rotary_percentage
is more consistent with Keras' style, and rotary_emb_base
is a little confusing, are there better names we could consider from the paper or elsewhere?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess one option is to document this the same as max_wavelength
in our SinePositionEncoding layer. https://keras.io/api/keras_nlp/modeling_layers/sine_position_encoding/
I'm not sure it's the best name, but at least it will be consistent across the library. We could name these arguments rotary_percentage
and rotary_max_wavelength
here, and just percentage
and max_wavelength
on the rotary layer itself.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fixed this !
] | ||
value = query_key_value[..., 2 * self.attn_head_size :] | ||
|
||
query_rot, query_pass = ( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wonder if we would be better off moving this slice and concat logic into the RotaryEmbedding
call. Then our usage here could look a little more like...
query = self.rotary_embedding(query)
key = self.rotary_embedding(key)
And the rotary embedding layer could also hold the percentage argument, which would conceptually be quite clean. Looks like falcon is doing this roughly -> https://huggingface.co/tiiuae/falcon-40b/blob/main/modelling_RW.py
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for this wonderful suggestion !
max_sequence_length=512, | ||
kernel_initializer="glorot_uniform", | ||
bias_initializer="zeros", | ||
rotary_pct=0.25, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess one option is to document this the same as max_wavelength
in our SinePositionEncoding layer. https://keras.io/api/keras_nlp/modeling_layers/sine_position_encoding/
I'm not sure it's the best name, but at least it will be consistent across the library. We could name these arguments rotary_percentage
and rotary_max_wavelength
here, and just percentage
and max_wavelength
on the rotary layer itself.
Putting back old conversion script unless we are done with presets. |
Thanks! Will take a look today. What is the deal with deduped vs not deduped by the way? Deduped training data sounds better to me, is there any reason to not just ignore the "non-deduped" checkpoints? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looking good! Left a few more comments.
restored_output = restored_model(self.input_batch) | ||
self.assertAllClose(model_output, restored_output) | ||
|
||
# def test_create_layout_map(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
just remove this for now, we can add/review the code in a follow up
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
removed !
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is still here on this diff, maybe you forgot to remove?
rotary_max_wavelength=10000, | ||
kernel_initializer="glorot_uniform", | ||
bias_initializer="zeros", | ||
use_parallel_residual=True, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we decided this was always true for now right? let's just remove the related code, add back if we need it
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sg !
# Infer the dimension of our hidden feature size from the build shape. | ||
hidden_dim = input_shape[-1] | ||
|
||
self._input_layernorm = keras.layers.LayerNormalization( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
these layernorms are confusingly named, from what i can tell, _input_layernorm
is the layernorm for self attention (applied first), _self_attention_layernorm
is the layernorm for the feedforward block (applied first).
I would rename these. _input_layernorm
-> _self_attention_layernorm
and _self_attention_layernorm
-> _feedforward_layernorm
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sg!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also, still unresolved.
|
||
def build(self, input_shape): | ||
super().build(input_shape) | ||
self.inverse_freq = self.add_weight( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure I follow the comment totally. The issue looks to be a precision one, but if the goal is to keep these explicitly as float32, why not just compute them on the fly with an explicit float32 dtype? I still don't understand the need for a variable. And the fact that this is a trainable seems incorrect looking at the torch implementation, these are not trainable in torch.
In general, I would be careful attempting to apply what seems like a fairly technical point about esm checkpoints to other models. Ideally we would just check how close our forward pass outputs are for the actual pythia checkpoints under fully precision (float32 everywhere), and mixed precision (float32 for variables, float16 for computations), and use that to determine our approach here.
1.0 / (self.rotary_emb_base ** (range / self.dim)) | ||
) | ||
|
||
@staticmethod |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is still marked as a staticmethod, I think there is no need.
output = self.backbone.token_embedding(self.input_batch["token_ids"]) | ||
self.assertEqual(output.shape, (2, 5, 64)) | ||
|
||
# def test_name(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The _1
would be from creating two backbones with the default name in the same session. I am able to get this test passing simply by uncommenting it and changing "gpt_neox_backbone"
to "gpt_neo_x_backbone"
.
Which actually brings up a good point. We should actually rename the directly and all files from gpt_neox_...
to -> gpt_neo_x_...
to match the way Keras automatically converts Camel to snake_case.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey! Started a review on this, but I think what happened is your changes might be on #1085 instead of here. Can you move your related changes onto this PR?
output = self.backbone.token_embedding(self.input_batch["token_ids"]) | ||
self.assertEqual(output.shape, (2, 5, 64)) | ||
|
||
# def test_name(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this is still unresolved. We should rename this directory and all files to gpt_neo_x...
and then enable this test. It should pass at that point.
restored_output = restored_model(self.input_batch) | ||
self.assertAllClose(model_output, restored_output) | ||
|
||
# def test_create_layout_map(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is still here on this diff, maybe you forgot to remove?
# Infer the dimension of our hidden feature size from the build shape. | ||
hidden_dim = input_shape[-1] | ||
|
||
self._input_layernorm = keras.layers.LayerNormalization( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also, still unresolved.
Moved tokenizer from #1085 to here. Resolved comments related to that PR. |
/gcbrun |
Yeah, I rebased this branch with |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looking good! Mostly minor stuff now!
self.rotary_percentage = rotary_percentage | ||
self.dropout = dropout | ||
self.attn_head_size = hidden_dim // num_heads | ||
self.rotary_ndims = int(self.attn_head_size * rotary_percentage) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It seems to me like you can move this line down into the rotary layer itself, you can get at attn_head_size
simply by reading the shape of the passed query
and value
right?
I would pass percentage
and max_wavelength
directly as arguments to RotaryEmbedding
, and keep all the logic there, that will keep things more compartmentalized.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
outputs match after this refactor :)
|
||
keras_model.get_layer("layer_norm").beta.assign(hf_wts["final_layer_norm.bias"]) | ||
|
||
hf_tokenizer = AutoTokenizer.from_pretrained(PRESET) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we should update this section to check tokenizer output for some simple input as well, now that we have added the tokenizer here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added this part, not using tokenizer input for model as we still don't have preprocessor.
outputs of tokenizer are same as hf tokenizer.
By the way they are using gpt-neox-20b
vocabulary for pythia suite.
/gcbrun |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks great! Pushing some minor style fixes as I land. For future PRs
- Make sure to follow the 80 character line limit for docstring besides links.
- Do document argument type, but do not document default values unless they are complex in some way. Simple defaults will already render in the keras.io signatures, e.g. https://keras.io/api/keras_nlp/models/bert/bert_backbone/.
Once tests are green I will pull this in!
/gcbrun |
Sounds good @mattdangerw ! All tests pass. |
Partially Completes #1052
Pythia uses GPT-Neo-X architecture
This PR adds implementation of
RotaryEmbedding
paperGPTNeoXAttention
GPTNeoXDecoder
LayerGPTNeoXBackbone
Colab