forked from microsoft/DeepSpeed
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
autoTP for fused qkv weight (microsoft#3844)
* autoTP for fused qkv weight * fix format * clean up * clean up * clean up * update * make logic flow to util and move to file * fix formatting * remove empty line --------- Co-authored-by: Olatunji Ruwase <[email protected]> Co-authored-by: Reza Yazdani <[email protected]> Co-authored-by: Jeff Rasley <[email protected]>
- Loading branch information
1 parent
aebd03d
commit b671585
Showing
3 changed files
with
113 additions
and
6 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,92 @@ | ||
# Copyright (c) Microsoft Corporation. | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
# DeepSpeed Team | ||
import torch | ||
from deepspeed.utils.logging import warning_once | ||
import re | ||
|
||
|
||
def split_by_qkvlist_and_refuse(qkv_list, split_size, split_dim=0, cat_dim=0): | ||
qkv_split_list = [torch.split(mat, split_size, dim=split_dim) for mat in qkv_list] | ||
tp_fusedqkv_list = [ | ||
torch.cat([qkv_s[i] for qkv_s in qkv_split_list], dim=cat_dim) for i in range(len(qkv_split_list[0])) | ||
] | ||
return tp_fusedqkv_list | ||
|
||
|
||
def require_tp_fused_qkvw(name, mp_size): | ||
fused_qkvw_name_list = ['qkv_proj', 'query_key_value'] | ||
|
||
if mp_size == 1: | ||
return False | ||
for fused_name in fused_qkvw_name_list: | ||
if fused_name in name: | ||
return True | ||
return False | ||
|
||
|
||
def prepare_tp_fused_qkvw(module_str, src, mp_size, gpu_index): | ||
if src == None: | ||
return | ||
fused_type_dict = { | ||
'CodeGenBlock': 'codegentype', | ||
'BloomBlock': 'bloomtype', | ||
'GLMBlock': 'glmtype', | ||
} | ||
|
||
def _codegen_type_transpose(input, mp_size, codegen_mp_num=4): | ||
# codegen_mp_num defined in https://github.com/huggingface/transformers/blob/main/src/transformers/models/codegen/modeling_codegen.py | ||
#TODO: assert num_heads % (mp_size*codegen_mp_num) == 0 | ||
#input : [3*hidden_dim, hidden_dim](weight) or [3*hidden_dim](bias) | ||
|
||
shape = input.shape | ||
dst_shape = shape[0] // mp_size | ||
num_mp_blocks = input.reshape(codegen_mp_num, shape[0] // codegen_mp_num, shape[1]) | ||
|
||
#num_mp_blocks : [codegen_mp_num, 3*hidden_dim/codegen_mp_num, :] | ||
src_split = list(torch.split(num_mp_blocks, num_mp_blocks.shape[1] // 3, dim=1)) | ||
src_split = [x.reshape(codegen_mp_num * mp_size, -1, shape[1]) for x in src_split] | ||
|
||
split_fusedqkv = split_by_qkvlist_and_refuse(src_split, shape[0] // 3 // mp_size, 0, 1) | ||
tp_fuseqkv_weight = torch.cat(split_fusedqkv, dim=0).reshape(shape[0], -1) | ||
|
||
return tp_fuseqkv_weight[gpu_index * dst_shape:(gpu_index + 1) * dst_shape] | ||
|
||
def _glm_type_transpose(input, mp_size): | ||
#input : [3*hidden_dim, hidden_dim](weight) or [3*hidden_dim](bias) | ||
|
||
shape = input.shape | ||
dst_shape = shape[0] // mp_size | ||
src_split = torch.split(input, shape[0] // 3, dim=0) | ||
|
||
split_fusedqkv = split_by_qkvlist_and_refuse(src_split, shape[0] // 3 // mp_size) | ||
tp_fuseqkv_weight = torch.cat(split_fusedqkv, dim=0) | ||
|
||
return tp_fuseqkv_weight[gpu_index * dst_shape:(gpu_index + 1) * dst_shape] | ||
|
||
def _bloom_type_transpose(input, mp_size): | ||
return input | ||
|
||
def _transpose_fused_qkvw(src, mp_size, fused_qkv_type=None): | ||
|
||
# suppose num_heads=n, q(n)_w means the n-th q head linear weight, the weight format are as following | ||
# bloomtype: [q(1)_w,k(1)_w,v(1)_w,q(2)_w,k(2)_w,v(2)_w,...,q(n)_w,k(n)_w,v(n)_w] | ||
# glmtype: [q(1)_w, q(2)_w,...,q(n)_w,k(1)_w,k(2)_w,...,k(n)_w,v(1)_w,v(2)_w,...,v(n)_w] | ||
# codegentype: [q(1)_w,q(2)_w,...,q(n/t)_w,k(1)_w,k(2)_w,...,k(n/t)_w,v(1)_2,v(2)_w,...v(n/t)_w,q(n/t+1)_w,...], where t is a const defined in model file. | ||
|
||
if fused_qkv_type == 'bloomtype': | ||
return _bloom_type_transpose(src, mp_size) | ||
elif fused_qkv_type == 'codegentype': | ||
return _codegen_type_transpose(src, mp_size) | ||
elif fused_qkv_type == 'glmtype': | ||
return _glm_type_transpose(src, mp_size) | ||
|
||
raise ValueError("unknown fused_qkv_type") | ||
|
||
for module_name, fused_type in fused_type_dict.items(): | ||
if re.search(module_name, module_str): | ||
return _transpose_fused_qkvw(src, mp_size, fused_type) | ||
warning_once(f"Unrecognized fusedkqv weight type, default to using bloom type," | ||
f"please check in prepare_tp_fused_qkvw() to avoid potential calculation errors") | ||
return src |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters