Skip to content

Commit

Permalink
feat!: llama.cpp - unified support for tools + refactoring (#1357)
Browse files Browse the repository at this point in the history
* llama.cpp - support for tools

* fix for tools provided in run

* improvements to warm_up and test serde
  • Loading branch information
anakin87 authored Feb 7, 2025
1 parent fd5def8 commit 909980b
Show file tree
Hide file tree
Showing 3 changed files with 323 additions and 159 deletions.
1 change: 1 addition & 0 deletions integrations/llama_cpp/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ dependencies = [
"pytest-rerunfailures",
"haystack-pydoc-tools",
"transformers[sentencepiece]",
"jsonschema", # needed for Tool
]
[tool.hatch.envs.default.scripts]
test = "pytest {args:tests}"
Expand Down
Original file line number Diff line number Diff line change
@@ -1,31 +1,64 @@
import json
import logging
from typing import Any, Dict, List, Optional

from haystack import component
from haystack.dataclasses import ChatMessage
from llama_cpp import Llama
from haystack import component, default_from_dict, default_to_dict
from haystack.dataclasses import ChatMessage, ToolCall
from haystack.tools import Tool, _check_duplicate_tool_names, deserialize_tools_inplace
from llama_cpp import ChatCompletionResponseChoice, CreateChatCompletionResponse, Llama
from llama_cpp.llama_tokenizer import LlamaHFTokenizer

logger = logging.getLogger(__name__)


def _convert_message_to_llamacpp_format(message: ChatMessage) -> Dict[str, str]:
def _convert_message_to_llamacpp_format(message: ChatMessage) -> Dict[str, Any]:
"""
Convert a message to the format expected by Llama.cpp.
:returns: A dictionary with the following keys:
- `role`
- `content`
- `name` (optional)
Convert a ChatMessage to the format expected by Ollama Chat API.
"""
formatted_msg = {"role": message.role.value, "content": message.text}
if message.name:
formatted_msg["name"] = message.name

if formatted_msg["role"] == "tool":
formatted_msg["name"] = message.tool_call_result.origin.tool_name
formatted_msg["content"] = message.tool_call_result.result

return formatted_msg
text_contents = message.texts
tool_calls = message.tool_calls
tool_call_results = message.tool_call_results

if not text_contents and not tool_calls and not tool_call_results:
msg = "A `ChatMessage` must contain at least one `TextContent`, `ToolCall`, or `ToolCallResult`."
raise ValueError(msg)
elif len(text_contents) + len(tool_call_results) > 1:
msg = "A `ChatMessage` can only contain one `TextContent` or one `ToolCallResult`."
raise ValueError(msg)

role = message._role.value
if role == "tool":
role = "function"

llamacpp_msg: Dict[str, Any] = {"role": role}

if tool_call_results:
if tool_call_results[0].origin.id is None:
msg = "`ToolCall` must have a non-null `id` attribute to be used with llama.cpp."
raise ValueError(msg)
llamacpp_msg["content"] = tool_call_results[0].result
llamacpp_msg["tool_call_id"] = tool_call_results[0].origin.id
# Llama.cpp does not provide a way to communicate errors in tool invocations, so we ignore the error field
return llamacpp_msg

if text_contents:
llamacpp_msg["content"] = text_contents[0]
if tool_calls:
llamacpp_tool_calls = []
for tc in tool_calls:
if tc.id is None:
msg = "`ToolCall` must have a non-null `id` attribute to be used with llama.cpp."
raise ValueError(msg)
llamacpp_tool_calls.append(
{
"id": tc.id,
"type": "function",
# We disable ensure_ascii so special chars like emojis are not converted
"function": {"name": tc.tool_name, "arguments": json.dumps(tc.arguments, ensure_ascii=False)},
}
)
llamacpp_msg["tool_calls"] = llamacpp_tool_calls
return llamacpp_msg


@component
Expand Down Expand Up @@ -54,6 +87,8 @@ def __init__(
n_batch: Optional[int] = 512,
model_kwargs: Optional[Dict[str, Any]] = None,
generation_kwargs: Optional[Dict[str, Any]] = None,
*,
tools: Optional[List[Tool]] = None,
):
"""
:param model: The path of a quantized model for text generation, for example, "zephyr-7b-beta.Q4_0.gguf".
Expand All @@ -68,34 +103,76 @@ def __init__(
:param generation_kwargs: A dictionary containing keyword arguments to customize text generation.
For more information on the available kwargs, see
[llama.cpp documentation](https://llama-cpp-python.readthedocs.io/en/latest/api-reference/#llama_cpp.Llama.create_chat_completion).
:param tools:
A list of tools for which the model can prepare calls.
"""

model_kwargs = model_kwargs or {}
generation_kwargs = generation_kwargs or {}

if "hf_tokenizer_path" in model_kwargs:
tokenizer = LlamaHFTokenizer.from_pretrained(model_kwargs["hf_tokenizer_path"])
model_kwargs["tokenizer"] = tokenizer

# check if the model_kwargs contain the essential parameters
# otherwise, populate them with values from init parameters
model_kwargs.setdefault("model_path", model)
model_kwargs.setdefault("n_ctx", n_ctx)
model_kwargs.setdefault("n_batch", n_batch)

_check_duplicate_tool_names(tools)

self.model_path = model
self.n_ctx = n_ctx
self.n_batch = n_batch
self.model_kwargs = model_kwargs
self.generation_kwargs = generation_kwargs
self.model = None
self._model = None
self.tools = tools

def warm_up(self):
if self.model is None:
self.model = Llama(**self.model_kwargs)
if "hf_tokenizer_path" in self.model_kwargs and "tokenizer" not in self.model_kwargs:
tokenizer = LlamaHFTokenizer.from_pretrained(self.model_kwargs["hf_tokenizer_path"])
self.model_kwargs["tokenizer"] = tokenizer

if self._model is None:
self._model = Llama(**self.model_kwargs)

def to_dict(self) -> Dict[str, Any]:
"""
Serializes the component to a dictionary.
:returns:
Dictionary with serialized data.
"""
serialized_tools = [tool.to_dict() for tool in self.tools] if self.tools else None
return default_to_dict(
self,
model=self.model_path,
n_ctx=self.n_ctx,
n_batch=self.n_batch,
model_kwargs=self.model_kwargs,
generation_kwargs=self.generation_kwargs,
tools=serialized_tools,
)

@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "LlamaCppChatGenerator":
"""
Deserializes the component from a dictionary.
:param data:
Dictionary to deserialize from.
:returns:
Deserialized component.
"""
deserialize_tools_inplace(data["init_parameters"], key="tools")
return default_from_dict(cls, data)

@component.output_types(replies=List[ChatMessage])
def run(self, messages: List[ChatMessage], generation_kwargs: Optional[Dict[str, Any]] = None):
def run(
self,
messages: List[ChatMessage],
generation_kwargs: Optional[Dict[str, Any]] = None,
*,
tools: Optional[List[Tool]] = None,
):
"""
Run the text generation model on the given list of ChatMessages.
Expand All @@ -104,10 +181,13 @@ def run(self, messages: List[ChatMessage], generation_kwargs: Optional[Dict[str,
:param generation_kwargs: A dictionary containing keyword arguments to customize text generation.
For more information on the available kwargs, see
[llama.cpp documentation](https://llama-cpp-python.readthedocs.io/en/latest/api-reference/#llama_cpp.Llama.create_chat_completion).
:param tools:
A list of tools for which the model can prepare calls. If set, it will override the `tools` parameter set
during component initialization.
:returns: A dictionary with the following keys:
- `replies`: The responses from the model
"""
if self.model is None:
if self._model is None:
error_msg = "The model has not been loaded. Please call warm_up() before running."
raise RuntimeError(error_msg)

Expand All @@ -117,28 +197,55 @@ def run(self, messages: List[ChatMessage], generation_kwargs: Optional[Dict[str,
updated_generation_kwargs = {**self.generation_kwargs, **(generation_kwargs or {})}
formatted_messages = [_convert_message_to_llamacpp_format(msg) for msg in messages]

response = self.model.create_chat_completion(messages=formatted_messages, **updated_generation_kwargs)
tools = tools or self.tools
llamacpp_tools = {}
if tools:
tool_definitions = [{"type": "function", "function": {**t.tool_spec}} for t in tools]
llamacpp_tools = {"tools": tool_definitions}

replies = []
response = self._model.create_chat_completion(
messages=formatted_messages, **updated_generation_kwargs, **llamacpp_tools
)

replies = []
for choice in response["choices"]:
meta = {
"response_id": response["id"],
"model": response["model"],
"created": response["created"],
"index": choice["index"],
"finish_reason": choice["finish_reason"],
"usage": response["usage"],
}

name = None
tool_calls = choice.get("message", {}).get("tool_calls", [])
if tool_calls:
meta["tool_calls"] = tool_calls
name = tool_calls[0]["function"]["name"]

reply = ChatMessage.from_assistant(choice["message"]["content"], meta=meta)
reply._name = name or None
replies.append(reply)
chat_message = self._convert_chat_completion_choice_to_chat_message(choice, response)
replies.append(chat_message)

return {"replies": replies}

@staticmethod
def _convert_chat_completion_choice_to_chat_message(
choice: ChatCompletionResponseChoice, response: CreateChatCompletionResponse
) -> ChatMessage:
llamacpp_message = choice["message"]
text_content = llamacpp_message.get("content", "") or None
tool_calls = []

if llamacpp_tool_calls := llamacpp_message.get("tool_calls", []):
for llamacpp_tc in llamacpp_tool_calls:
arguments_str = llamacpp_tc["function"]["arguments"]
try:
arguments = json.loads(arguments_str)
tool_calls.append(
ToolCall(id=llamacpp_tc["id"], tool_name=llamacpp_tc["function"]["name"], arguments=arguments)
)
except json.JSONDecodeError:
logger.warning(
"Llama.cpp returned a malformed JSON string for tool call arguments. This tool call "
"will be skipped. Tool call ID: %s, Tool name: %s, Arguments: %s",
llamacpp_tc["id"],
llamacpp_tc["function"]["name"],
arguments_str,
)

meta = {
"response_id": response["id"],
"model": response["model"],
"created": response["created"],
"index": choice["index"],
"finish_reason": choice["finish_reason"],
"usage": response["usage"],
}

return ChatMessage.from_assistant(text=text_content, tool_calls=tool_calls, meta=meta)
Loading

0 comments on commit 909980b

Please sign in to comment.