-
Notifications
You must be signed in to change notification settings - Fork 4.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
autoTP for fused qkv weight #3844
Conversation
we conducted some tests
|
@RezaYazdaniAminabadi @molly-smith |
src_split = torch.split(input, shape[0] // 3, dim=0) | ||
|
||
qkv_split = [torch.split(src_s, shape[0] // 3 // mp_size, dim=0) for src_s in src_split] | ||
split_fusedqkv = [ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
split_fusedqkv (cat operations) can be made as a separate method since there may be other variations using the same logic flow. @inkcherry ?
Also changes look good, @molly-smith @RezaYazdaniAminabadi could you please review and provide suggestions?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, I agree, and since we have the same logic for when using the TP for when the inference kernels are enabled then we can reuse this.
setattr(child, "replaced", True) | ||
return LinearLayer(weight=data.to(get_accelerator().current_device_name()), bias=bias_data) | ||
|
||
def require_tp_fused_qkvw(name): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@inkcherry, can we please move all these logic into a separate file, so that we can then refactor the tp-sharding logic based on this? thanks
Thanks @inkcherry for the great PR! |
Thanks for your suggestion: ) The file has been modified. |
Thanks @inkcherry for this contribution. It would be nice to update the supported/unsupported model lists in docs/_tutorials/automatic-tensor-parallelism.md and maybe add one of these fused qkv weight models to a unit test. |
sure, I placed it in this PR |
* autoTP for fused qkv weight * fix format * clean up * clean up * clean up * update * make logic flow to util and move to file * fix formatting * remove empty line --------- Co-authored-by: Olatunji Ruwase <[email protected]> Co-authored-by: Reza Yazdani <[email protected]> Co-authored-by: Jeff Rasley <[email protected]>
Bloom won't work with this since it's not taking mp_size into account |
To make autoTP work with some fusedqkv model.
The qkv weight of one rank format after TP should be consistent with no TP, and the heads should be correctly divided and arranged, so that the model can correctly parse the qkv matrix when do inference.
However, some models have different qkv fuse types, and we have found three types. Now types are temporarily named based on model name.
Assuming there is num_ head=n, qx represents the weight of the n-th q head, shape=(hidden dim, hidden dim)
Bloom (huggingface):
q1, k1, v1, q2, k2, v2,..., qn, kn, vn
GLM or WebGLM (huggingface)
q1, q2, q3,.. qn, k1, k2, k3,...,kn, v1, v2, v3,..., vn
Codegen (huggingface):
q1, q2,..., q(n/4), k1, k2,..., k(n/4), v1, v2,..., v(n/4), q(n/4+1),...,q(n/2), k(n/4+1),...,k(n/2), v(n/4+1),...,v(n/2),...
At first, we attempted to infer autoTP GLM but received incorrect results, but there were no errors or warnings. We think that we can add warnings to unrecognized fused qkv for user to better troubleshoot the problem.
For other model use fused_qkv, they can also work by matching or adding fused types.