Skip to content

Commit

Permalink
autoTP for fused qkv weight (microsoft#3844)
Browse files Browse the repository at this point in the history
* autoTP for fused qkv weight

* fix format

* clean up

* clean up

* clean up

* update

* make logic flow to util and move to file

* fix formatting

* remove empty line

---------

Co-authored-by: Olatunji Ruwase <[email protected]>
Co-authored-by: Reza Yazdani <[email protected]>
Co-authored-by: Jeff Rasley <[email protected]>
  • Loading branch information
4 people authored and polisettyvarma committed Aug 7, 2023
1 parent aebd03d commit b671585
Show file tree
Hide file tree
Showing 3 changed files with 113 additions and 6 deletions.
2 changes: 1 addition & 1 deletion deepspeed/module_inject/auto_tp.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def get_module_list(model):
return mlist

def supported(model):
unsupported = ['codegen', 'deberta', 'flaubert', 'fsmt', 'gpt2', 'led', 'longformer', 'xlm', 'xlnet']
unsupported = ['deberta', 'flaubert', 'fsmt', 'gpt2', 'led', 'longformer', 'xlm', 'xlnet']
model = str(model)
key = re.search(r": (.*?)Model", model)
if key is None:
Expand Down
92 changes: 92 additions & 0 deletions deepspeed/module_inject/fusedqkv_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0

# DeepSpeed Team
import torch
from deepspeed.utils.logging import warning_once
import re


def split_by_qkvlist_and_refuse(qkv_list, split_size, split_dim=0, cat_dim=0):
qkv_split_list = [torch.split(mat, split_size, dim=split_dim) for mat in qkv_list]
tp_fusedqkv_list = [
torch.cat([qkv_s[i] for qkv_s in qkv_split_list], dim=cat_dim) for i in range(len(qkv_split_list[0]))
]
return tp_fusedqkv_list


def require_tp_fused_qkvw(name, mp_size):
fused_qkvw_name_list = ['qkv_proj', 'query_key_value']

if mp_size == 1:
return False
for fused_name in fused_qkvw_name_list:
if fused_name in name:
return True
return False


def prepare_tp_fused_qkvw(module_str, src, mp_size, gpu_index):
if src == None:
return
fused_type_dict = {
'CodeGenBlock': 'codegentype',
'BloomBlock': 'bloomtype',
'GLMBlock': 'glmtype',
}

def _codegen_type_transpose(input, mp_size, codegen_mp_num=4):
# codegen_mp_num defined in https://github.com/huggingface/transformers/blob/main/src/transformers/models/codegen/modeling_codegen.py
#TODO: assert num_heads % (mp_size*codegen_mp_num) == 0
#input : [3*hidden_dim, hidden_dim](weight) or [3*hidden_dim](bias)

shape = input.shape
dst_shape = shape[0] // mp_size
num_mp_blocks = input.reshape(codegen_mp_num, shape[0] // codegen_mp_num, shape[1])

#num_mp_blocks : [codegen_mp_num, 3*hidden_dim/codegen_mp_num, :]
src_split = list(torch.split(num_mp_blocks, num_mp_blocks.shape[1] // 3, dim=1))
src_split = [x.reshape(codegen_mp_num * mp_size, -1, shape[1]) for x in src_split]

split_fusedqkv = split_by_qkvlist_and_refuse(src_split, shape[0] // 3 // mp_size, 0, 1)
tp_fuseqkv_weight = torch.cat(split_fusedqkv, dim=0).reshape(shape[0], -1)

return tp_fuseqkv_weight[gpu_index * dst_shape:(gpu_index + 1) * dst_shape]

def _glm_type_transpose(input, mp_size):
#input : [3*hidden_dim, hidden_dim](weight) or [3*hidden_dim](bias)

shape = input.shape
dst_shape = shape[0] // mp_size
src_split = torch.split(input, shape[0] // 3, dim=0)

split_fusedqkv = split_by_qkvlist_and_refuse(src_split, shape[0] // 3 // mp_size)
tp_fuseqkv_weight = torch.cat(split_fusedqkv, dim=0)

return tp_fuseqkv_weight[gpu_index * dst_shape:(gpu_index + 1) * dst_shape]

def _bloom_type_transpose(input, mp_size):
return input

def _transpose_fused_qkvw(src, mp_size, fused_qkv_type=None):

# suppose num_heads=n, q(n)_w means the n-th q head linear weight, the weight format are as following
# bloomtype: [q(1)_w,k(1)_w,v(1)_w,q(2)_w,k(2)_w,v(2)_w,...,q(n)_w,k(n)_w,v(n)_w]
# glmtype: [q(1)_w, q(2)_w,...,q(n)_w,k(1)_w,k(2)_w,...,k(n)_w,v(1)_w,v(2)_w,...,v(n)_w]
# codegentype: [q(1)_w,q(2)_w,...,q(n/t)_w,k(1)_w,k(2)_w,...,k(n/t)_w,v(1)_2,v(2)_w,...v(n/t)_w,q(n/t+1)_w,...], where t is a const defined in model file.

if fused_qkv_type == 'bloomtype':
return _bloom_type_transpose(src, mp_size)
elif fused_qkv_type == 'codegentype':
return _codegen_type_transpose(src, mp_size)
elif fused_qkv_type == 'glmtype':
return _glm_type_transpose(src, mp_size)

raise ValueError("unknown fused_qkv_type")

for module_name, fused_type in fused_type_dict.items():
if re.search(module_name, module_str):
return _transpose_fused_qkvw(src, mp_size, fused_type)
warning_once(f"Unrecognized fusedkqv weight type, default to using bloom type,"
f"please check in prepare_tp_fused_qkvw() to avoid potential calculation errors")
return src
25 changes: 20 additions & 5 deletions deepspeed/module_inject/replace_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,15 @@
from deepspeed.accelerator import get_accelerator
from .replace_policy import HFGPT2LayerPolicy
from .replace_policy import replace_policies, generic_policies

from deepspeed import comm as dist
from torch import nn
from .fusedqkv_utils import require_tp_fused_qkvw, prepare_tp_fused_qkvw

from .layers import LinearAllreduce, LinearLayer
from .load_checkpoint import load_model_with_checkpoint
import time

from .utils import policy_to_ds_container

import gc


Expand Down Expand Up @@ -396,6 +395,7 @@ def _replace(child, name, conv_linear_layer):
dtype=child.weight.dtype)
if conv_linear_layer:
child.weight.data = child.weight.data.transpose(-1, -2).contiguous()

data = mp_replace.copy(new_weight, child.weight.data)
new_bias = torch.empty((weight_shape[0]), device=child.weight.device, dtype=child.weight.dtype)
if child.bias is not None:
Expand All @@ -412,13 +412,28 @@ def _replace(child, name, conv_linear_layer):
dtype=child.weight.dtype)
if conv_linear_layer:
child.weight.data = child.weight.data.transpose(-1, -2).contiguous()
data = mp_replace.copy(new_weight, child.weight.data)

new_bias = torch.empty((weight_shape[0] // mp_size),
device=child.weight.device,
dtype=child.weight.dtype)
bias_data = None if child.bias is None else mp_replace.copy(new_bias, child.bias.data).to(
get_accelerator().current_device_name())

if require_tp_fused_qkvw(name, mp_size):
#for detecting fused type
module_str = str(module).strip()
#The copy is a regular copy, The shape of dst and src is the same
data = mp_replace.copy(
new_weight, prepare_tp_fused_qkvw(module_str, child.weight.data, mp_size,
mp_replace.gpu_index))

bias_data = None if child.bias is None else mp_replace.copy(
new_bias, prepare_tp_fused_qkvw(module_str, child.bias.data, mp_size,
mp_replace.gpu_index)).to(
get_accelerator().current_device_name())
else:
data = mp_replace.copy(new_weight, child.weight.data)
bias_data = None if child.bias is None else mp_replace.copy(new_bias, child.bias.data).to(
get_accelerator().current_device_name())

setattr(child, "replaced", True)
return LinearLayer(weight=data.to(get_accelerator().current_device_name()), bias=bias_data)

Expand Down

0 comments on commit b671585

Please sign in to comment.