-
Notifications
You must be signed in to change notification settings - Fork 448
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Update fusion layer counting logic for Llama 3.2 weight conversion #1722
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/1722
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit df8cab3 with merge base 3fddc56 (): This comment was automatically generated by Dr. CI and updates every 15 minutes. |
num_fusion_layers = ( | ||
max(_layer_num(k) for k in state_dict if "cross_attention_layers" in k) + 1 | ||
num_fusion_layers = len( | ||
set([k.split(".")[2] for k in state_dict if "fusion_layer" in k]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can you add a comment explaining what the FQN looks like here? Why not just count number of "fusion_layer" in k for k in state_dict
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why not just count number of "fusion_layer" in k for k in state_dict?
There are multiple params for each layer so we need to dedup
related: #1721 |
@@ -148,8 +148,8 @@ def llama3_vision_tune_to_meta( | |||
|
|||
# Calculate fusion_interval: layer interval where cross attention layers are fused | |||
num_layers = max(_layer_num(k) for k in state_dict if "layers" in k) + 1 | |||
num_fusion_layers = ( | |||
max(_layer_num(k) for k in state_dict if "cross_attention_layers" in k) + 1 | |||
num_fusion_layers = len( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is the [2] referring to here? Isn't that the layer number?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah exactly
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for catching this. Can you just add a comment above the change saying you're getting the unique layer numbers or use the _layer_number function?
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #1722 +/- ##
==========================================
- Coverage 70.67% 67.64% -3.03%
==========================================
Files 299 304 +5
Lines 15251 15627 +376
==========================================
- Hits 10778 10571 -207
- Misses 4473 5056 +583
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Sentry. |
Checkpoint save errors without this change
Test plan:
Before: https://gist.github.com/ebsmothers/c9ad0175cedeb5ad2719aec4d266090d
After: