Skip to content

Commit

Permalink
Add generic litellm support
Browse files Browse the repository at this point in the history
  • Loading branch information
Marti2203 committed Jul 18, 2024
1 parent 2c003db commit bed6e77
Show file tree
Hide file tree
Showing 2 changed files with 137 additions and 5 deletions.
13 changes: 11 additions & 2 deletions app/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,11 +186,20 @@ def add_task_related_args(parser: ArgumentParser) -> None:
default=False,
help="Do not print most messages to stdout.",
)

def model_parser(name: str):
if not isinstance(name, str):
raise TypeError(f"Invalid model name: {name}")
if name in common.MODEL_HUB.keys():
return name
if name.startswith("litellm-generic-"):
return name
raise TypeError(f"Invalid model name: {name}")

parser.add_argument(
"--model",
type=str,
type=model_parser,
default="gpt-3.5-turbo-0125",
choices=list(common.MODEL_HUB.keys()),
help="The model to use. Currently only OpenAI models are supported.",
)
parser.add_argument(
Expand Down
129 changes: 126 additions & 3 deletions app/model/common.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,16 @@
import os
import sys
import threading
from abc import ABC, abstractmethod
from typing import Literal

from app.log import log_and_cprint
import litellm
from litellm import cost_per_token
from litellm.utils import Choices, Message, ModelResponse
from openai import BadRequestError
from tenacity import retry, stop_after_attempt, wait_random_exponential

from app.log import log_and_cprint, log_and_print

# Variables for each process. Since models are singleton objects, their references are copied
# to each process, but they all point to the same objects. For safe updating costs per process,
Expand Down Expand Up @@ -68,6 +76,104 @@ def get_overall_exec_stats(self):
}


class LiteLLMGeneric(Model):
"""
Base class for creating instances of LiteLLM-supported models.
"""

_instances = {}

def __new__(cls, model_name: str, cost_per_input: float, cost_per_output: float):
if model_name not in cls._instances:
cls._instances[model_name] = super().__new__(cls)
cls._instances[model_name]._initialized = False
return cls._instances[model_name]

def __init__(
self,
name: str,
cost_per_input: float,
cost_per_output: float,
parallel_tool_call: bool = False,
):
if self._initialized:
return
super().__init__(name, cost_per_input, cost_per_output, parallel_tool_call)
self._initialized = True

def setup(self) -> None:
"""
Check API key.
"""
pass

def check_api_key(self) -> str:
return ""

def extract_resp_content(self, chat_message: Message) -> str:
"""
Given a chat completion message, extract the content from it.
"""
content = chat_message.content
if content is None:
return ""
else:
return content

@retry(wait=wait_random_exponential(min=30, max=600), stop=stop_after_attempt(3))
def call(
self,
messages: list[dict],
top_p=1,
tools=None,
response_format: Literal["text", "json_object"] = "text",
**kwargs,
):
# FIXME: ignore tools field since we don't use tools now
try:
prefill_content = "{"
if response_format == "json_object": # prefill
messages.append({"role": "assistant", "content": prefill_content})

response = litellm.completion(
model=self.name,
messages=messages,
temperature=MODEL_TEMP,
max_tokens=os.getenv("ACR_TOKEN_LIMIT", 1024),
response_format=(
{"type": response_format} if "gpt" in self.name else None
),
top_p=top_p,
stream=False,
)
assert isinstance(response, ModelResponse)
resp_usage = response.usage
assert resp_usage is not None
input_tokens = int(resp_usage.prompt_tokens)
output_tokens = int(resp_usage.completion_tokens)
cost = self.calc_cost(input_tokens, output_tokens)

thread_cost.process_cost += cost
thread_cost.process_input_tokens += input_tokens
thread_cost.process_output_tokens += output_tokens

first_resp_choice = response.choices[0]
assert isinstance(first_resp_choice, Choices)
resp_msg: Message = first_resp_choice.message
content = self.extract_resp_content(resp_msg)
if response_format == "json_object":
# prepend the prefilled character
if not content.startswith(prefill_content):
content = prefill_content + content

return content, cost, input_tokens, output_tokens

except BadRequestError as e:
if e.code == "context_length_exceeded":
log_and_print("Context length exceeded")
raise e


MODEL_HUB = {}


Expand All @@ -86,10 +192,27 @@ def get_all_model_names():

def set_model(model_name: str):
global SELECTED_MODEL
if model_name not in MODEL_HUB:
if model_name not in MODEL_HUB and not model_name.startswith("litellm-generic-"):
print(f"Invalid model name: {model_name}")
sys.exit(1)
SELECTED_MODEL = MODEL_HUB[model_name]
if model_name.startswith("litellm-generic-"):
prompt_tokens = 5
completion_tokens = 10
prompt_tokens_cost_usd_dollar, completion_tokens_cost_usd_dollar = (
cost_per_token(
model="gpt-3.5-turbo",
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
)
)
litellm.set_verbose = True
SELECTED_MODEL = LiteLLMGeneric(
model_name.removeprefix("litellm-generic-"),
prompt_tokens_cost_usd_dollar,
completion_tokens_cost_usd_dollar,
)
else:
SELECTED_MODEL = MODEL_HUB[model_name]
SELECTED_MODEL.setup()


Expand Down

0 comments on commit bed6e77

Please sign in to comment.