Skip to content

Commit

Permalink
keep_alive classmethod wrappers instead of custom class + only wrap i…
Browse files Browse the repository at this point in the history
…f keep_alive differs from default
  • Loading branch information
Robinsane committed Mar 28, 2024
1 parent a8fd51d commit 437f921
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 37 deletions.
33 changes: 0 additions & 33 deletions private_gpt/components/llm/custom/ollama.py

This file was deleted.

26 changes: 22 additions & 4 deletions private_gpt/components/llm/llm_component.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import logging
from collections.abc import Callable

from injector import inject, singleton
from llama_index.core.llms import LLM, MockLLM
Expand Down Expand Up @@ -108,8 +109,8 @@ def __init__(self, settings: Settings) -> None:
)
case "ollama":
try:
from private_gpt.components.llm.custom.ollama import (
CustomOllama, # type: ignore
from llama_index.llms.ollama import (
Ollama, # type: ignore
)
except ImportError as e:
raise ImportError(
Expand All @@ -127,15 +128,32 @@ def __init__(self, settings: Settings) -> None:
"repeat_penalty": ollama_settings.repeat_penalty, # ollama llama-cpp
}

self.llm = CustomOllama(
self.llm = Ollama(
model=ollama_settings.llm_model,
base_url=ollama_settings.api_base,
temperature=settings.llm.temperature,
context_window=settings.llm.context_window,
additional_kwargs=settings_kwargs,
request_timeout=ollama_settings.request_timeout,
keep_alive=ollama_settings.keep_alive,
)

if (
ollama_settings.keep_alive
!= ollama_settings.model_fields["keep_alive"].default
):
# Modify Ollama methods to use the "keep_alive" field.
def add_keep_alive(func: Callable) -> Callable:
def wrapper(*args, **kwargs) -> Callable:
kwargs["keep_alive"] = ollama_settings.keep_alive
return func(*args, **kwargs)

return wrapper

Ollama.chat = add_keep_alive(Ollama.chat)
Ollama.stream_chat = add_keep_alive(Ollama.stream_chat)
Ollama.complete = add_keep_alive(Ollama.complete)
Ollama.stream_complete = add_keep_alive(Ollama.stream_complete)

case "azopenai":
try:
from llama_index.llms.azure_openai import ( # type: ignore
Expand Down

0 comments on commit 437f921

Please sign in to comment.