Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add custom params to VertexAIGeminiGenerator and VertexAIGeminiChatGenerator #1100

Merged
merged 7 commits into from
Sep 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
HarmCategory,
Part,
Tool,
ToolConfig,
)

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -54,6 +55,8 @@ def __init__(
generation_config: Optional[Union[GenerationConfig, Dict[str, Any]]] = None,
safety_settings: Optional[Dict[HarmCategory, HarmBlockThreshold]] = None,
tools: Optional[List[Tool]] = None,
tool_config: Optional[ToolConfig] = None,
system_instruction: Optional[Union[str, ByteStream, Part]] = None,
streaming_callback: Optional[Callable[[StreamingChunk], None]] = None,
):
"""
Expand All @@ -76,8 +79,11 @@ def __init__(
:param tools: List of tools to use when generating content. See the documentation for
[Tool](https://cloud.google.com/python/docs/reference/aiplatform/latest/vertexai.generative_models.Tool)
the list of supported arguments.
:param tool_config: The tool config to use. See the documentation for [ToolConfig]
(https://cloud.google.com/vertex-ai/generative-ai/docs/reference/python/latest/vertexai.generative_models.ToolConfig)
:param system_instruction: Default system instruction to use for generating content.
:param streaming_callback: A callback function that is called when a new token is received from
the stream. The callback function accepts StreamingChunk as an argument.
the stream. The callback function accepts StreamingChunk as an argument.

"""

Expand All @@ -87,13 +93,25 @@ def __init__(
self._model_name = model
self._project_id = project_id
self._location = location
self._model = GenerativeModel(self._model_name)

# model parameters
self._generation_config = generation_config
self._safety_settings = safety_settings
self._tools = tools
self._tool_config = tool_config
self._system_instruction = system_instruction
self._streaming_callback = streaming_callback

# except streaming_callback, all other model parameters can be passed during initialization
self._model = GenerativeModel(
self._model_name,
generation_config=self._generation_config,
safety_settings=self._safety_settings,
tools=self._tools,
tool_config=self._tool_config,
system_instruction=self._system_instruction,
Comment on lines +106 to +112
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Initially, we were passing tools and other parameters directly in the generate_content method. However, these can be passed during model initialization instead, which simplifies parameter handling. See the reference here. This approach simplifies the API's requirement that system_instruction can only be passed during model initialization.

)

def _generation_config_to_dict(self, config: Union[GenerationConfig, Dict[str, Any]]) -> Dict[str, Any]:
if isinstance(config, dict):
return config
Expand All @@ -106,6 +124,17 @@ def _generation_config_to_dict(self, config: Union[GenerationConfig, Dict[str, A
"stop_sequences": config._raw_generation_config.stop_sequences,
}

def _tool_config_to_dict(self, tool_config: ToolConfig) -> Dict[str, Any]:
Amnah199 marked this conversation as resolved.
Show resolved Hide resolved
"""Serializes the ToolConfig object into a dictionary."""
mode = tool_config._gapic_tool_config.function_calling_config.mode
allowed_function_names = tool_config._gapic_tool_config.function_calling_config.allowed_function_names
config_dict = {"function_calling_config": {"mode": mode}}

if allowed_function_names:
config_dict["function_calling_config"]["allowed_function_names"] = allowed_function_names

return config_dict

def to_dict(self) -> Dict[str, Any]:
"""
Serializes the component to a dictionary.
Expand All @@ -123,10 +152,14 @@ def to_dict(self) -> Dict[str, Any]:
generation_config=self._generation_config,
safety_settings=self._safety_settings,
tools=self._tools,
tool_config=self._tool_config,
system_instruction=self._system_instruction,
streaming_callback=callback_name,
)
if (tools := data["init_parameters"].get("tools")) is not None:
data["init_parameters"]["tools"] = [Tool.to_dict(t) for t in tools]
if (tool_config := data["init_parameters"].get("tool_config")) is not None:
data["init_parameters"]["tool_config"] = self._tool_config_to_dict(tool_config)
if (generation_config := data["init_parameters"].get("generation_config")) is not None:
data["init_parameters"]["generation_config"] = self._generation_config_to_dict(generation_config)
return data
Expand All @@ -141,10 +174,23 @@ def from_dict(cls, data: Dict[str, Any]) -> "VertexAIGeminiChatGenerator":
:returns:
Deserialized component.
"""

def _tool_config_from_dict(config_dict: Dict[str, Any]) -> ToolConfig:
"""Deserializes the ToolConfig object from a dictionary."""
function_calling_config = config_dict["function_calling_config"]
return ToolConfig(
function_calling_config=ToolConfig.FunctionCallingConfig(
mode=function_calling_config["mode"],
allowed_function_names=function_calling_config.get("allowed_function_names"),
)
)

if (tools := data["init_parameters"].get("tools")) is not None:
data["init_parameters"]["tools"] = [Tool.from_dict(t) for t in tools]
if (generation_config := data["init_parameters"].get("generation_config")) is not None:
data["init_parameters"]["generation_config"] = GenerationConfig.from_dict(generation_config)
if (tool_config := data["init_parameters"].get("tool_config")) is not None:
data["init_parameters"]["tool_config"] = _tool_config_from_dict(tool_config)
if (serialized_callback_handler := data["init_parameters"].get("streaming_callback")) is not None:
data["init_parameters"]["streaming_callback"] = deserialize_callable(serialized_callback_handler)
return default_from_dict(cls, data)
Expand Down Expand Up @@ -212,9 +258,6 @@ def run(
new_message = self._message_to_part(messages[-1])
res = session.send_message(
content=new_message,
generation_config=self._generation_config,
safety_settings=self._safety_settings,
tools=self._tools,
stream=streaming_callback is not None,
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
HarmCategory,
Part,
Tool,
ToolConfig,
)

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -58,6 +59,8 @@ def __init__(
generation_config: Optional[Union[GenerationConfig, Dict[str, Any]]] = None,
safety_settings: Optional[Dict[HarmCategory, HarmBlockThreshold]] = None,
tools: Optional[List[Tool]] = None,
tool_config: Optional[ToolConfig] = None,
system_instruction: Optional[Union[str, ByteStream, Part]] = None,
streaming_callback: Optional[Callable[[StreamingChunk], None]] = None,
):
"""
Expand Down Expand Up @@ -86,6 +89,8 @@ def __init__(
:param tools: List of tools to use when generating content. See the documentation for
[Tool](https://cloud.google.com/python/docs/reference/aiplatform/latest/vertexai.generative_models.Tool)
the list of supported arguments.
:param tool_config: The tool config to use. See the documentation for [ToolConfig](https://cloud.google.com/vertex-ai/generative-ai/docs/reference/python/latest/vertexai.generative_models.ToolConfig)
:param system_instruction: Default system instruction to use for generating content.
:param streaming_callback: A callback function that is called when a new token is received from the stream.
The callback function accepts StreamingChunk as an argument.
"""
Expand All @@ -96,13 +101,25 @@ def __init__(
self._model_name = model
self._project_id = project_id
self._location = location
self._model = GenerativeModel(self._model_name)

# model parameters
self._generation_config = generation_config
self._safety_settings = safety_settings
self._tools = tools
self._tool_config = tool_config
self._system_instruction = system_instruction
self._streaming_callback = streaming_callback

# except streaming_callback, all other model parameters can be passed during initialization
self._model = GenerativeModel(
self._model_name,
generation_config=self._generation_config,
safety_settings=self._safety_settings,
tools=self._tools,
tool_config=self._tool_config,
system_instruction=self._system_instruction,
)

def _generation_config_to_dict(self, config: Union[GenerationConfig, Dict[str, Any]]) -> Dict[str, Any]:
if isinstance(config, dict):
return config
Expand All @@ -115,6 +132,18 @@ def _generation_config_to_dict(self, config: Union[GenerationConfig, Dict[str, A
"stop_sequences": config._raw_generation_config.stop_sequences,
}

def _tool_config_to_dict(self, tool_config: ToolConfig) -> Dict[str, Any]:
"""Serializes the ToolConfig object into a dictionary."""

mode = tool_config._gapic_tool_config.function_calling_config.mode
allowed_function_names = tool_config._gapic_tool_config.function_calling_config.allowed_function_names
config_dict = {"function_calling_config": {"mode": mode}}

if allowed_function_names:
config_dict["function_calling_config"]["allowed_function_names"] = allowed_function_names

return config_dict

def to_dict(self) -> Dict[str, Any]:
"""
Serializes the component to a dictionary.
Expand All @@ -132,10 +161,14 @@ def to_dict(self) -> Dict[str, Any]:
generation_config=self._generation_config,
safety_settings=self._safety_settings,
tools=self._tools,
tool_config=self._tool_config,
system_instruction=self._system_instruction,
streaming_callback=callback_name,
)
if (tools := data["init_parameters"].get("tools")) is not None:
data["init_parameters"]["tools"] = [Tool.to_dict(t) for t in tools]
if (tool_config := data["init_parameters"].get("tool_config")) is not None:
data["init_parameters"]["tool_config"] = self._tool_config_to_dict(tool_config)
if (generation_config := data["init_parameters"].get("generation_config")) is not None:
data["init_parameters"]["generation_config"] = self._generation_config_to_dict(generation_config)
return data
Expand All @@ -150,10 +183,23 @@ def from_dict(cls, data: Dict[str, Any]) -> "VertexAIGeminiGenerator":
:returns:
Deserialized component.
"""

def _tool_config_from_dict(config_dict: Dict[str, Any]) -> ToolConfig:
"""Deserializes the ToolConfig object from a dictionary."""
function_calling_config = config_dict["function_calling_config"]
return ToolConfig(
function_calling_config=ToolConfig.FunctionCallingConfig(
mode=function_calling_config["mode"],
allowed_function_names=function_calling_config.get("allowed_function_names"),
)
)

if (tools := data["init_parameters"].get("tools")) is not None:
data["init_parameters"]["tools"] = [Tool.from_dict(t) for t in tools]
if (generation_config := data["init_parameters"].get("generation_config")) is not None:
data["init_parameters"]["generation_config"] = GenerationConfig.from_dict(generation_config)
if (tool_config := data["init_parameters"].get("tool_config")) is not None:
data["init_parameters"]["tool_config"] = _tool_config_from_dict(tool_config)
if (serialized_callback_handler := data["init_parameters"].get("streaming_callback")) is not None:
data["init_parameters"]["streaming_callback"] = deserialize_callable(serialized_callback_handler)
return default_from_dict(cls, data)
Expand Down Expand Up @@ -188,11 +234,9 @@ def run(
converted_parts = [self._convert_part(p) for p in parts]

contents = [Content(parts=converted_parts, role="user")]

res = self._model.generate_content(
contents=contents,
generation_config=self._generation_config,
safety_settings=self._safety_settings,
tools=self._tools,
stream=streaming_callback is not None,
)
self._model.start_chat()
Expand Down
42 changes: 42 additions & 0 deletions integrations/google_vertex/tests/chat/test_gemini.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
HarmCategory,
Part,
Tool,
ToolConfig,
)

from haystack_integrations.components.generators.google_vertex import VertexAIGeminiChatGenerator
Expand Down Expand Up @@ -60,19 +61,29 @@ def test_init(mock_vertexai_init, _mock_generative_model):
safety_settings = {HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_ONLY_HIGH}

tool = Tool(function_declarations=[GET_CURRENT_WEATHER_FUNC])
tool_config = ToolConfig(
function_calling_config=ToolConfig.FunctionCallingConfig(
mode=ToolConfig.FunctionCallingConfig.Mode.ANY,
allowed_function_names=["get_current_weather_func"],
)
)

gemini = VertexAIGeminiChatGenerator(
project_id="TestID123",
location="TestLocation",
generation_config=generation_config,
safety_settings=safety_settings,
tools=[tool],
tool_config=tool_config,
system_instruction="Please provide brief answers.",
)
mock_vertexai_init.assert_called()
assert gemini._model_name == "gemini-1.5-flash"
assert gemini._generation_config == generation_config
assert gemini._safety_settings == safety_settings
assert gemini._tools == [tool]
assert gemini._tool_config == tool_config
assert gemini._system_instruction == "Please provide brief answers."


@patch("haystack_integrations.components.generators.google_vertex.chat.gemini.vertexai_init")
Expand All @@ -92,6 +103,8 @@ def test_to_dict(_mock_vertexai_init, _mock_generative_model):
"safety_settings": None,
"streaming_callback": None,
"tools": None,
"tool_config": None,
"system_instruction": None,
},
}

Expand All @@ -110,12 +123,20 @@ def test_to_dict_with_params(_mock_vertexai_init, _mock_generative_model):
safety_settings = {HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_ONLY_HIGH}

tool = Tool(function_declarations=[GET_CURRENT_WEATHER_FUNC])
tool_config = ToolConfig(
function_calling_config=ToolConfig.FunctionCallingConfig(
mode=ToolConfig.FunctionCallingConfig.Mode.ANY,
allowed_function_names=["get_current_weather_func"],
)
)

gemini = VertexAIGeminiChatGenerator(
project_id="TestID123",
generation_config=generation_config,
safety_settings=safety_settings,
tools=[tool],
tool_config=tool_config,
system_instruction="Please provide brief answers.",
)

assert gemini.to_dict() == {
Expand Down Expand Up @@ -155,6 +176,13 @@ def test_to_dict_with_params(_mock_vertexai_init, _mock_generative_model):
]
}
],
"tool_config": {
"function_calling_config": {
"mode": ToolConfig.FunctionCallingConfig.Mode.ANY,
"allowed_function_names": ["get_current_weather_func"],
}
},
"system_instruction": "Please provide brief answers.",
},
}

Expand All @@ -180,6 +208,8 @@ def test_from_dict(_mock_vertexai_init, _mock_generative_model):
assert gemini._project_id == "TestID123"
assert gemini._safety_settings is None
assert gemini._tools is None
assert gemini._tool_config is None
assert gemini._system_instruction is None
assert gemini._generation_config is None


Expand Down Expand Up @@ -222,6 +252,13 @@ def test_from_dict_with_param(_mock_vertexai_init, _mock_generative_model):
]
}
],
"tool_config": {
"function_calling_config": {
"mode": ToolConfig.FunctionCallingConfig.Mode.ANY,
"allowed_function_names": ["get_current_weather_func"],
}
},
"system_instruction": "Please provide brief answers.",
"streaming_callback": None,
},
}
Expand All @@ -231,7 +268,12 @@ def test_from_dict_with_param(_mock_vertexai_init, _mock_generative_model):
assert gemini._project_id == "TestID123"
assert gemini._safety_settings == {HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_ONLY_HIGH}
assert repr(gemini._tools) == repr([Tool(function_declarations=[GET_CURRENT_WEATHER_FUNC])])
assert isinstance(gemini._tool_config, ToolConfig)
assert isinstance(gemini._generation_config, GenerationConfig)
assert gemini._system_instruction == "Please provide brief answers."
assert (
gemini._tool_config._gapic_tool_config.function_calling_config.mode == ToolConfig.FunctionCallingConfig.Mode.ANY
)


@patch("haystack_integrations.components.generators.google_vertex.chat.gemini.GenerativeModel")
Expand Down
Loading