This repository has been archived by the owner on Jan 6, 2025. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmodel.py
90 lines (83 loc) · 3.6 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import time
from typing import List
from omegaconf import DictConfig
from vllm import LLM, SamplingParams # 确保 vllm 库已正确安装
from typing_extensions import override
from redis_util.interface import ModelInterface
from transformers import AutoTokenizer
from pydantic import BaseModel
from vllm_protocol_pb2 import VLLMProtocol
# class VLLMProtocol(BaseModel):
# texts: List[str]
# output_tokens: int
# input_tokens: int
class VLLMModel(ModelInterface[str, VLLMProtocol]):
"""
VLLMModel 类根据给定的配置初始化一个 vllm.LLM 模型,并根据配置中的参数接受 prompt 和推理
"""
def __init__(self, config: DictConfig):
"""
初始化 VLLMModel 类的实例。
:param config: 包含模型配置的 OmegaConf 配置对象
"""
self.tokenizer = AutoTokenizer.from_pretrained(
config.llm.model,
trust_remote_code=True,
use_fast=False,
)
self.llm = LLM(model=config.llm.model,
tensor_parallel_size=config.llm.tensor_parallel_size,
gpu_memory_utilization=config.llm.gpu_memory_utilization)
# 处理 stop 参数, 将其中的 <|EOS_TOKEN|> 替换为实际的 EOS 标记(根据eos_token_id来确定)
config.llm.stop = self.parse_stop(config.llm.stop)
self.sampling_params = SamplingParams(
n=config.llm.n,
max_tokens=config.llm.max_tokens,
temperature=config.llm.temperature,
top_k=config.llm.top_k,
top_p=config.llm.top_p,
stop=config.llm.stop,
)
print(f"Initialized VLLMModel with model {config.llm.model}")
print(f"Initialized VLLMModel with sampling_params {self.sampling_params}")
@override
def process(self, batch_tasks: List[str]) -> List[VLLMProtocol]:
responses = self.llm.generate(batch_tasks, sampling_params=self.sampling_params)
return [VLLMProtocol(texts=[o.text for o in rsp.outputs],
output_tokens=sum(len(o.token_ids) for o in rsp.outputs),
input_tokens=len(rsp.prompt_token_ids)) for rsp in responses]
@override
@staticmethod
def encode_outputs(value: list[VLLMProtocol]) -> list[bytes]:
return [p.SerializeToString() for p in value]
@override
@staticmethod
def decode_outputs(value: list[bytes]) -> list[VLLMProtocol]:
ret = [VLLMProtocol() for _ in value]
for i, v in enumerate(value):
ret[i].ParseFromString(v)
return ret
def parse_stop(self, stop):
if isinstance(stop, str):
stop = [stop]
if '<|EOS_TOKEN|>' in stop:
eos_id = self.tokenizer.eos_token_id
stop.remove('<|EOS_TOKEN|>')
if isinstance(eos_id, int):
stop.append(self.tokenizer.decode(eos_id))
elif isinstance(eos_id, list):
stop.extend([self.tokenizer.decode(e) for e in eos_id])
else:
raise ValueError(f"Invalid eos_id type: {type(eos_id)}")
return stop
class DummyModel(ModelInterface):
"""
DummyModel 类用于模拟模型的处理过程。它接受一个字符串列表作为输入,并返回一个字符串列表作为输出。
"""
def __init__(self, config: DictConfig):
self.n = config.llm.n
@override
def process(self, batch_tasks: List[str]) -> List[List[str]]:
# 模拟超线性的处理速度
time.sleep(1 * len(batch_tasks) ** 0.8)
return [["response" + str(i) for i in range(self.n)]] * len(batch_tasks)