Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

autoTP for fused qkv weight #3844

Merged
merged 12 commits into from
Jul 27, 2023
Merged

Conversation

inkcherry
Copy link
Contributor

@inkcherry inkcherry commented Jun 29, 2023

To make autoTP work with some fusedqkv model.
The qkv weight of one rank format after TP should be consistent with no TP, and the heads should be correctly divided and arranged, so that the model can correctly parse the qkv matrix when do inference.

However, some models have different qkv fuse types, and we have found three types. Now types are temporarily named based on model name.

Assuming there is num_ head=n, qx represents the weight of the n-th q head, shape=(hidden dim, hidden dim)

Bloom (huggingface):
q1, k1, v1, q2, k2, v2,..., qn, kn, vn

GLM or WebGLM (huggingface)

q1, q2, q3,.. qn, k1, k2, k3,...,kn, v1, v2, v3,..., vn

Codegen (huggingface):

q1, q2,..., q(n/4), k1, k2,..., k(n/4), v1, v2,..., v(n/4), q(n/4+1),...,q(n/2), k(n/4+1),...,k(n/2), v(n/4+1),...,v(n/2),...

At first, we attempted to infer autoTP GLM but received incorrect results, but there were no errors or warnings. We think that we can add warnings to unrecognized fused qkv for user to better troubleshoot the problem.

For other model use fused_qkv, they can also work by matching or adding fused types.

@inkcherry
Copy link
Contributor Author

inkcherry commented Jun 30, 2023

we conducted some tests

  • glm 350M mp=1/2 work
  • codegen 350M mp=1/2 work
  • codegen 6B mp=1/2/4 work
  • glm 10B mp=1/2/4/8 work
  • codegen16B mp=1/2/3/6 work

@inkcherry
Copy link
Contributor Author

@RezaYazdaniAminabadi @molly-smith
Could you please help review this pr? Will appreciate your feedback.

src_split = torch.split(input, shape[0] // 3, dim=0)

qkv_split = [torch.split(src_s, shape[0] // 3 // mp_size, dim=0) for src_s in src_split]
split_fusedqkv = [
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

split_fusedqkv (cat operations) can be made as a separate method since there may be other variations using the same logic flow. @inkcherry ?
Also changes look good, @molly-smith @RezaYazdaniAminabadi could you please review and provide suggestions?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I agree, and since we have the same logic for when using the TP for when the inference kernels are enabled then we can reuse this.

setattr(child, "replaced", True)
return LinearLayer(weight=data.to(get_accelerator().current_device_name()), bias=bias_data)

def require_tp_fused_qkvw(name):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@inkcherry, can we please move all these logic into a separate file, so that we can then refactor the tp-sharding logic based on this? thanks

@RezaYazdaniAminabadi
Copy link
Contributor

To make autoTP work with some fusedqkv model. The qkv weight of one rank format after TP should be consistent with no TP, and the heads should be correctly divided and arranged, so that the model can correctly parse the qkv matrix when do inference.

However, some models have different qkv fuse types, and we have found three types. Now types are temporarily named based on model name.

Assuming there is num_ head=n, qx represents the weight of the n-th q head, shape=(hidden dim, hidden dim)

Bloom (huggingface): q1, k1, v1, q2, k2, v2,..., qn, kn, vn

GLM or WebGLM (huggingface)

q1, q2, q3,.. qn, k1, k2, k3,...,kn, v1, v2, v3,..., vn

Codegen (huggingface):

q1, q2,..., q(n/4), k1, k2,..., k(n/4), v1, v2,..., v(n/4), q(n/4+1),...,q(n/2), k(n/4+1),...,k(n/2), v(n/4+1),...,v(n/2),...

At first, we attempted to infer autoTP GLM but received incorrect results, but there were no errors or warnings. We think that we can add warnings to unrecognized fused qkv for user to better troubleshoot the problem.

For other model use fused_qkv, they can also work by matching or adding fused types.

Thanks @inkcherry for the great PR!
I have gone through the changes they look good to me. I just have some minor comments.

@inkcherry
Copy link
Contributor Author

Thanks for your suggestion: ) The file has been modified.
@RezaYazdaniAminabadi @tjruwase

cc @guoyejun @abhilash1910

@molly-smith molly-smith self-requested a review July 27, 2023 20:12
@jeffra jeffra enabled auto-merge July 27, 2023 20:16
@molly-smith
Copy link
Contributor

molly-smith commented Jul 27, 2023

Thanks @inkcherry for this contribution. It would be nice to update the supported/unsupported model lists in docs/_tutorials/automatic-tensor-parallelism.md and maybe add one of these fused qkv weight models to a unit test.

@jeffra jeffra disabled auto-merge July 27, 2023 20:22
@jeffra jeffra enabled auto-merge July 27, 2023 20:24
@jeffra jeffra added this pull request to the merge queue Jul 27, 2023
Merged via the queue into microsoft:master with commit 6b877d2 Jul 27, 2023
@inkcherry
Copy link
Contributor Author

Thanks @inkcherry for this contribution. It would be nice to update the supported/unsupported model lists in docs/_tutorials/automatic-tensor-parallelism.md and maybe add one of these fused qkv weight models to a unit test.

sure, I placed it in this PR
#4057

polisettyvarma pushed a commit to polisettyvarma/DeepSpeed that referenced this pull request Aug 7, 2023
* autoTP for fused qkv weight

* fix format

* clean up

* clean up

* clean up

* update

* make logic flow to util and move to file

* fix formatting

* remove empty line

---------

Co-authored-by: Olatunji Ruwase <[email protected]>
Co-authored-by: Reza Yazdani <[email protected]>
Co-authored-by: Jeff Rasley <[email protected]>
@andre-bauer
Copy link

Bloom won't work with this since it's not taking mp_size into account

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants