Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add AWS Bedrock Callback #324

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
232 changes: 232 additions & 0 deletions libs/aws/langchain_aws/callbacks/bedrock_callback.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,232 @@
import threading
from typing import Any, Dict, List, Union

from langchain_core.callbacks import BaseCallbackHandler
from langchain_core.outputs import LLMResult

# https://aws.amazon.com/bedrock/pricing/
MODEL_COST_PER_1K_INPUT_TOKENS = {
"anthropic.claude-instant-v1": 0.0008,
"anthropic.claude-v2": 0.008,
"anthropic.claude-v2:1": 0.008,
"anthropic.claude-3-sonnet-20240229-v1:0": 0.003,
"anthropic.claude-3-5-sonnet-20240620-v1:0": 0.003,
"anthropic.claude-3-5-sonnet-20241022-v2:0": 0.003,
"anthropic.claude-3-haiku-20240307-v1:0": 0.00025,
"anthropic.claude-3-5-haiku-20241022-v1:0": 0.0008,
"anthropic.claude-3-opus-20240229-v1:0": 0.015,
"meta.llama3-70b-instruct-v1:0": 0.00265,
"meta.llama3-8b-instruct-v1:0": 0.0003,
"meta.llama2-13b-chat-v1": 0.00075,
"meta.llama2-70b-chat-v1": 0.00195,
"meta.llama3-1-70b-instruct-v1:0": 0.00072,
"meta.llama3-1-8b-instruct-v1:0": 0.00022,
"meta.llama3-2-11b-instruct-v1:0": 0.00016,
"meta.llama3-2-1b-instruct-v1:0": 0.0001,
"meta.llama3-2-3b-instruct-v1:0": 0.00015,
"meta.llama3-2-90b-instruct-v1:0": 0.00072,
"meta.llama3-3-70b-instruct-v1:0": 0.00072,
"amazon.nova-pro-v1:0": 0.0008,
"amazon.nova-lite-v1:0": 0.00006,
"amazon.nova-micro-v1:0": 0.000035,
}

MODEL_COST_PER_1K_OUTPUT_TOKENS = {
"anthropic.claude-instant-v1": 0.0024,
"anthropic.claude-v2": 0.024,
"anthropic.claude-v2:1": 0.024,
"anthropic.claude-3-sonnet-20240229-v1:0": 0.015,
"anthropic.claude-3-5-sonnet-20240620-v1:0": 0.015,
"anthropic.claude-3-5-sonnet-20241022-v2:0": 0.015,
"anthropic.claude-3-haiku-20240307-v1:0": 0.00125,
"anthropic.claude-3-5-haiku-20241022-v1:0": 0.004,
"anthropic.claude-3-opus-20240229-v1:0": 0.075,
"meta.llama3-70b-instruct-v1:0": 0.0035,
"meta.llama3-8b-instruct-v1:0": 0.0006,
"meta.llama2-13b-chat-v1": 0.00100,
"meta.llama2-70b-chat-v1": 0.00256,
"meta.llama3-1-70b-instruct-v1:0": 0.00072,
"meta.llama3-1-8b-instruct-v1:0": 0.00022,
"meta.llama3-2-11b-instruct-v1:0": 0.00016,
"meta.llama3-2-1b-instruct-v1:0": 0.0001,
"meta.llama3-2-3b-instruct-v1:0": 0.00015,
"meta.llama3-2-90b-instruct-v1:0": 0.00072,
"meta.llama3-3-70b-instruct-v1:0": 0.00072,
"amazon.nova-pro-v1:0": 0.0032,
"amazon.nova-lite-v1:0": 0.00024,
"amazon.nova-micro-v1:0": 0.00014,
}

MODEL_COST_PER_1K_INPUT_CACHE_WRITE_TOKENS = {
"anthropic.claude-instant-v1": 0.0,
"anthropic.claude-v2": 0.0,
"anthropic.claude-v2:1": 0.0,
"anthropic.claude-3-sonnet-20240229-v1:0": 0.0,
"anthropic.claude-3-5-sonnet-20240620-v1:0": 0.00375,
"anthropic.claude-3-5-sonnet-20241022-v2:0": 0.00375,
"anthropic.claude-3-haiku-20240307-v1:0": 0.0,
"anthropic.claude-3-5-haiku-20241022-v1:0": 0.001,
"anthropic.claude-3-opus-20240229-v1:0": 0.0,
"meta.llama3-70b-instruct-v1:0": 0.0,
"meta.llama3-8b-instruct-v1:0": 0.0,
"meta.llama2-13b-chat-v1": 0.0,
"meta.llama2-70b-chat-v1": 0.0,
"meta.llama3-1-70b-instruct-v1:0": 0.0,
"meta.llama3-1-8b-instruct-v1:0": 0.0,
"meta.llama3-2-11b-instruct-v1:0": 0.0,
"meta.llama3-2-1b-instruct-v1:0": 0.0,
"meta.llama3-2-3b-instruct-v1:0": 0.0,
"meta.llama3-2-90b-instruct-v1:0": 0.0,
"meta.llama3-3-70b-instruct-v1:0": 0.0,
"amazon.nova-pro-v1:0": 0.0,
"amazon.nova-lite-v1:0": 0.0,
"amazon.nova-micro-v1:0": 0.0,
}

MODEL_COST_PER_1K_INPUT_CACHE_READ_TOKENS = {
"anthropic.claude-instant-v1": 0.0,
"anthropic.claude-v2": 0.0,
"anthropic.claude-v2:1": 0.0,
"anthropic.claude-3-sonnet-20240229-v1:0": 0.0,
"anthropic.claude-3-5-sonnet-20240620-v1:0": 0.0003,
"anthropic.claude-3-5-sonnet-20241022-v2:0": 0.0003,
"anthropic.claude-3-haiku-20240307-v1:0": 0.0,
"anthropic.claude-3-5-haiku-20241022-v1:0": 0.00008,
"anthropic.claude-3-opus-20240229-v1:0": 0.0,
"meta.llama3-70b-instruct-v1:0": 0.0,
"meta.llama3-8b-instruct-v1:0": 0.0,
"meta.llama2-13b-chat-v1": 0.0,
"meta.llama2-70b-chat-v1": 0.0,
"meta.llama3-1-70b-instruct-v1:0": 0.0,
"meta.llama3-1-8b-instruct-v1:0": 0.0,
"meta.llama3-2-11b-instruct-v1:0": 0.0,
"meta.llama3-2-1b-instruct-v1:0": 0.0,
"meta.llama3-2-3b-instruct-v1:0": 0.0,
"meta.llama3-2-90b-instruct-v1:0": 0.0,
"meta.llama3-3-70b-instruct-v1:0": 0.0,
"amazon.nova-pro-v1:0": 0.0002,
"amazon.nova-lite-v1:0": 0.000015,
"amazon.nova-micro-v1:0": 0.00000875,
}


def _get_token_cost(
prompt_tokens: int,
prompt_tokens_cache_write: int,
prompt_tokens_cache_read: int,
completion_tokens: int,
model_id: Union[str, None],
) -> float:
if model_id not in MODEL_COST_PER_1K_INPUT_TOKENS:
# The model ID can be a cross-region (system-defined) inference profile ID,
# which has a prefix indicating the region (e.g., 'us', 'eu') but
# shares the same token costs as the "base model".
# By extracting the "base model ID", by taking the last two segments
# of the model ID, we can map cross-region inference profile IDs to
# their corresponding cost entries.
base_model_id = model_id.split(".")[-2] + "." + model_id.split(".")[-1]
else:
base_model_id = model_id
"""Get the cost of tokens for the model."""
if base_model_id not in MODEL_COST_PER_1K_INPUT_TOKENS:
raise ValueError(
f"Failed to calculate token cost. Unknown model: {model_id}. "
"Please provide a valid model name. "
f"Known models are: {",".join(MODEL_COST_PER_1K_INPUT_TOKENS.keys())}"
)
return round(
((prompt_tokens - prompt_tokens_cache_read) / 1000)
* MODEL_COST_PER_1K_INPUT_TOKENS[base_model_id]
+ (prompt_tokens_cache_write / 1000)
* MODEL_COST_PER_1K_INPUT_CACHE_WRITE_TOKENS[base_model_id]
+ (prompt_tokens_cache_read / 1000)
* MODEL_COST_PER_1K_INPUT_CACHE_READ_TOKENS[base_model_id]
+ (completion_tokens / 1000) * MODEL_COST_PER_1K_OUTPUT_TOKENS[base_model_id],
7,
)


class BedrockTokenUsageCallbackHandler(BaseCallbackHandler):
"""Callback Handler that tracks bedrock anthropic info."""

total_tokens: int = 0
prompt_tokens: int = 0
prompt_tokens_cache_write: int = 0
prompt_tokens_cache_read: int = 0
completion_tokens: int = 0
successful_requests: int = 0
total_cost: float = 0.0

def __init__(self) -> None:
super().__init__()
self._lock = threading.Lock()

def __repr__(self) -> str:
return (
f"Tokens Used: {self.total_tokens}\n"
f"\tPrompt Tokens: {self.prompt_tokens}\n"
f"\t\tCache Write: {self.prompt_tokens_cache_write}\n"
f"\t\tCache Read: {self.prompt_tokens_cache_read}\n"
f"\tCompletion Tokens: {self.completion_tokens}\n"
f"Successful Requests: {self.successful_requests}\n"
f"Total Cost (USD): ${self.total_cost}"
)

@property
def always_verbose(self) -> bool:
"""Whether to call verbose callbacks even if verbose is False."""
return True

def on_llm_start(
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
) -> None:
"""Print out the prompts."""
pass

def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
"""Print out the token."""
pass

def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
"""Collect token usage."""
if response.llm_output is None:
return None

if "usage" not in response.llm_output:
with self._lock:
self.successful_requests += 1
return None

# compute tokens and cost for this request
token_usage = response.llm_output["usage"]
prompt_tokens = token_usage.get("prompt_tokens", 0)
prompt_tokens_cache_write = token_usage.get("prompt_tokens_cache_write", 0)
prompt_tokens_cache_read = token_usage.get("prompt_tokens_cache_read", 0)
completion_tokens = token_usage.get("completion_tokens", 0)
total_tokens = token_usage.get("total_tokens", 0)
model_id = response.llm_output.get("model_id", None)
total_cost = _get_token_cost(
prompt_tokens=prompt_tokens,
prompt_tokens_cache_write=prompt_tokens_cache_write,
prompt_tokens_cache_read=prompt_tokens_cache_read,
completion_tokens=completion_tokens,
model_id=model_id,
)

# update shared state behind lock
with self._lock:
self.total_cost += total_cost
self.total_tokens += total_tokens
self.prompt_tokens += prompt_tokens
self.prompt_tokens_cache_write += prompt_tokens_cache_write
self.prompt_tokens_cache_read += prompt_tokens_cache_read
self.completion_tokens += completion_tokens
self.successful_requests += 1

def __copy__(self) -> "BedrockTokenUsageCallbackHandler":
"""Return a copy of the callback handler."""
return self

def __deepcopy__(self, memo: Any) -> "BedrockTokenUsageCallbackHandler":
"""Return a deep copy of the callback handler."""
return self
45 changes: 45 additions & 0 deletions libs/aws/langchain_aws/callbacks/manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
from __future__ import annotations

import logging
from contextlib import contextmanager
from contextvars import ContextVar
from typing import (
Generator,
Optional,
)

from langchain_core.tracers.context import register_configure_hook

from langchain_aws.callbacks.bedrock_callback import (
BedrockTokenUsageCallbackHandler,
)

logger = logging.getLogger(__name__)


bedrock_callback_var: (ContextVar)[
Optional[BedrockTokenUsageCallbackHandler]
] = ContextVar("bedrock_anthropic_callback", default=None)

register_configure_hook(bedrock_callback_var, True)


@contextmanager
def get_bedrock_callback() -> (
Generator[BedrockTokenUsageCallbackHandler, None, None]
):
"""Get the Bedrock callback handler in a context manager.
which conveniently exposes token and cost information.

Returns:
BedrockTokenUsageCallbackHandler:
The Bedrock callback handler.

Example:
>>> with get_bedrock_callback() as cb:
... # Use the Bedrock callback handler
"""
cb = BedrockTokenUsageCallbackHandler()
bedrock_callback_var.set(cb)
yield cb
bedrock_callback_var.set(None)
8 changes: 7 additions & 1 deletion libs/aws/langchain_aws/chat_models/bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
HumanMessage,
SystemMessage,
)
from langchain_core.messages.ai import UsageMetadata
from langchain_core.messages.ai import InputTokenDetails, UsageMetadata
from langchain_core.messages.tool import ToolCall, ToolMessage
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
from langchain_core.runnables import Runnable, RunnableMap, RunnablePassthrough
Expand Down Expand Up @@ -578,11 +578,17 @@ def _generate(
# usage metadata
if usage := llm_output.get("usage"):
input_tokens = usage.get("prompt_tokens", 0)
cache_creation = usage.get("prompt_tokens_cache_write", 0)
cache_read = usage.get("prompt_tokens_cache_read", 0)
output_tokens = usage.get("completion_tokens", 0)
usage_metadata = UsageMetadata(
input_tokens=input_tokens,
output_tokens=output_tokens,
total_tokens=usage.get("total_tokens", input_tokens + output_tokens),
input_token_details=InputTokenDetails(
cache_creation=cache_creation,
cache_read=cache_read,
),
)
else:
usage_metadata = None
Expand Down
Loading
Loading