Skip to content

Commit

Permalink
overlap communication and computation
Browse files Browse the repository at this point in the history
Signed-off-by: daquexian <[email protected]>
  • Loading branch information
daquexian committed Jul 8, 2023
1 parent f721d3a commit e84dea7
Showing 1 changed file with 114 additions and 22 deletions.
136 changes: 114 additions & 22 deletions rwkv_pip_package/src/rwkv/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM
########################################################################################################

import types, gc, os, time, re
from collections import ChainMap
import types, gc, os, time, re, contextlib
import torch
from torch.nn import functional as F
torch.backends.cudnn.benchmark = True
Expand Down Expand Up @@ -77,12 +78,13 @@ def cuda_mm8_one(N: int, M: int, x, w, mx, rx, my, ry):
class RWKV(MyModule):
def __init__(self, model, strategy, verbose = True, convert_and_save_and_exit = None):
super().__init__()
self.prefetch_buffer = {}
if verbose:
prxxx = lambda *args, **kwargs: print(*args, **kwargs)
else:
prxxx = lambda *args, **kwargs: None

STRATEGY_REGEX = r"^(?:(?:^|->) *(?:cuda(?::[\d]+)?|cpu|mps) (?:fp(?:16|32)|bf16)(?:i8|i4|i3)?(?: \*[\d]+\+?)? *)+$"
STRATEGY_REGEX = r"^(?:(?:^|->) *(?:cuda(?::[\d]+)?|cpu|mps) (?:fp(?:16|32)|bf16)(?:i8|i4|i3)?(?: \*[\d]+\+?(\d+)?)? *)+$"
if not re.match(STRATEGY_REGEX, strategy):
raise ValueError("Invalid strategy. Please read https://pypi.org/project/rwkv/")

Expand All @@ -91,6 +93,7 @@ def __init__(self, model, strategy, verbose = True, convert_and_save_and_exit =
args = self.args
args.MODEL_NAME = model
args.strategy_string = strategy
args.prefetch_layers_num = 0

# Rescale for fp16 mode: set x = x/2 every X layer (to avoid fp16 overflow)
self.RESCALE_LAYER = 6 if 'fp16' in strategy else 0
Expand Down Expand Up @@ -144,9 +147,15 @@ def __init__(self, model, strategy, verbose = True, convert_and_save_and_exit =
if len(si) > 2:
ss = si[2]
assert ss.startswith('*')
if ss.endswith('+'):
plan[i] = int(ss[1:-1])
if '+' in ss:
stream_i = i
if ss.endswith('+'):
plan[i] = int(ss[1:-1])
args.prefetch_layers_num = 0
else:
plan[i] = int(ss.split('+')[0][1:])
args.prefetch_layers_num = int(ss.split('+')[-1])
assert args.prefetch_layers_num >= 0
else:
plan[i] = int(ss[1:])
allocated += plan[i]
Expand Down Expand Up @@ -556,9 +565,12 @@ def cuda_att_seq_i8(self, x, sx, aa, bb, pp, ln_w, ln_b, k_mix, v_mix, r_mix, t_

def forward(self, tokens, state, full_output=False):
with torch.no_grad():
w = self.w
args = self.args

if not hasattr(self, 'prefetch_stream'):
self.prefetch_stream = torch.cuda.Stream()
self.prefetch_events = [None] * args.prefetch_layers_num

if state == None:
state = [None] * args.n_layer * 5
for i in range(args.n_layer): # state: 0=att_xx 1=att_aa 2=att_bb 3=att_pp 4=ffn_xx
Expand All @@ -573,13 +585,57 @@ def forward(self, tokens, state, full_output=False):

seq_mode = len(tokens) > 1

pb = self.prefetch_buffer

# init prefetch_buffer
for i in range(args.prefetch_layers_num):
dd = self.strategy[i]
dev = dd.device
if dd.stream:
att = f'blocks.{i}.att.'
kw_key = f'{att}key.weight'
vw_key = f'{att}value.weight'
rw_key = f'{att}receptance.weight'
ow_key = f'{att}output.weight'

with torch.cuda.stream(self.prefetch_stream):
pb[kw_key] = self.w[kw_key].to(device=dev, non_blocking=True)
pb[vw_key] = self.w[vw_key].to(device=dev, non_blocking=True)
pb[rw_key] = self.w[rw_key].to(device=dev, non_blocking=True)
pb[ow_key] = self.w[ow_key].to(device=dev, non_blocking=True)

ffn = f'blocks.{i}.ffn.'
kw_key = f'{ffn}key.weight'
vw_key = f'{ffn}value.weight'
rw_key = f'{ffn}receptance.weight'

with torch.cuda.stream(self.prefetch_stream):
pb[kw_key] = self.w[kw_key].to(device=dev, non_blocking=True)
pb[vw_key] = self.w[vw_key].to(device=dev, non_blocking=True)
pb[rw_key] = self.w[rw_key].to(device=dev, non_blocking=True)
event = torch.cuda.Event()
event.record(torch.cuda.current_stream())
self.prefetch_events[i % args.prefetch_layers_num] = event

# find in self.prefetch_buffer, and find in self.w if not found
w = ChainMap(pb, self.w)

x = w['emb.weight'][tokens if seq_mode else tokens[0]]

for i in range(args.n_layer):
bbb = f'blocks.{i}.'
att = f'blocks.{i}.att.'
ffn = f'blocks.{i}.ffn.'
dd = self.strategy[i]
if dd.stream and args.prefetch_layers_num > 0:
self.prefetch_events[i % args.prefetch_layers_num].synchronize()
self.prefetch_events[i % args.prefetch_layers_num] = None

pf_layer = i + args.prefetch_layers_num
if pf_layer < args.n_layer:
prefetch = self.strategy[pf_layer].stream
else:
prefetch = False
dev = dd.device
atype = dd.atype
wtype = dd.wtype
Expand All @@ -595,15 +651,30 @@ def forward(self, tokens, state, full_output=False):

x = x.to(dtype=atype, device=dev)

kw = w[f'{att}key.weight']
vw = w[f'{att}value.weight']
rw = w[f'{att}receptance.weight']
ow = w[f'{att}output.weight']
if dd.stream:
kw = kw.to(device=dev, non_blocking=True)
vw = vw.to(device=dev, non_blocking=True)
rw = rw.to(device=dev, non_blocking=True)
ow = ow.to(device=dev, non_blocking=True)
if prefetch:
pf_att = f'blocks.{pf_layer}.att.'
pf_kw_key = f'{pf_att}key.weight'
pf_vw_key = f'{pf_att}value.weight'
pf_rw_key = f'{pf_att}receptance.weight'
pf_ow_key = f'{pf_att}output.weight'

with torch.cuda.stream(self.prefetch_stream) if args.prefetch_layers_num > 0 else contextlib.nullcontext():
pb[pf_kw_key] = self.w[pf_kw_key].to(device=dev, non_blocking=True)
pb[pf_vw_key] = self.w[pf_vw_key].to(device=dev, non_blocking=True)
pb[pf_rw_key] = self.w[pf_rw_key].to(device=dev, non_blocking=True)
pb[pf_ow_key] = self.w[pf_ow_key].to(device=dev, non_blocking=True)
kw_key = f'{att}key.weight'
vw_key = f'{att}value.weight'
rw_key = f'{att}receptance.weight'
ow_key = f'{att}output.weight'
kw = w[kw_key]
vw = w[vw_key]
rw = w[rw_key]
ow = w[ow_key]
assert kw.device.type == dev, f'{kw.device.type} != {dev}, {att=}'
assert vw.device.type == dev, f'{vw.device.type} != {dev}, {att=}'
assert rw.device.type == dev, f'{rw.device.type} != {dev}, {att=}'
assert ow.device.type == dev, f'{ow.device.type} != {dev}, {att=}'
kmx = w[f'{att}key.weight_mx'] if wtype == torch.uint8 else x
krx = w[f'{att}key.weight_rx'] if wtype == torch.uint8 else x
kmy = w[f'{att}key.weight_my'] if wtype == torch.uint8 else x
Expand Down Expand Up @@ -633,14 +704,32 @@ def forward(self, tokens, state, full_output=False):
)
if dd.stream:
del kw, vw, rw, ow

kw = w[f'{ffn}key.weight']
vw = w[f'{ffn}value.weight']
rw = w[f'{ffn}receptance.weight']
if dd.stream:
kw = kw.to(device=dev, non_blocking=True)
vw = vw.to(device=dev, non_blocking=True)
rw = rw.to(device=dev, non_blocking=True)
del pb[kw_key]
del pb[vw_key]
del pb[rw_key]
del pb[ow_key]

if prefetch:
pf_ffn = f'blocks.{pf_layer}.ffn.'
pf_kw_key = f'{pf_ffn}key.weight'
pf_vw_key = f'{pf_ffn}value.weight'
pf_rw_key = f'{pf_ffn}receptance.weight'

with torch.cuda.stream(self.prefetch_stream) if args.prefetch_layers_num > 0 else contextlib.nullcontext():
pb[pf_kw_key] = self.w[pf_kw_key].to(device=dev, non_blocking=True)
pb[pf_vw_key] = self.w[pf_vw_key].to(device=dev, non_blocking=True)
pb[pf_rw_key] = self.w[pf_rw_key].to(device=dev, non_blocking=True)
if args.prefetch_layers_num > 0:
event = torch.cuda.Event()
event.record(torch.cuda.current_stream())
self.prefetch_events[i % args.prefetch_layers_num] = event

kw_key = f'{ffn}key.weight'
vw_key = f'{ffn}value.weight'
rw_key = f'{ffn}receptance.weight'
kw = w[kw_key]
vw = w[vw_key]
rw = w[rw_key]
kmx = w[f'{ffn}key.weight_mx'] if wtype == torch.uint8 else x
krx = w[f'{ffn}key.weight_rx'] if wtype == torch.uint8 else x
kmy = w[f'{ffn}key.weight_my'] if wtype == torch.uint8 else x
Expand All @@ -664,6 +753,9 @@ def forward(self, tokens, state, full_output=False):
)
if dd.stream:
del kw, vw, rw
del pb[kw_key]
del pb[vw_key]
del pb[rw_key]

if self.RESCALE_LAYER > 0:
if (i+1) % self.RESCALE_LAYER == 0:
Expand Down

0 comments on commit e84dea7

Please sign in to comment.