Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Eagle speculative decoding part 4: Add EAGLE2 worker #2150

Merged
merged 94 commits into from
Jan 2, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
94 commits
Select commit Hold shift + click to select a range
70135d6
temp
Oct 13, 2024
65fae7b
migrated to new upstream, need implement evict memory
Oct 14, 2024
064cca6
prove single req
Oct 15, 2024
cb01c64
fix bug for long generate due to eagle_verify_retrive kernel
Oct 16, 2024
df3de9d
fix bug of eagle spec verify
Oct 16, 2024
b7628f2
support cuda graph
Oct 19, 2024
e2634e9
support batch inference
Oct 22, 2024
f557a06
temp
Oct 23, 2024
9987741
fix memeory leak
Oct 24, 2024
dcbc11c
add sampling score
Oct 24, 2024
5578b18
support target model cuda graph
Oct 24, 2024
af2e79a
disable target model cuda graph
Oct 24, 2024
0fdd0b1
fix batch bug
Oct 24, 2024
923523b
disable cuda graph pad in eagle
Oct 25, 2024
0e3fea2
fix server args
Oct 25, 2024
4faaa31
fix cuda graph and split prefill
Oct 25, 2024
33d8aef
optimize generate attn arg
Oct 25, 2024
11d6e86
fix parent list dtype
Oct 25, 2024
2b3cb22
fix draft worker memory problem
Oct 26, 2024
7aa0aff
need to fix decode error when request retract happend
Oct 26, 2024
404c5ab
remove debug info
Oct 26, 2024
e095ec0
fix bug
Oct 27, 2024
9f0a0c2
fix some bug and support target model use cuda graph
Oct 29, 2024
35c5678
fix conflict, should solve scheduler and cuda graph problem
Nov 1, 2024
b647a70
fix naive cuda graph
Nov 1, 2024
7226987
fix cuda graph
Nov 2, 2024
aaf1cae
support split prefill batch
Nov 2, 2024
dbeaa2c
fix cuda graph padding
Nov 2, 2024
b6f45d5
remove modification of target model and remove some redundant code
Nov 4, 2024
7c4a04c
fix cache management
Nov 5, 2024
8c87835
merge main and fix conflict
Nov 6, 2024
df50c13
add eagle example
Nov 7, 2024
9a6916f
recover original example
Nov 7, 2024
8cf2bf8
fix split prefill
Nov 8, 2024
12948c5
fix code style
Nov 11, 2024
5618ebb
update to main and fix conflict
Nov 12, 2024
8aec12b
add cutex to ci dependency
Nov 12, 2024
0dc2b08
add default value for speculative argument
Nov 12, 2024
faf9f50
fix code style
Nov 12, 2024
a62eb85
fix ci
Nov 14, 2024
84d454f
update and fix conflict
Nov 18, 2024
a1c9a5c
update and fix conflict
Nov 22, 2024
5413e5d
fix run with dp
Nov 22, 2024
4c93b3e
fix github comment
Nov 22, 2024
a3cfa17
remove flashinfer_utils
Nov 22, 2024
8487760
update and fix conflict
Nov 23, 2024
62d44cb
change forward_mode.target_verify to is_extend
Nov 23, 2024
1b5b3b7
add eagle test to ci
Nov 23, 2024
060be3a
test ci
Nov 23, 2024
de041ab
add eagle test to run_suit
Nov 24, 2024
e416008
Merge branch 'main' into spec_infer
merrymercy Nov 24, 2024
2f928d0
refine eagle ci and make eagle worker's input/output explicit
Nov 24, 2024
30a0bab
fix ci
Nov 24, 2024
8db243d
remove performance test
Nov 24, 2024
a9a99fb
Merge branch 'main' into spec_infer
merrymercy Nov 25, 2024
ba27032
Fix lint
merrymercy Nov 25, 2024
901d6bb
Merge branch 'main' into spec_infer
merrymercy Nov 25, 2024
26a8ae6
update to main
Dec 5, 2024
07dccdd
fix comment
Dec 6, 2024
f19cc85
Fix cuda graph for verify
merrymercy Dec 29, 2024
1ea994c
Merge branch 'main' into spec_infer
merrymercy Dec 29, 2024
f23702a
Simplify llama eagle
merrymercy Dec 29, 2024
c5d32be
Fix llama_eagle
merrymercy Dec 29, 2024
0842911
Fix llama_eagle
merrymercy Dec 29, 2024
61692a8
add llama_eagle
merrymercy Dec 29, 2024
b217660
Merge branch 'main' into spec_infer
merrymercy Dec 29, 2024
fb2e04e
clean flashinfer backend
merrymercy Dec 29, 2024
9ad3491
Fix flashinfer backend
merrymercy Dec 29, 2024
2addaa7
clean flashinfer backend
merrymercy Dec 29, 2024
287b6a7
Clean up attention backend
merrymercy Dec 29, 2024
7dd16a8
update
merrymercy Dec 29, 2024
e2de519
rename
merrymercy Dec 29, 2024
84b2b72
Fix
merrymercy Dec 29, 2024
598eb50
simplify spec info
merrymercy Dec 29, 2024
6457654
simplify spec algorithm
merrymercy Dec 31, 2024
ece223a
Merge branch 'main' into spec_infer
merrymercy Dec 31, 2024
c3742d2
furthur clean up
merrymercy Dec 31, 2024
3cc218f
Merge branch 'main' into spec_infer
merrymercy Dec 31, 2024
f0c1c4b
update cuda graph runner
merrymercy Dec 31, 2024
040c1e5
update
merrymercy Dec 31, 2024
e779c76
update
merrymercy Dec 31, 2024
69daa98
clean up forward_batch_generation
merrymercy Dec 31, 2024
0d1c701
Merge branch 'main' into spec_infer
merrymercy Dec 31, 2024
032eaad
Fix arguments
merrymercy Dec 31, 2024
31343a2
fix port
merrymercy Jan 2, 2025
788c562
simplify speculative_worker
merrymercy Jan 2, 2025
a5fedad
simplify spec algo
merrymercy Jan 2, 2025
42d08db
simplify server args
merrymercy Jan 2, 2025
a238c29
Simplify cuda graph
merrymercy Jan 2, 2025
3a6040b
simplify position handling
merrymercy Jan 2, 2025
2659da9
update
merrymercy Jan 2, 2025
015473c
update
merrymercy Jan 2, 2025
96e3a77
Merge branch 'main' into spec_infer
merrymercy Jan 2, 2025
39b6d4b
Eagle
merrymercy Jan 2, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 37 additions & 0 deletions examples/runtime/engine/EAGLE_offline_batch_inference.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import sglang as sgl


def main():
# Sample prompts.
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]

# Create a sampling params object.
sampling_params = {"temperature": 0, "max_new_tokens": 30}

# Create an LLM.
llm = sgl.Engine(
model_path="meta-llama/Llama-2-7b-chat-hf",
speculative_algorithm="EAGLE",
speculative_draft_model_path="lmzheng/sglang-EAGLE-llama2-chat-7B",
speculative_num_steps=3,
speculative_eagle_topk=4,
speculative_num_draft_tokens=16,
)

outputs = llm.generate(prompts, sampling_params)

# Print the outputs.
for prompt, output in zip(prompts, outputs):
print("===============================")
print(f"Prompt: {prompt}\nGenerated text: {output['text']}")


# The __main__ condition is necessary here because we use "spawn" to create subprocesses
# Spawn starts a fresh program every time, if there is no __main__, it will run into infinite loop to keep spawning processes from sgl.Engine
if __name__ == "__main__":
main()
347 changes: 347 additions & 0 deletions python/sglang/srt/speculative/build_eagle_tree.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,347 @@
import cutex
import torch

# parent_table [bs,topk*depth+)]
# selected_index [bs,draft_token_num-1)]
# verified_seq_len [bs]
# tree_mask [draft_token*(seq_len[0]+draft_token) | draft_token*(seq_len[1]+draft_token) | ..] = [sum(verified_seq_len)*draft_token+bs*draft_token*draft_token]
# positions [bs*draft_token]
# retrive_index [b, draft_token, depth+2]
kernels = cutex.SourceModule(
"""
//cuda
__global__ void build_tree(Tensor<long, 2> parent_list, Tensor<long, 2> selected_index, Tensor<int, 1> verified_seq_len,
Tensor<bool, 1> tree_mask, Tensor<long, 1> positions, Tensor<long, 3> retrive_index, int topk, int depth, int draft_token_num) {
int bid = blockIdx.x;
int tid = threadIdx.x;
if (tid >= draft_token_num){
return;
}
int seq_tree_idx = draft_token_num * draft_token_num * bid;
for(int i=0; i<bid; i++){
seq_tree_idx += verified_seq_len[i] * draft_token_num;
}
int seq_len = verified_seq_len[bid];
int token_tree_idx = seq_tree_idx + (seq_len+draft_token_num)*tid + seq_len + 1;
for(int i=0; i<draft_token_num-1; i++){
tree_mask[token_tree_idx+i] = false;
}
int position = 0;
if (tid==0){
positions[bid*draft_token_num] = seq_len;
retrive_index[bid][0][0] = bid * draft_token_num;
return;
}
int depends_order[10];
int cur_position = tid-1;
while(true){
depends_order[position] = cur_position+1;
position += 1;
tree_mask[token_tree_idx+cur_position] = true;
int parent_tb_idx = selected_index[bid][cur_position]/topk;
if(parent_tb_idx==0){
break;
}
int token_idx = parent_list[bid][parent_tb_idx];
for(cur_position=0; cur_position<draft_token_num;cur_position++){
if(selected_index[bid][cur_position]==token_idx){
break;
}
}
}
positions[bid*draft_token_num+tid] = position + seq_len;
int is_leaf = 0;
for(int i=1;i<draft_token_num;i++){
if(tree_mask[seq_tree_idx + i * (draft_token_num+seq_len) + seq_len + tid])
{
is_leaf ++;
}
}
if(is_leaf==1){
for(int i=0; i<position; i++){
retrive_index[bid][tid][position-i] = depends_order[i] + bid * draft_token_num;
}
retrive_index[bid][tid][0] = bid*draft_token_num;
}
}
//!cuda
""",
float_bits=16, # change to 16 to use half precision as `float` type in the above source code.
boundscheck=True, # turning on for debug and off for performance (to use full threads of a block), default is on.
)


def build_tree_kernel(parent_list, top_score_index, seq_lens, topk, depth, draft_token):
bs = seq_lens.numel()
device = parent_list.device
tree_mask = torch.full(
(torch.sum(seq_lens).item() * draft_token + draft_token * draft_token * bs,),
True,
device=device,
)
retrive_index = torch.full(
(bs, draft_token, depth + 2), -1, device=device, dtype=torch.long
)
positions = torch.empty((bs * draft_token,), device=device, dtype=torch.long)

kernels.build_tree(
parent_list,
top_score_index,
seq_lens.to(torch.int32),
tree_mask,
positions,
retrive_index,
topk,
depth,
draft_token,
grid=(bs, 1, 1),
block=(64, 1, 1),
)
index = retrive_index.sum(dim=-1) != -depth - 2
cum_len = torch.cumsum(torch.sum(index, dim=-1), dim=-1)
retrive_cum_len = torch.zeros(
(cum_len.numel() + 1,), dtype=torch.int32, device="cuda"
)
retrive_cum_len[1:] = cum_len
retrive_index = retrive_index[index]
return tree_mask, positions, retrive_index, retrive_cum_len


if __name__ == "__main__":

def findp(p_i, index, parent_list):
pos = index // 10
index_list = index.tolist()
parent_list = parent_list.tolist()
res = [p_i]
while True:
p = pos[p_i]
if p == 0:
break
token_idx = parent_list[p]
p_i = index_list.index(token_idx)
res.append(p_i)
return res

def create_mask(seq_len, draft_token, index, parent_list, max_depth):
mask = []
positions = []
retrive_index = []
for i, lens in enumerate(seq_len.tolist()):
first_mask = torch.full((lens + draft_token,), True)
first_mask[-(draft_token - 1) :] = False
positions.append(lens)
mask.append(first_mask)
seq_order = []
first_index = torch.Tensor([0] + [-1] * (depth + 1)).cuda().to(torch.long)
r_index = [first_index]
for j in range(draft_token - 1):
mask.append(torch.full((lens + 1,), True))
idx = findp(j, index, parent_list)

seq_order.append(idx)
positions.append(len(idx) + seq_len)
t = torch.full((draft_token - 1,), False)
t[idx] = True
mask.append(t)

for i in range(1, draft_token - 1):
is_leaf = 0
for j in range(draft_token - 1):
if i in seq_order[j]:
is_leaf += 1

if is_leaf == 1:
order_list = [0] + [x + 1 for x in seq_order[i][::-1]]
for _ in range(max_depth + 1 - len(seq_order[i])):
order_list.append(-1)
order = torch.Tensor(order_list).cuda().to(torch.long)
r_index.append(order)
retrive_index.append(torch.stack(r_index))

return (
torch.cat(mask).cuda(),
torch.Tensor(positions).cuda().to(torch.long),
torch.stack(retrive_index),
)

index = (
torch.Tensor(
[
0,
1,
2,
3,
10,
11,
12,
13,
20,
21,
22,
30,
110,
130,
150,
160,
210,
211,
212,
213,
214,
215,
216,
217,
218,
219,
220,
230,
310,
311,
312,
313,
314,
315,
316,
317,
320,
321,
322,
330,
360,
380,
390,
410,
411,
412,
413,
414,
415,
416,
417,
418,
419,
420,
421,
422,
423,
430,
431,
440,
441,
460,
470,
]
)
.to(torch.long)
.cuda()
)

parent_list = (
torch.Tensor(
[
-1,
0,
1,
2,
3,
4,
5,
6,
7,
8,
9,
10,
11,
12,
20,
30,
21,
13,
22,
40,
23,
110,
130,
160,
150,
190,
120,
111,
121,
200,
180,
210,
211,
212,
213,
214,
215,
216,
220,
230,
217,
310,
311,
312,
313,
320,
314,
321,
315,
316,
317,
]
)
.to(torch.long)
.cuda()
)

verified_seq_len = torch.Tensor([47]).to(torch.long).cuda()
bs = verified_seq_len.shape[0]
topk = 10
depth = 5 # depth <= 10
draft_token = 64

tree_mask = torch.full(
(
torch.sum(verified_seq_len).item() * draft_token
+ draft_token * draft_token * bs,
),
True,
).cuda()
retrive_index = torch.full(
(bs, draft_token, depth + 2), -1, device="cuda", dtype=torch.long
)
positions = torch.empty((bs * draft_token,), device="cuda", dtype=torch.long)

kernels.build_tree(
parent_list.unsqueeze(0),
index.unsqueeze(0),
verified_seq_len,
tree_mask,
positions,
retrive_index,
topk,
depth,
draft_token,
grid=(bs, 1, 1),
block=(64, 1, 1),
)
retrive_index = retrive_index[retrive_index.sum(dim=-1) != -depth - 2]

c_mask, c_positions, c_retive_index = create_mask(
verified_seq_len, draft_token, index, parent_list, depth
)

assert torch.allclose(tree_mask, c_mask), "tree mask has error."
assert torch.allclose(positions, c_positions), "positions has error."
assert torch.allclose(retrive_index, c_retive_index), "retrive_index has error."
Loading
Loading