-
Notifications
You must be signed in to change notification settings - Fork 218
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fixed the llama model #769
Conversation
Signed-off-by: yiliu30 <[email protected]>
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/769
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 64480e7 with merge base 05224a9 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
Signed-off-by: yiliu30 <[email protected]>
This is relevant only to the training path that @gau-nernst added recently. He would be the one to ask, i'll add him as a reviewer. |
Sorry for the delay! From what I see, this PR fixes the situation when model is used for inference and model = Transformer.from_name(...)
model.setup_caches(training=False)
model(input_ids) However, it seems like the model is not used this way anywhere in torchao? Or do you have something else planned in your mind? Otherwise, the change looks fine to me. Note: from what I know, here are the 2 places where this Llama model is used
|
Hi @gau-nernst, thanks for providing such detailed background. I'm using this Llama model for auto-round. The scenario is similar to the training case you mentioned above(not using KV cache and no If the |
ao/torchao/_models/llama/model.py Lines 198 to 200 in e2dad4a
Given your use case, which also does not use KV cache, I don't think having I agree that having test is good to prevent regression. Since you are already working on this PR, can you also add a short docstring/comment in Llama's forward about how to use this model? i.e. for inference (w/ KV-cache + input_pos) and training (no KV-cache + no input_pos). |
test/test_ao_models.py
Outdated
random_model.setup_caches(max_batch_size=batch_size, max_seq_length=seq_len) | ||
out = random_model(input_ids, input_pos) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This will test the case "KV cache + input_pos". If this is not required, perhaps we can test for inference and training separately? i.e.
# inference case
random_model.setup_caches(training=False)
random_model(input_ids, input_pos) # input_pos is not None
# training case
random_model.setup_caches(training=True)
random_model(input_ids) # no input_pos
Yeah, but L208 update the ao/torchao/_models/llama/model.py Lines 198 to 208 in e2dad4a
|
torchao/_models/llama/model.py
Outdated
@@ -197,15 +197,17 @@ def forward(self, idx: Tensor, input_pos: Optional[Tensor] = None) -> Tensor: | |||
|
|||
if input_pos is None: | |||
mask = None | |||
input_pos = torch.arange(0, idx.shape[1], device=idx.device) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for your reply. It makes sense to separate the logic for creating mask
, input_pos
, freq_cis
like you did here. However, creating input_pos
here (for training case) is now not needed?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've also tested the inference case where input_pos
is None
. Is input_pos
a required argument for model.forward
in inference mode?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
inference case where input_pos is None
This is what I mentioned earlier. I think right now there is no code that uses inference w/o input_pos
, and your auto-round PR doesn't seem to need it also. I think it's fine to support this case (though we will have a tiny inefficiency -> create an unneeded input_pos
during training. probably insignificant). Perhaps others can have other comments.
Can you update the PR description to describe the problem this PR fixes clearer? i.e. when input_pos
is None (during training), freq_cis
is overridden by line xxx. I will approve the PR once you added tests for training mode (model.setup_caches(training=True)
) and add a short docstring/comment in model.forward()
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Got it, thanks for the clarification! I updated the PR description, UTs and docstring. Please review it again.
Signed-off-by: yiliu30 <[email protected]>
Signed-off-by: yiliu30 <[email protected]>
Signed-off-by: yiliu30 <[email protected]>
* fixed input_pos is None Signed-off-by: yiliu30 <[email protected]> * add test Signed-off-by: yiliu30 <[email protected]> * update the test Signed-off-by: yiliu30 <[email protected]> * update the docstring Signed-off-by: yiliu30 <[email protected]> * update the docstring Signed-off-by: yiliu30 <[email protected]> --------- Signed-off-by: yiliu30 <[email protected]>
In the training mode (
model.setup_caches(..., training=True)
) withinput_pos
isNone
, thefreq_cis
is overridden by L208.ao/torchao/_models/llama/model.py
Lines 198 to 208 in e2dad4a
This PR attempts to fix this issue, and added some tests for the Llama model.
Test