Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add custom Langfuse span handling support #1313

Merged
merged 10 commits into from
Jan 28, 2025
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@
import httpx
from haystack import component, default_from_dict, default_to_dict, logging, tracing
from haystack.utils import Secret, deserialize_secrets_inplace
from haystack.utils.base_serialization import deserialize_class_instance, serialize_class_instance

from haystack_integrations.tracing.langfuse import LangfuseTracer
from haystack_integrations.tracing.langfuse import LangfuseTracer, SpanHandler
from langfuse import Langfuse

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -103,6 +104,7 @@ def __init__(
public_key: Optional[Secret] = Secret.from_env_var("LANGFUSE_PUBLIC_KEY"), # noqa: B008
secret_key: Optional[Secret] = Secret.from_env_var("LANGFUSE_SECRET_KEY"), # noqa: B008
httpx_client: Optional[httpx.Client] = None,
span_handler: Optional[SpanHandler] = None,
):
"""
Initialize the LangfuseConnector component.
Expand All @@ -117,11 +119,16 @@ def __init__(
:param httpx_client: Optional custom httpx.Client instance to use for Langfuse API calls. Note that when
deserializing a pipeline from YAML, any custom client is discarded and Langfuse will create its own default
client, since HTTPX clients cannot be serialized.
:param span_handler: Optional custom handler for processing spans. If None, uses DefaultSpanHandler.
The span handler controls how spans are created and processed, allowing customization of span types
based on component types and additional processing after spans are yielded. See SpanHandler class for
details on implementing custom handlers.
"""
self.name = name
self.public = public
self.secret_key = secret_key
self.public_key = public_key
self.span_handler = span_handler
self.tracer = LangfuseTracer(
tracer=Langfuse(
secret_key=secret_key.resolve_value() if secret_key else None,
Expand All @@ -130,6 +137,7 @@ def __init__(
),
name=name,
public=public,
span_handler=span_handler,
)
tracing.enable_tracing(self.tracer)

Expand Down Expand Up @@ -158,13 +166,15 @@ def to_dict(self) -> Dict[str, Any]:

:returns: The serialized component as a dictionary.
"""
span_handler = serialize_class_instance(self.span_handler) if self.span_handler else None
return default_to_dict(
self,
name=self.name,
public=self.public,
secret_key=self.secret_key.to_dict() if self.secret_key else None,
public_key=self.public_key.to_dict() if self.public_key else None,
# Note: httpx_client is not serialized as it's not serializable
span_handler=span_handler,
)

@classmethod
Expand All @@ -175,5 +185,9 @@ def from_dict(cls, data: Dict[str, Any]) -> "LangfuseConnector":
:param data: The dictionary representation of this component.
:returns: The deserialized component instance.
"""
deserialize_secrets_inplace(data["init_parameters"], keys=["secret_key", "public_key"])
init_params = data["init_parameters"]
deserialize_secrets_inplace(init_params, keys=["secret_key", "public_key"])
init_params["span_handler"] = (
deserialize_class_instance(init_params["span_handler"]) if init_params["span_handler"] else None
)
return default_from_dict(cls, data)
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# SPDX-FileCopyrightText: 2024-present deepset GmbH <[email protected]>
#
# SPDX-License-Identifier: Apache-2.0
from .tracer import LangfuseTracer
from .tracer import DefaultSpanHandler, LangfuseSpan, LangfuseTracer, SpanContext, SpanHandler

__all__ = ["LangfuseTracer"]
__all__ = ["DefaultSpanHandler", "LangfuseSpan", "LangfuseTracer", "SpanContext", "SpanHandler"]
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import contextlib
import os
from abc import ABC, abstractmethod
from contextvars import ContextVar
from dataclasses import dataclass
from datetime import datetime
from typing import Any, Dict, Iterator, List, Optional, Union
from typing import Any, Dict, Iterator, List, Optional

from haystack import logging
from haystack import default_from_dict, default_to_dict, logging
from haystack.dataclasses import ChatMessage
from haystack.tracing import Span, Tracer
from haystack.tracing import tracer as proxy_tracer
Expand Down Expand Up @@ -50,7 +52,7 @@ class LangfuseSpan(Span):
Internal class representing a bridge between the Haystack span tracing API and Langfuse.
"""

def __init__(self, span: "Union[langfuse.client.StatefulSpanClient, langfuse.client.StatefulTraceClient]") -> None:
def __init__(self, span: "langfuse.client.StatefulClient") -> None:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could you please explain this change?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for asking @anakin87 . It is a mistake! StatefulClient is a superclass of all these spans and one can create other spans from StatefulClient superclass API using generation/span methods but we need to keep the previous version because we also call update on these spans which is only declared in subclasses: StatefulSpanClient, StatefulTraceClient, and StatefulGenerationClient. I guess IDE/linters didn't pick these up because of forward references. I'll make corrections. 🙏

"""
Initialize a LangfuseSpan instance.

Expand Down Expand Up @@ -98,7 +100,7 @@ def set_content_tag(self, key: str, value: Any) -> None:

self._data[key] = value

def raw_span(self) -> "Union[langfuse.client.StatefulSpanClient, langfuse.client.StatefulTraceClient]":
def raw_span(self) -> "langfuse.client.StatefulClient":
"""
Return the underlying span instance.

Expand All @@ -110,21 +112,195 @@ def get_correlation_data_for_logs(self) -> Dict[str, Any]:
return {}


@dataclass(frozen=True)
class SpanContext:
"""
Context for creating spans in Langfuse.

Encapsulates the information needed to create and configure a span in Langfuse tracing.
Used by SpanHandler to determine the span type (trace, generation, or default) and its configuration.

:param name: The name of the span to create. For components, this is typically the component name.
:param operation_name: The operation being traced (e.g. "haystack.pipeline.run"). Used to determine
if a new trace should be created without warning.
:param component_type: The type of component creating the span (e.g. "OpenAIChatGenerator").
Can be used to determine the type of span to create.
:param tags: Additional metadata to attach to the span. Contains component input/output data
and other trace information.
:param parent_span: The parent span if this is a child span. If None, a new trace will be created.
:param trace_name: The name to use for the trace when creating a parent span. Defaults to "Haystack".
:param public: Whether traces should be publicly accessible. Defaults to False.
"""

name: str
operation_name: str
component_type: Optional[str]
tags: Dict[str, Any]
parent_span: Optional[Span]
trace_name: str = "Haystack"
public: bool = False

def __post_init__(self) -> None:
"""
Validate the span context attributes.

:raises ValueError: If name, operation_name or trace_name are empty
:raises TypeError: If tags is not a dictionary
"""
if not self.name:
msg = "Span name cannot be empty"
raise ValueError(msg)
if not self.operation_name:
msg = "Operation name cannot be empty"
raise ValueError(msg)
if not self.trace_name:
msg = "Trace name cannot be empty"
raise ValueError(msg)


class SpanHandler(ABC):
"""
Abstract base class for customizing how Langfuse spans are created and processed.

This class defines two key extension points:
1. create_span: Controls what type of span to create (default or generation)
2. handle: Processes the span after component execution (adding metadata, metrics, etc.)

To implement a custom handler:
- Extend this class or DefaultSpanHandler
- Override create_span and handle methods. It is more common to override handle.
- Pass your handler to LangfuseConnector init method
"""

def __init__(self):
self.tracer: Optional[langfuse.Langfuse] = None

def init_tracer(self, tracer: "langfuse.Langfuse") -> None:
"""
Initialize with Langfuse tracer. Called internally by LangfuseTracer.

:param tracer: The Langfuse client instance to use for creating spans
"""
self.tracer = tracer

@abstractmethod
def create_span(self, context: SpanContext) -> LangfuseSpan:
"""
Create a span of appropriate type based on the context.

This method determines what kind of span to create:
- A new trace if there's no parent span
- A generation span for LLM components
- A default span for other components

:param context: The context containing all information needed to create the span
:returns: A new LangfuseSpan instance configured according to the context
"""
pass

@abstractmethod
def handle(self, span: LangfuseSpan, component_type: Optional[str]) -> None:
"""
Process a span after component execution by attaching metadata and metrics.

This method is called after the component yields its span, allowing you to:
- Extract and attach token usage statistics
- Add model information
- Record timing data (e.g., time-to-first-token)
- Set log levels for quality monitoring
- Add custom metrics and observations

:param span: The span that was yielded by the component
:param component_type: The type of component that created the span, used to determine
what metadata to extract and how to process it
"""
pass

@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "SpanHandler":
return default_from_dict(cls, data)

def to_dict(self) -> Dict[str, Any]:
return default_to_dict(self)


class DefaultSpanHandler(SpanHandler):
"""DefaultSpanHandler provides the default Langfuse tracing behavior for Haystack."""

def create_span(self, context: SpanContext) -> LangfuseSpan:
message = "Tracer is not initialized"
if self.tracer is None:
raise RuntimeError(message)
tracing_ctx = tracing_context_var.get({})
if not context.parent_span:
if context.operation_name != _PIPELINE_RUN_KEY:
logger.warning(
"Creating a new trace without a parent span is not recommended for operation '{operation_name}'.",
operation_name=context.operation_name,
)
# Create a new trace when there's no parent span
return LangfuseSpan(
self.tracer.trace(
name=context.trace_name,
public=context.public,
id=tracing_ctx.get("trace_id"),
user_id=tracing_ctx.get("user_id"),
session_id=tracing_ctx.get("session_id"),
tags=tracing_ctx.get("tags"),
version=tracing_ctx.get("version"),
)
)
elif context.component_type in _ALL_SUPPORTED_GENERATORS:
return LangfuseSpan(context.parent_span.raw_span().generation(name=context.name))
else:
return LangfuseSpan(context.parent_span.raw_span().span(name=context.name))

def handle(self, span: LangfuseSpan, component_type: Optional[str]) -> None:
if component_type in _SUPPORTED_GENERATORS:
meta = span._data.get(_COMPONENT_OUTPUT_KEY, {}).get("meta")
if meta:
m = meta[0]
span._span.update(usage=m.get("usage") or None, model=m.get("model"))
elif component_type in _SUPPORTED_CHAT_GENERATORS:
replies = span._data.get(_COMPONENT_OUTPUT_KEY, {}).get("replies")
if replies:
meta = replies[0].meta
completion_start_time = meta.get("completion_start_time")
if completion_start_time:
try:
completion_start_time = datetime.fromisoformat(completion_start_time)
except ValueError:
logger.error(f"Failed to parse completion_start_time: {completion_start_time}")
completion_start_time = None
span._span.update(
usage=meta.get("usage") or None,
model=meta.get("model"),
completion_start_time=completion_start_time,
)


class LangfuseTracer(Tracer):
"""
Internal class representing a bridge between the Haystack tracer and Langfuse.
"""

def __init__(self, tracer: "langfuse.Langfuse", name: str = "Haystack", public: bool = False) -> None:
def __init__(
self,
tracer: "langfuse.Langfuse",
name: str = "Haystack",
public: bool = False,
span_handler: Optional[SpanHandler] = None,
) -> None:
"""
Initialize a LangfuseTracer instance.

:param tracer: The Langfuse tracer instance.
:param name: The name of the pipeline or component. This name will be used to identify the tracing run on the
Langfuse dashboard.
:param public: Whether the tracing data should be public or private. If set to `True`, the tracing data will
be publicly accessible to anyone with the tracing URL. If set to `False`, the tracing data will be private
and only accessible to the Langfuse account owner.
be publicly accessible to anyone with the tracing URL. If set to `False`, the tracing data will be private
and only accessible to the Langfuse account owner.
:param span_handler: Custom handler for processing spans. If None, uses DefaultSpanHandler.
"""
if not proxy_tracer.is_content_tracing_enabled:
logger.warning(
Expand All @@ -137,68 +313,37 @@ def __init__(self, tracer: "langfuse.Langfuse", name: str = "Haystack", public:
self._name = name
self._public = public
self.enforce_flush = os.getenv(HAYSTACK_LANGFUSE_ENFORCE_FLUSH_ENV_VAR, "true").lower() == "true"
self._span_handler = span_handler or DefaultSpanHandler()
self._span_handler.init_tracer(tracer)

@contextlib.contextmanager
def trace(
self, operation_name: str, tags: Optional[Dict[str, Any]] = None, parent_span: Optional[Span] = None
) -> Iterator[Span]:
tags = tags or {}
span_name = tags.get(_COMPONENT_NAME_KEY, operation_name)

# Create new span depending whether there's a parent span or not
if not parent_span:
if operation_name != _PIPELINE_RUN_KEY:
logger.warning(
"Creating a new trace without a parent span is not recommended for operation '{operation_name}'.",
operation_name=operation_name,
)
# Create a new trace if no parent span is provided
context = tracing_context_var.get({})
span = LangfuseSpan(
self._tracer.trace(
name=self._name,
public=self._public,
id=context.get("trace_id"),
user_id=context.get("user_id"),
session_id=context.get("session_id"),
tags=context.get("tags"),
version=context.get("version"),
)
component_type = tags.get(_COMPONENT_TYPE_KEY)

# Create span using the handler
span = self._span_handler.create_span(
SpanContext(
name=span_name,
operation_name=operation_name,
component_type=component_type,
tags=tags,
parent_span=parent_span,
trace_name=self._name,
public=self._public,
)
elif tags.get(_COMPONENT_TYPE_KEY) in _ALL_SUPPORTED_GENERATORS:
span = LangfuseSpan(parent_span.raw_span().generation(name=span_name))
else:
span = LangfuseSpan(parent_span.raw_span().span(name=span_name))
)

self._context.append(span)
span.set_tags(tags)

yield span

# Update span metadata based on component type
if tags.get(_COMPONENT_TYPE_KEY) in _SUPPORTED_GENERATORS:
# Haystack returns one meta dict for each message, but the 'usage' value
# is always the same, let's just pick the first item
meta = span._data.get(_COMPONENT_OUTPUT_KEY, {}).get("meta")
if meta:
m = meta[0]
span._span.update(usage=m.get("usage") or None, model=m.get("model"))
elif tags.get(_COMPONENT_TYPE_KEY) in _SUPPORTED_CHAT_GENERATORS:
replies = span._data.get(_COMPONENT_OUTPUT_KEY, {}).get("replies")
if replies:
meta = replies[0].meta
completion_start_time = meta.get("completion_start_time")
if completion_start_time:
try:
completion_start_time = datetime.fromisoformat(completion_start_time)
except ValueError:
logger.error(f"Failed to parse completion_start_time: {completion_start_time}")
completion_start_time = None
span._span.update(
usage=meta.get("usage") or None,
model=meta.get("model"),
completion_start_time=completion_start_time,
)
# Let the span handler process the span
self._span_handler.handle(span, component_type)

raw_span = span.raw_span()

Expand Down
Loading
Loading