Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BFCL Chore] Add @final and @overrides Decorators to Class Methods in Model Handler #790

Merged
merged 5 commits into from
Nov 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
)
from bfcl.model_handler.model_style import ModelStyle
from bfcl.utils import load_file, make_json_serializable, sort_key
from overrides import final


class BaseHandler:
Expand All @@ -31,6 +32,7 @@ def __init__(self, model_name, temperature) -> None:
self.temperature = temperature
self.is_fc_model = False # Whether the model is a function calling model

@final
def inference(self, test_entry: dict, include_input_log: bool, include_state_log: bool):
# This method is used to retrive model response for each model.

Expand All @@ -52,6 +54,7 @@ def inference(self, test_entry: dict, include_input_log: bool, include_state_log
test_entry, include_input_log
)

@final
def inference_multi_turn_FC(
self, test_entry: dict, include_input_log: bool, include_state_log: bool
) -> tuple[list[list], dict]:
Expand Down Expand Up @@ -296,6 +299,7 @@ def inference_multi_turn_FC(

return all_model_response, metadata

@final
def inference_multi_turn_prompting(
self, test_entry: dict, include_input_log: bool, include_state_log: bool
) -> tuple[list[list], dict]:
Expand Down Expand Up @@ -537,6 +541,7 @@ def inference_multi_turn_prompting(

return all_model_response, metadata

@final
def inference_single_turn_FC(
self, test_entry: dict, include_input_log: bool
) -> tuple[any, dict]:
Expand Down Expand Up @@ -569,6 +574,7 @@ def inference_single_turn_FC(

return model_response_data["model_responses"], metadata

@final
def inference_single_turn_prompting(
self, test_entry: dict, include_input_log: bool
) -> tuple[any, dict]:
Expand Down Expand Up @@ -607,6 +613,7 @@ def decode_execute(self, result):
# This method takes raw model output and convert it to standard execute checker input.
raise NotImplementedError

@final
def write(self, result, result_dir, update_mode=False):
model_name_dir = self.model_name.replace("/", "_")
model_result_dir = result_dir / model_name_dir
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import subprocess
import threading
import time
import json
from concurrent.futures import ThreadPoolExecutor

import requests
Expand All @@ -16,17 +15,19 @@
system_prompt_pre_processing_chat_model,
)
from openai import OpenAI
from overrides import EnforceOverrides, final
from tqdm import tqdm


class OSSHandler(BaseHandler):
class OSSHandler(BaseHandler, EnforceOverrides):
def __init__(self, model_name, temperature, dtype="bfloat16") -> None:
super().__init__(model_name, temperature)
self.model_name_huggingface = model_name
self.model_style = ModelStyle.OSSMODEL
self.dtype = dtype
self.client = OpenAI(base_url=f"http://localhost:{VLLM_PORT}/v1", api_key="EMPTY")

@final
def inference(self, test_entry: dict, include_input_log: bool, include_state_log: bool):
"""
OSS models have a different inference method.
Expand All @@ -44,6 +45,7 @@ def decode_ast(self, result, language="Python"):
def decode_execute(self, result):
return default_decode_execute_prompting(result)

@final
def batch_inference(
self,
test_entries: list[dict],
Expand Down Expand Up @@ -218,6 +220,7 @@ def log_subprocess_output(pipe, stop_event):
stdout_thread.join()
stderr_thread.join()

@final
def _multi_threaded_inference(self, test_case, include_input_log: bool, include_state_log: bool):
"""
This is a wrapper function to make sure that, if an error occurs during inference, the process does not stop.
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from bfcl.model_handler.oss_model.base_oss_handler import OSSHandler
from overrides import overrides


class DeepseekHandler(OSSHandler):
Expand All @@ -9,6 +10,7 @@ class DeepseekHandler(OSSHandler):
def __init__(self, model_name, temperature) -> None:
super().__init__(model_name, temperature)

@overrides
def decode_ast(self, result, language="Python"):
result = result.strip()
if result.startswith("```json"):
Expand All @@ -17,13 +19,15 @@ def decode_ast(self, result, language="Python"):
result = result[len("```python"):]
return super().decode_ast(result, language)

@overrides
def decode_execute(self, result):
if result.startswith("```json"):
result = result[len("```json"):]
if result.startswith("```python"):
result = result[len("```python"):]
return super().decode_execute(result)

@overrides
def _format_prompt(self, messages, function):
"""
"bos_token": {
Expand Down Expand Up @@ -58,6 +62,7 @@ def _format_prompt(self, messages, function):

return formatted_prompt

@overrides
def _add_execution_results_prompting(
self, inference_data: dict, execution_results: list[str], model_response_data: dict
) -> dict:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
convert_to_function_call,
func_doc_language_specific_pre_processing,
)
from overrides import overrides


class DeepseekCoderHandler(OSSHandler):
Expand All @@ -18,12 +19,15 @@ class DeepseekCoderHandler(OSSHandler):
def __init__(self, model_name, temperature) -> None:
super().__init__(model_name, temperature)

@overrides
def decode_ast(self, result, language="Python"):
return result

@overrides
def decode_execute(self, result):
return convert_to_function_call(result)

@overrides
def _format_prompt(self, messages, function):
"""
"bos_token": {
Expand Down Expand Up @@ -105,6 +109,7 @@ def _format_prompt(self, messages, function):

return formatted_prompt

@overrides
def _pre_query_processing_prompting(self, test_entry: dict) -> dict:
functions: list = test_entry["function"]
test_category: str = test_entry["id"].rsplit("_", 1)[0]
Expand All @@ -131,6 +136,7 @@ def _pre_query_processing_prompting(self, test_entry: dict) -> dict:

return {"message": [], "function": functions}

@overrides
def _parse_query_response_prompting(self, api_response: any) -> dict:
model_responses = api_response.choices[0].text
extracted_tool_calls = self.extract_tool_calls(model_responses)
Expand Down Expand Up @@ -158,6 +164,7 @@ def _parse_query_response_prompting(self, api_response: any) -> dict:
"output_token": api_response.usage.completion_tokens,
}

@overrides
def _add_assistant_message_prompting(
self, inference_data: dict, model_response_data: dict
) -> dict:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,16 +1,18 @@
from bfcl.model_handler.oss_model.base_oss_handler import OSSHandler
from bfcl.model_handler.utils import (
combine_consecutive_user_prompts,
convert_system_prompt_into_user_prompt,
func_doc_language_specific_pre_processing,
system_prompt_pre_processing_chat_model,
convert_system_prompt_into_user_prompt,
combine_consecutive_user_prompts,
)
from overrides import overrides


class GemmaHandler(OSSHandler):
def __init__(self, model_name, temperature) -> None:
super().__init__(model_name, temperature)

@overrides
def _format_prompt(self, messages, function):
"""
"bos_token": "<bos>",
Expand All @@ -25,6 +27,7 @@ def _format_prompt(self, messages, function):

return formatted_prompt

@overrides
def _pre_query_processing_prompting(self, test_entry: dict) -> dict:
functions: list = test_entry["function"]
test_category: str = test_entry["id"].rsplit("_", 1)[0]
Expand Down
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
import json

from bfcl.model_handler.oss_model.base_oss_handler import OSSHandler
from bfcl.model_handler.utils import convert_to_function_call
import json
from overrides import overrides


class GlaiveHandler(OSSHandler):
def __init__(self, model_name, temperature) -> None:
super().__init__(model_name, temperature)

@overrides
def decode_ast(self, result, language="Python"):
function_call = result.split("<functioncall>")[-1]
function_call = function_call.replace("'", "")
Expand All @@ -22,6 +25,7 @@ def decode_ast(self, result, language="Python"):
decoded_result = [{decoded_function["name"]: decoded_function["arguments"]}]
return decoded_result

@overrides
def decode_execute(self, result):
function_call = result.split("<functioncall>")[-1]
function_call = function_call.replace("'", "")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,15 @@
convert_to_tool,
func_doc_language_specific_pre_processing,
)
from overrides import overrides


class GLMHandler(OSSHandler):
def __init__(self, model_name, temperature) -> None:
super().__init__(model_name, temperature)
self.stop_token_ids = [151329, 151336, 151338]

@overrides
def _format_prompt(self, messages, function):
"""
"chat_template": "[gMASK]<sop>{% for item in messages %}{% if item['tools'] is defined %}<|system|>\n你是一个名为 ChatGLM 的人工智能助手。你是基于智谱AI训练的语言模型 GLM-4 模型开发的,你的任务是针对用户的问题和要求提供适当的答复和支持。\n\n# 可用工具{% set tools = item['tools'] %}{% for tool in tools %}{% if tool['type'] == 'function' %}\n\n## {{ tool['function']['name'] }}\n\n{{ tool['function'] | tojson(indent=4) }}\n在调用上述函数时,请使用 Json 格式表示调用的参数。{% elif tool['type'] == 'python' %}\n\n## python\n\n当你向 `python` 发送包含 Python 代码的消息时,该代码将会在一个有状态的 Jupyter notebook 环境中执行。\n`python` 返回代码执行的输出,或在执行 60 秒后返回超时。\n`/mnt/data` 将会持久化存储你的文件。在此会话中,`python` 无法访问互联网。不要使用 `python` 进行任何网络请求或者在线 API 调用,这些在线内容的访问将不会成功。{% elif tool['type'] == 'simple_browser' %}\n\n## simple_browser\n\n你可以使用 `simple_browser` 工具。该工具支持以下函数:\n`search(query: str, recency_days: int)`:使用搜索引擎进行查询并显示结果,可以使用 `recency_days` 参数控制搜索内容的时效性。\n`mclick(ids: list[int])`:获取一系列指定 id 的页面内容。每次调用时,须选择3-10个页面。选择多个角度的页面,同时尽可能选择可信任的信息来源。考虑到部分页面是无法加载的,你也可以多打开一些可能有用的页面而不用担心内容过多。\n`open_url(url: str)`:打开指定的 URL。\n\n使用 `【{引用 id}†{引用文本}】` 来引用内容。\n\n操作步骤:1. 使用 `search` 来获得信息列表; 2. 使用 `mclick` 来获取指定 ID 页面的内容; 3. 根据获得的内容进行回复。在回复中应当引用信息来源。\n 如果用户提供了 URL,也可以用 `open_url` 直接打开页面。\n如果初次搜索结果没有找到合适的信息,也可以再次使用 `search` 进行搜索。{% elif tool['type'] == 'cogview' %}\n\n## cogview\n\n如果用户的请求中包含了对图像的描述,你可以使用 `cogview` 来生成图像并展示给用户。你需要向 `cogview` 发送图像描述,规则:\n- 发送给 `cogview` 的消息必须使用英语。用户的中文描述必须完全翻译为英语。\n- 应当尽可能详细地描述图像生成的需求,需求描述约 100 英文单词。\n- 保持用户原始描述的意图。不要虚构内容或者没见过的人物。\n- 如无特殊说明,所在地为中国,持有中国立场并遵循中国社会主义价值观。{% endif %}{% endfor %}{% endif %}{% if item['content'] %}<|{{ item['role'] }}|>{{ item['metadata'] }}\n{{ item['content'] }}{% endif %}{% endfor %}{% if add_generation_prompt %}<|assistant|>{% endif %}",
Expand All @@ -34,6 +36,7 @@ def _format_prompt(self, messages, function):

return formatted_prompt

@overrides
def _pre_query_processing_prompting(self, test_entry: dict) -> dict:
functions: list = test_entry["function"]
test_category: str = test_entry["id"].rsplit("_", 1)[0]
Expand All @@ -44,6 +47,7 @@ def _pre_query_processing_prompting(self, test_entry: dict) -> dict:

return {"message": [], "function": functions}

@overrides
def decode_ast(self, result, language="Python"):
args = result.split("\n")
if len(args) == 1:
Expand All @@ -52,6 +56,7 @@ def decode_ast(self, result, language="Python"):
func = [{args[0]: json.loads(args[1])}]
return func

@overrides
def decode_execute(self, result):
args = result.split("\n")
if len(args) == 1:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,14 @@
convert_to_tool,
func_doc_language_specific_pre_processing,
)
from overrides import overrides


class GraniteHandler(OSSHandler):
def __init__(self, model_name, temperature) -> None:
super().__init__(model_name, temperature)

@overrides
def _format_prompt(self, messages, function):
"""
"chat_template": "{% set function_str = messages.get('functions_str', {}) %}\n{% set query = messages['query'] %}\n{% set sys_prompt = 'You are a helpful assistant with access to the following function calls. Your task is to produce a sequence of function calls necessary to generate response to the user utterance. Use the following function calls as required. ' %}\n{% set funcstr = function_str|join('\n') %}\n{{ 'SYSTEM: ' + sys_prompt + '\n<|function_call_library|>\n' + funcstr + '\n\nIf none of the functions are relevant or the given question lacks the parameters required by the function, please output \"<function_call> {\"name\": \"no_function\", \"arguments\": {}}\".\n\nUSER: ' + query}}\n{% if add_generation_prompt %}\n{{ 'ASSISTANT:' }}{% endif %}",
Expand Down Expand Up @@ -40,6 +42,7 @@ def _format_prompt(self, messages, function):

return prompt_str

@overrides
def _pre_query_processing_prompting(self, test_entry: dict) -> dict:
functions: list = test_entry["function"]
test_category: str = test_entry["id"].rsplit("_", 1)[0]
Expand All @@ -50,6 +53,7 @@ def _pre_query_processing_prompting(self, test_entry: dict) -> dict:

return {"message": [], "function": functions}

@overrides
def decode_ast(self, result, language="Python"):
decoded_outputs = []
result = [
Expand All @@ -75,6 +79,7 @@ def decode_ast(self, result, language="Python"):

return decoded_outputs

@overrides
def decode_execute(self, result):
decoded_outputs = []
result = [
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
convert_system_prompt_into_user_prompt,
func_doc_language_specific_pre_processing,
)
from overrides import overrides

TASK_INSTRUCTION = """You are a tool calling assistant. In order to complete the user's request, you need to select one or more appropriate tools from the following tools and fill in the correct values for the tool parameters. Your specific tasks are:
1. Make one or more function/tool calls to meet the request based on the question.
Expand All @@ -28,6 +29,7 @@ class HammerHandler(OSSHandler):
def __init__(self, model_name, temperature) -> None:
super().__init__(model_name, temperature)

@overrides
def _format_prompt(self, messages, function):
"""
"chat_template": "{% set system_message = 'You are a helpful assistant.' %}{% if messages[0]['role'] == 'system' %}{% set system_message = messages[0]['content'] %}{% endif %}{% if system_message is defined %}{{ '<|im_start|>system\n' + system_message + '<|im_end|>\n' }}{% endif %}{% for message in messages %}{% set content = message['content'] %}{% if message['role'] == 'user' %}{{ '<|im_start|>user\n' + content + '<|im_end|>\n<|im_start|>assistant\n' }}{% elif message['role'] == 'assistant' %}{{ content + '<|im_end|>' + '\n' }}{% endif %}{% endfor %}",
Expand Down Expand Up @@ -86,6 +88,7 @@ def convert_to_format_tool(tools):

return f"<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\n{content}<|im_end|>\n<|im_start|>assistant\n"

@overrides
def decode_ast(self, result, language="Python"):
result = result.replace("```", "")
try:
Expand Down Expand Up @@ -126,6 +129,7 @@ def xlam_json_to_python_tool_calls(tool_calls):

return python_format

@overrides
def decode_execute(self, result):
result = result.replace("```", "")
try:
Expand All @@ -142,6 +146,7 @@ def decode_execute(self, result):
function_call = self.xlam_json_to_python_tool_calls(tool_calls)
return function_call

@overrides
def _pre_query_processing_prompting(self, test_entry: dict) -> dict:
functions: list = test_entry["function"]
test_category: str = test_entry["id"].rsplit("_", 1)[0]
Expand Down
Loading