Skip to content

Commit

Permalink
Merge pull request #4173 from mMrBun/main
Browse files Browse the repository at this point in the history
Implemented the tool_formatter and tool_extractor for glm4 and Qwen2 tool_format
  • Loading branch information
hiyouga authored Jun 18, 2024
2 parents 9ab0401 + 950e360 commit c0ca425
Show file tree
Hide file tree
Showing 3 changed files with 63 additions and 16 deletions.
13 changes: 8 additions & 5 deletions src/llamafactory/api/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,11 +164,14 @@ async def create_chat_completion_response(
else:
result = response.response_text

if isinstance(result, tuple):
name, arguments = result
function = Function(name=name, arguments=arguments)
tool_call = FunctionCall(id="call_{}".format(uuid.uuid4().hex), function=function)
response_message = ChatCompletionMessage(role=Role.ASSISTANT, tool_calls=[tool_call])
if isinstance(result, list):
tool_calls = []
for tool in result:
name, arguments = tool
function = Function(name=name, arguments=arguments)
tool_call = FunctionCall(id="call_{}".format(uuid.uuid4().hex), function=function)
tool_calls.append(tool_call)
response_message = ChatCompletionMessage(role=Role.ASSISTANT, tool_calls=tool_calls)
finish_reason = Finish.TOOL
else:
response_message = ChatCompletionMessage(role=Role.ASSISTANT, content=result)
Expand Down
63 changes: 53 additions & 10 deletions src/llamafactory/data/formatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,17 @@
)


GLM4_TOOL_SUFFIX_PROMPT = (
"在调用上述函数时,请使用 Json 格式表示调用的参数。"
)

GLM4_TOOL_PROMPT = (
"你是一个名为 GLM-4 的人工智能助手。你是基于智谱AI训练的语言模型 GLM-4 模型开发的,你的任务是针对用户的问题和要求提供适当的答复和支持,"
"{tool_text}"

)


def default_tool_formatter(tools: List[Dict[str, Any]]) -> str:
tool_text = ""
tool_names = []
Expand Down Expand Up @@ -67,31 +78,59 @@ def default_tool_formatter(tools: List[Dict[str, Any]]) -> str:
)


def default_tool_extractor(content: str) -> Union[str, Tuple[str, str]]:
regex = re.compile(r"Action:\s*([a-zA-Z0-9_]+).*?Action Input:\s*(.*)", re.DOTALL)
action_match = re.search(regex, content)
def glm4_tool_formatter(tools: List[Dict[str, Any]]) -> str:
tool_text = ""
for tool in tools:
tool_name = tool["name"]
tool_text += f"\n\n## {tool_name}\n\n{json.dumps(tool, ensure_ascii=False, indent=4)}\n{GLM4_TOOL_SUFFIX_PROMPT}"
return GLM4_TOOL_PROMPT.format(tool_text=tool_text)


def default_tool_extractor(content: str) -> Union[str, List[Tuple[str, str]]]:
regex = re.compile(r"Action:\s*([a-zA-Z0-9_]+)\s*Action Input:\s*({.*?})(?=\nAction:|\Z)", re.DOTALL)
action_match = re.findall(regex, content)
if not action_match:
return content

tool_name = action_match.group(1).strip()
tool_input = action_match.group(2).strip().strip('"').strip("```")
results = []

for match in action_match:
tool_name, tool_input = match
tool_name = tool_name.strip()
tool_input = tool_input.strip().strip('"').strip("```")

try:
arguments = json.loads(tool_input)
results.append((tool_name, json.dumps(arguments, ensure_ascii=False)))
except json.JSONDecodeError:
return content

return results


def glm4_tool_extractor(content: str) -> Union[str, List[Tuple[str, str]]]:
lines = content.strip().split("\n")
if len(lines) != 2:
return content
tool_name = lines[0].strip()
tool_input = lines[1].strip()
try:
arguments = json.loads(tool_input)
except json.JSONDecodeError:
return content
return [(tool_name, json.dumps(arguments, ensure_ascii=False))]

return tool_name, json.dumps(arguments, ensure_ascii=False)



@dataclass
class Formatter(ABC):
slots: SLOTS = field(default_factory=list)
tool_format: Optional[Literal["default"]] = None
tool_format: Optional[Literal["default", "glm4"]] = None

@abstractmethod
def apply(self, **kwargs) -> SLOTS: ...

def extract(self, content: str) -> Union[str, Tuple[str, str]]:
def extract(self, content: str) -> Union[str, List[Tuple[str, str]]]:
raise NotImplementedError


Expand Down Expand Up @@ -189,13 +228,17 @@ def apply(self, **kwargs) -> SLOTS:

if self.tool_format == "default":
return [default_tool_formatter(tools)]
elif self.tool_format == "glm4":
return [glm4_tool_formatter(tools)]
else:
raise NotImplementedError
except Exception:
return [""]

def extract(self, content: str) -> Union[str, Tuple[str, str]]:
def extract(self, content: str) -> Union[str, List[Tuple[str, str]]]:
if self.tool_format == "default":
return default_tool_extractor(content)
elif self.tool_format == "glm4":
return glm4_tool_extractor(content)
else:
raise NotImplementedError
3 changes: 2 additions & 1 deletion src/llamafactory/data/template.py
Original file line number Diff line number Diff line change
Expand Up @@ -676,9 +676,10 @@ def get_template_and_fix_tokenizer(
name="glm4",
format_user=StringFormatter(slots=["<|user|>\n{{content}}<|assistant|>"]),
format_assistant=StringFormatter(slots=["\n{{content}}"]),
format_system=StringFormatter(slots=["[gMASK]<sop>{{content}}"]),
format_system=StringFormatter(slots=["[gMASK]<sop><|system|>\n{{content}}"]),
format_function=FunctionFormatter(slots=["{{name}}\n{{arguments}}"]),
format_observation=StringFormatter(slots=["<|observation|>\n{{content}}<|assistant|>"]),
format_tools=ToolFormatter(tool_format="glm4"),
stop_words=["<|user|>", "<|observation|>"],
efficient_eos=True,
force_system=True,
Expand Down

0 comments on commit c0ca425

Please sign in to comment.