Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FEATURE] Capture metrics #136

Merged
merged 10 commits into from
Dec 18, 2024
Merged
1,261 changes: 649 additions & 612 deletions pdm.lock

Large diffs are not rendered by default.

6 changes: 4 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@ dependencies = [
#Unable to import llm adapters : No module named 'mistralai.models.chat_completion'
#Looks like mistralai>0.4.2 is not backward compatible
"mistralai==0.4.2",

"llama-index-llms-anyscale==0.1.4",
"llama-index-llms-anthropic==0.1.16",
"llama-index-llms-azure-openai==0.1.10",
Expand All @@ -58,6 +57,7 @@ dependencies = [
"singleton-decorator~=1.0.0",
"httpx>=0.25.2",
"pdfplumber>=0.11.2",
"redis>=5.2.1",
]
readme = "README.md"
urls = { Homepage = "https://unstract.com", "Release notes" = "https://github.com/Zipstack/unstract-sdk/releases", Source = "https://github.com/Zipstack/unstract-sdk" }
Expand Down Expand Up @@ -120,4 +120,6 @@ path = "src/unstract/sdk/__init__.py"
# Adding the following override to resolve dependency version
# for environs. Otherwise, it stays stuck while resolving pins
[tool.pdm.resolution.overrides]
grpcio = ">=1.62.1"
grpcio = "1.62.3"
grpcio-tools = "1.62.3"
grpcio-health-checking = "1.62.3"
20 changes: 18 additions & 2 deletions src/unstract/sdk/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from unstract.sdk.file_storage import FileStorage, FileStorageProvider
from unstract.sdk.tool.base import BaseTool
from unstract.sdk.utils import ToolUtils
from unstract.sdk.utils.common_utils import log_elapsed
from unstract.sdk.utils.common_utils import capture_metrics, log_elapsed
from unstract.sdk.vector_db import VectorDB
from unstract.sdk.x2txt import X2Text

Expand All @@ -39,10 +39,19 @@ class Constants:


class Index:
def __init__(self, tool: BaseTool):
def __init__(
self,
tool: BaseTool,
run_id: Optional[str] = None,
capture_metrics: bool = False,
):
# TODO: Inherit from StreamMixin and avoid using BaseTool
self.tool = tool
self._run_id = run_id
self._capture_metrics = capture_metrics
self._metrics = {}

@capture_metrics
def query_index(
self,
embedding_instance_id: str,
Expand Down Expand Up @@ -180,6 +189,7 @@ def extract_text(
return extracted_text

@log_elapsed(operation="CHECK_AND_INDEX(overall)")
@capture_metrics
def index(
self,
tool_id: str,
Expand Down Expand Up @@ -449,6 +459,12 @@ def generate_index_key(
hashed_index_key = ToolUtils.hash_str(json.dumps(index_key, sort_keys=True))
return hashed_index_key

def get_metrics(self):
return self._metrics

def clear_metrics(self):
self._metrics = {}

@deprecated(version="0.45.0", reason="Use generate_index_key() instead")
def generate_file_id(
self,
Expand Down
29 changes: 20 additions & 9 deletions src/unstract/sdk/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from unstract.sdk.helper import SdkHelper
from unstract.sdk.tool.base import BaseTool
from unstract.sdk.utils.callback_manager import CallbackManager
from unstract.sdk.utils.common_utils import capture_metrics

logger = logging.getLogger(__name__)

Expand All @@ -36,6 +37,7 @@ def __init__(
tool: BaseTool,
adapter_instance_id: Optional[str] = None,
usage_kwargs: dict[Any, Any] = {},
capture_metrics: bool = False,
):
"""Creates an instance of this LLM class.

Expand All @@ -50,6 +52,10 @@ def __init__(
self._adapter_instance_id = adapter_instance_id
self._llm_instance: LlamaIndexLLM = None
self._usage_kwargs = usage_kwargs
self._capture_metrics = capture_metrics
gaya3-zipstack marked this conversation as resolved.
Show resolved Hide resolved
self._run_id = usage_kwargs.get("run_id")
self._usage_reason = usage_kwargs.get("llm_usage_reason")
self._metrics = {}
self._initialise()

def _initialise(self):
Expand All @@ -65,14 +71,16 @@ def _initialise(self):
kwargs=self._usage_kwargs,
)

@capture_metrics
def complete(
self,
prompt: str,
extract_json: bool = True,
process_text: Optional[Callable[[str], str]] = None,
**kwargs: Any,
) -> dict[str, Any]:
"""Generates a completion response for the given prompt.
"""Generates a completion response for the given prompt and captures
metrics if run_id is provided.

Args:
prompt (str): The input text prompt for generating the completion.
Expand All @@ -85,12 +93,8 @@ def complete(
**kwargs (Any): Additional arguments passed to the completion function.

Returns:
dict[str, Any]: A dictionary containing the result of the completion
and any processed output.

Raises:
LLMError: If an error occurs during the completion process, it will be
raised after being processed by `parse_llm_err`.
dict[str, Any]: A dictionary containing the result of the completion,
any processed output, and the captured metrics (if applicable).
"""
try:
response: CompletionResponse = self._llm_instance.complete(prompt, **kwargs)
Expand All @@ -105,12 +109,19 @@ def complete(
if not isinstance(process_text_output, dict):
process_text_output = {}
except Exception as e:
logger.error(f"Error occured inside function 'process_text': {e}")
logger.error(f"Error occurred inside function 'process_text': {e}")
process_text_output = {}
return {LLM.RESPONSE: response, **process_text_output}
response_data = {LLM.RESPONSE: response, **process_text_output}
return response_data
except Exception as e:
raise parse_llm_err(e, self._llm_instance) from e

def get_metrics(self):
return self._metrics

def get_usage_reason(self):
return self._usage_reason

def stream_complete(
self,
prompt: str,
Expand Down
60 changes: 60 additions & 0 deletions src/unstract/sdk/metrics_mixin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
import os
import time
import uuid
from typing import Any

from redis import StrictRedis


class MetricsMixin:
TIME_TAKEN_KEY = "time_taken(s)"

def __init__(self, run_id):
"""Initialize the MetricsMixin class.

Args:
run_id (str): Unique identifier for the run.
"""
self.run_id = run_id
Deepak-Kesavan marked this conversation as resolved.
Show resolved Hide resolved
self.op_id = str(uuid.uuid4()) # Unique identifier for this instance

# Initialize Redis client
self.redis_client = StrictRedis(
host=os.getenv("REDIS_HOST", "unstract-redis"),
port=int(os.getenv("REDIS_PORT", 6379)),
username=os.getenv("REDIS_USER", "default"),
password=os.getenv("REDIS_PASSWORD", ""),
db=1,
decode_responses=True,
)
self.redis_key = f"metrics:{self.run_id}:{self.op_id}"

# Set the start time immediately upon initialization
self.set_start_time()

def set_start_time(self, ttl=86400):
"""Store the current timestamp in Redis when the instance is
created."""
self.redis_client.set(self.redis_key, time.time(), ex=ttl)

def collect_metrics(self) -> dict[str, Any]:
"""Calculate the time taken since the timestamp was set and delete the
Redis key.

Returns:
dict: The calculated time taken and the associated run_id and op_id.
"""
# Check if the key exists in Redis
if not self.redis_client.exists(self.redis_key):
# If the key is missing, return an empty metrics dictionary
return {self.TIME_TAKEN_KEY: None}

start_time = float(
self.redis_client.get(self.redis_key)
) # Get the stored timestamp
time_taken = round(time.time() - start_time, 3)

# Delete the Redis key after use
self.redis_client.delete(self.redis_key)

return {self.TIME_TAKEN_KEY: time_taken}
49 changes: 49 additions & 0 deletions src/unstract/sdk/utils/common_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import uuid

from unstract.sdk.constants import LogLevel
from unstract.sdk.metrics_mixin import MetricsMixin

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -54,3 +55,51 @@ def wrapper(*args, **kwargs):
return wrapper

return decorator


def capture_metrics(func):
"""Decorator to capture metrics at the start and end of a function."""

@functools.wraps(func)
def wrapper(self, *args, **kwargs):
Deepak-Kesavan marked this conversation as resolved.
Show resolved Hide resolved
# Ensure the required attributes exist; if not,
# execute the function and return its result
if not all(
hasattr(self, attr) for attr in ["_run_id", "_capture_metrics", "_metrics"]
):
return func(self, *args, **kwargs)

# Check if run_id exists and if metrics should be captured
metrics_mixin = None
time_taken_key = MetricsMixin.TIME_TAKEN_KEY
if self._run_id and self._capture_metrics:
metrics_mixin = MetricsMixin(run_id=self._run_id)

try:
result = func(self, *args, **kwargs)
finally:
# If metrics are being captured, collect and assign them at the end
if metrics_mixin:
new_metrics = metrics_mixin.collect_metrics()

# If time_taken(s) exists in both self._metrics and new_metrics, sum it
if (
self._metrics
and time_taken_key in self._metrics
and time_taken_key in new_metrics
):
time_taken_self = self._metrics.get(time_taken_key)
time_taken_new = new_metrics.get(time_taken_key)
chandrasekharan-zipstack marked this conversation as resolved.
Show resolved Hide resolved

# Only sum if both are valid
if time_taken_self and time_taken_new:
self._metrics[time_taken_key] = time_taken_self + time_taken_new
else:
self._metrics[time_taken_key] = None
chandrasekharan-zipstack marked this conversation as resolved.
Show resolved Hide resolved
else:
# If the key isn't in self._metrics, set it to new_metrics
self._metrics = new_metrics

return result

return wrapper