diff --git a/libs/aws/langchain_aws/callbacks/bedrock_callback.py b/libs/aws/langchain_aws/callbacks/bedrock_callback.py new file mode 100644 index 00000000..cf5e7dd2 --- /dev/null +++ b/libs/aws/langchain_aws/callbacks/bedrock_callback.py @@ -0,0 +1,232 @@ +import threading +from typing import Any, Dict, List, Union + +from langchain_core.callbacks import BaseCallbackHandler +from langchain_core.outputs import LLMResult + +# https://aws.amazon.com/bedrock/pricing/ +MODEL_COST_PER_1K_INPUT_TOKENS = { + "anthropic.claude-instant-v1": 0.0008, + "anthropic.claude-v2": 0.008, + "anthropic.claude-v2:1": 0.008, + "anthropic.claude-3-sonnet-20240229-v1:0": 0.003, + "anthropic.claude-3-5-sonnet-20240620-v1:0": 0.003, + "anthropic.claude-3-5-sonnet-20241022-v2:0": 0.003, + "anthropic.claude-3-haiku-20240307-v1:0": 0.00025, + "anthropic.claude-3-5-haiku-20241022-v1:0": 0.0008, + "anthropic.claude-3-opus-20240229-v1:0": 0.015, + "meta.llama3-70b-instruct-v1:0": 0.00265, + "meta.llama3-8b-instruct-v1:0": 0.0003, + "meta.llama2-13b-chat-v1": 0.00075, + "meta.llama2-70b-chat-v1": 0.00195, + "meta.llama3-1-70b-instruct-v1:0": 0.00072, + "meta.llama3-1-8b-instruct-v1:0": 0.00022, + "meta.llama3-2-11b-instruct-v1:0": 0.00016, + "meta.llama3-2-1b-instruct-v1:0": 0.0001, + "meta.llama3-2-3b-instruct-v1:0": 0.00015, + "meta.llama3-2-90b-instruct-v1:0": 0.00072, + "meta.llama3-3-70b-instruct-v1:0": 0.00072, + "amazon.nova-pro-v1:0": 0.0008, + "amazon.nova-lite-v1:0": 0.00006, + "amazon.nova-micro-v1:0": 0.000035, +} + +MODEL_COST_PER_1K_OUTPUT_TOKENS = { + "anthropic.claude-instant-v1": 0.0024, + "anthropic.claude-v2": 0.024, + "anthropic.claude-v2:1": 0.024, + "anthropic.claude-3-sonnet-20240229-v1:0": 0.015, + "anthropic.claude-3-5-sonnet-20240620-v1:0": 0.015, + "anthropic.claude-3-5-sonnet-20241022-v2:0": 0.015, + "anthropic.claude-3-haiku-20240307-v1:0": 0.00125, + "anthropic.claude-3-5-haiku-20241022-v1:0": 0.004, + "anthropic.claude-3-opus-20240229-v1:0": 0.075, + "meta.llama3-70b-instruct-v1:0": 0.0035, + "meta.llama3-8b-instruct-v1:0": 0.0006, + "meta.llama2-13b-chat-v1": 0.00100, + "meta.llama2-70b-chat-v1": 0.00256, + "meta.llama3-1-70b-instruct-v1:0": 0.00072, + "meta.llama3-1-8b-instruct-v1:0": 0.00022, + "meta.llama3-2-11b-instruct-v1:0": 0.00016, + "meta.llama3-2-1b-instruct-v1:0": 0.0001, + "meta.llama3-2-3b-instruct-v1:0": 0.00015, + "meta.llama3-2-90b-instruct-v1:0": 0.00072, + "meta.llama3-3-70b-instruct-v1:0": 0.00072, + "amazon.nova-pro-v1:0": 0.0032, + "amazon.nova-lite-v1:0": 0.00024, + "amazon.nova-micro-v1:0": 0.00014, +} + +MODEL_COST_PER_1K_INPUT_CACHE_WRITE_TOKENS = { + "anthropic.claude-instant-v1": 0.0, + "anthropic.claude-v2": 0.0, + "anthropic.claude-v2:1": 0.0, + "anthropic.claude-3-sonnet-20240229-v1:0": 0.0, + "anthropic.claude-3-5-sonnet-20240620-v1:0": 0.00375, + "anthropic.claude-3-5-sonnet-20241022-v2:0": 0.00375, + "anthropic.claude-3-haiku-20240307-v1:0": 0.0, + "anthropic.claude-3-5-haiku-20241022-v1:0": 0.001, + "anthropic.claude-3-opus-20240229-v1:0": 0.0, + "meta.llama3-70b-instruct-v1:0": 0.0, + "meta.llama3-8b-instruct-v1:0": 0.0, + "meta.llama2-13b-chat-v1": 0.0, + "meta.llama2-70b-chat-v1": 0.0, + "meta.llama3-1-70b-instruct-v1:0": 0.0, + "meta.llama3-1-8b-instruct-v1:0": 0.0, + "meta.llama3-2-11b-instruct-v1:0": 0.0, + "meta.llama3-2-1b-instruct-v1:0": 0.0, + "meta.llama3-2-3b-instruct-v1:0": 0.0, + "meta.llama3-2-90b-instruct-v1:0": 0.0, + "meta.llama3-3-70b-instruct-v1:0": 0.0, + "amazon.nova-pro-v1:0": 0.0, + "amazon.nova-lite-v1:0": 0.0, + "amazon.nova-micro-v1:0": 0.0, +} + +MODEL_COST_PER_1K_INPUT_CACHE_READ_TOKENS = { + "anthropic.claude-instant-v1": 0.0, + "anthropic.claude-v2": 0.0, + "anthropic.claude-v2:1": 0.0, + "anthropic.claude-3-sonnet-20240229-v1:0": 0.0, + "anthropic.claude-3-5-sonnet-20240620-v1:0": 0.0003, + "anthropic.claude-3-5-sonnet-20241022-v2:0": 0.0003, + "anthropic.claude-3-haiku-20240307-v1:0": 0.0, + "anthropic.claude-3-5-haiku-20241022-v1:0": 0.00008, + "anthropic.claude-3-opus-20240229-v1:0": 0.0, + "meta.llama3-70b-instruct-v1:0": 0.0, + "meta.llama3-8b-instruct-v1:0": 0.0, + "meta.llama2-13b-chat-v1": 0.0, + "meta.llama2-70b-chat-v1": 0.0, + "meta.llama3-1-70b-instruct-v1:0": 0.0, + "meta.llama3-1-8b-instruct-v1:0": 0.0, + "meta.llama3-2-11b-instruct-v1:0": 0.0, + "meta.llama3-2-1b-instruct-v1:0": 0.0, + "meta.llama3-2-3b-instruct-v1:0": 0.0, + "meta.llama3-2-90b-instruct-v1:0": 0.0, + "meta.llama3-3-70b-instruct-v1:0": 0.0, + "amazon.nova-pro-v1:0": 0.0002, + "amazon.nova-lite-v1:0": 0.000015, + "amazon.nova-micro-v1:0": 0.00000875, +} + + +def _get_token_cost( + prompt_tokens: int, + prompt_tokens_cache_write: int, + prompt_tokens_cache_read: int, + completion_tokens: int, + model_id: Union[str, None], +) -> float: + if model_id not in MODEL_COST_PER_1K_INPUT_TOKENS: + # The model ID can be a cross-region (system-defined) inference profile ID, + # which has a prefix indicating the region (e.g., 'us', 'eu') but + # shares the same token costs as the "base model". + # By extracting the "base model ID", by taking the last two segments + # of the model ID, we can map cross-region inference profile IDs to + # their corresponding cost entries. + base_model_id = model_id.split(".")[-2] + "." + model_id.split(".")[-1] + else: + base_model_id = model_id + """Get the cost of tokens for the model.""" + if base_model_id not in MODEL_COST_PER_1K_INPUT_TOKENS: + raise ValueError( + f"Failed to calculate token cost. Unknown model: {model_id}. " + "Please provide a valid model name. " + f"Known models are: {",".join(MODEL_COST_PER_1K_INPUT_TOKENS.keys())}" + ) + return round( + ((prompt_tokens - prompt_tokens_cache_read) / 1000) + * MODEL_COST_PER_1K_INPUT_TOKENS[base_model_id] + + (prompt_tokens_cache_write / 1000) + * MODEL_COST_PER_1K_INPUT_CACHE_WRITE_TOKENS[base_model_id] + + (prompt_tokens_cache_read / 1000) + * MODEL_COST_PER_1K_INPUT_CACHE_READ_TOKENS[base_model_id] + + (completion_tokens / 1000) * MODEL_COST_PER_1K_OUTPUT_TOKENS[base_model_id], + 7, + ) + + +class BedrockTokenUsageCallbackHandler(BaseCallbackHandler): + """Callback Handler that tracks bedrock anthropic info.""" + + total_tokens: int = 0 + prompt_tokens: int = 0 + prompt_tokens_cache_write: int = 0 + prompt_tokens_cache_read: int = 0 + completion_tokens: int = 0 + successful_requests: int = 0 + total_cost: float = 0.0 + + def __init__(self) -> None: + super().__init__() + self._lock = threading.Lock() + + def __repr__(self) -> str: + return ( + f"Tokens Used: {self.total_tokens}\n" + f"\tPrompt Tokens: {self.prompt_tokens}\n" + f"\t\tCache Write: {self.prompt_tokens_cache_write}\n" + f"\t\tCache Read: {self.prompt_tokens_cache_read}\n" + f"\tCompletion Tokens: {self.completion_tokens}\n" + f"Successful Requests: {self.successful_requests}\n" + f"Total Cost (USD): ${self.total_cost}" + ) + + @property + def always_verbose(self) -> bool: + """Whether to call verbose callbacks even if verbose is False.""" + return True + + def on_llm_start( + self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any + ) -> None: + """Print out the prompts.""" + pass + + def on_llm_new_token(self, token: str, **kwargs: Any) -> None: + """Print out the token.""" + pass + + def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None: + """Collect token usage.""" + if response.llm_output is None: + return None + + if "usage" not in response.llm_output: + with self._lock: + self.successful_requests += 1 + return None + + # compute tokens and cost for this request + token_usage = response.llm_output["usage"] + prompt_tokens = token_usage.get("prompt_tokens", 0) + prompt_tokens_cache_write = token_usage.get("prompt_tokens_cache_write", 0) + prompt_tokens_cache_read = token_usage.get("prompt_tokens_cache_read", 0) + completion_tokens = token_usage.get("completion_tokens", 0) + total_tokens = token_usage.get("total_tokens", 0) + model_id = response.llm_output.get("model_id", None) + total_cost = _get_token_cost( + prompt_tokens=prompt_tokens, + prompt_tokens_cache_write=prompt_tokens_cache_write, + prompt_tokens_cache_read=prompt_tokens_cache_read, + completion_tokens=completion_tokens, + model_id=model_id, + ) + + # update shared state behind lock + with self._lock: + self.total_cost += total_cost + self.total_tokens += total_tokens + self.prompt_tokens += prompt_tokens + self.prompt_tokens_cache_write += prompt_tokens_cache_write + self.prompt_tokens_cache_read += prompt_tokens_cache_read + self.completion_tokens += completion_tokens + self.successful_requests += 1 + + def __copy__(self) -> "BedrockTokenUsageCallbackHandler": + """Return a copy of the callback handler.""" + return self + + def __deepcopy__(self, memo: Any) -> "BedrockTokenUsageCallbackHandler": + """Return a deep copy of the callback handler.""" + return self diff --git a/libs/aws/langchain_aws/callbacks/manager.py b/libs/aws/langchain_aws/callbacks/manager.py new file mode 100644 index 00000000..962f6d82 --- /dev/null +++ b/libs/aws/langchain_aws/callbacks/manager.py @@ -0,0 +1,45 @@ +from __future__ import annotations + +import logging +from contextlib import contextmanager +from contextvars import ContextVar +from typing import ( + Generator, + Optional, +) + +from langchain_core.tracers.context import register_configure_hook + +from langchain_aws.callbacks.bedrock_callback import ( + BedrockTokenUsageCallbackHandler, +) + +logger = logging.getLogger(__name__) + + +bedrock_callback_var: (ContextVar)[ + Optional[BedrockTokenUsageCallbackHandler] +] = ContextVar("bedrock_anthropic_callback", default=None) + +register_configure_hook(bedrock_callback_var, True) + + +@contextmanager +def get_bedrock_callback() -> ( + Generator[BedrockTokenUsageCallbackHandler, None, None] +): + """Get the Bedrock callback handler in a context manager. + which conveniently exposes token and cost information. + + Returns: + BedrockTokenUsageCallbackHandler: + The Bedrock callback handler. + + Example: + >>> with get_bedrock_callback() as cb: + ... # Use the Bedrock callback handler + """ + cb = BedrockTokenUsageCallbackHandler() + bedrock_callback_var.set(cb) + yield cb + bedrock_callback_var.set(None) \ No newline at end of file diff --git a/libs/aws/langchain_aws/chat_models/bedrock.py b/libs/aws/langchain_aws/chat_models/bedrock.py index 11cb01fb..af8b6e28 100644 --- a/libs/aws/langchain_aws/chat_models/bedrock.py +++ b/libs/aws/langchain_aws/chat_models/bedrock.py @@ -30,7 +30,7 @@ HumanMessage, SystemMessage, ) -from langchain_core.messages.ai import UsageMetadata +from langchain_core.messages.ai import InputTokenDetails, UsageMetadata from langchain_core.messages.tool import ToolCall, ToolMessage from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult from langchain_core.runnables import Runnable, RunnableMap, RunnablePassthrough @@ -578,11 +578,17 @@ def _generate( # usage metadata if usage := llm_output.get("usage"): input_tokens = usage.get("prompt_tokens", 0) + cache_creation = usage.get("prompt_tokens_cache_write", 0) + cache_read = usage.get("prompt_tokens_cache_read", 0) output_tokens = usage.get("completion_tokens", 0) usage_metadata = UsageMetadata( input_tokens=input_tokens, output_tokens=output_tokens, total_tokens=usage.get("total_tokens", input_tokens + output_tokens), + input_token_details=InputTokenDetails( + cache_creation=cache_creation, + cache_read=cache_read, + ), ) else: usage_metadata = None diff --git a/libs/aws/langchain_aws/chat_models/bedrock_converse.py b/libs/aws/langchain_aws/chat_models/bedrock_converse.py index 4748d8b6..f08ea176 100644 --- a/libs/aws/langchain_aws/chat_models/bedrock_converse.py +++ b/libs/aws/langchain_aws/chat_models/bedrock_converse.py @@ -34,7 +34,7 @@ ToolMessage, merge_message_runs, ) -from langchain_core.messages.ai import AIMessageChunk, UsageMetadata +from langchain_core.messages.ai import AIMessageChunk, InputTokenDetails, UsageMetadata from langchain_core.messages.tool import tool_call as create_tool_call from langchain_core.messages.tool import tool_call_chunk from langchain_core.output_parsers import JsonOutputKeyToolsParser, PydanticToolsParser @@ -383,9 +383,9 @@ class Joke(BaseModel): additionalModelResponseFieldPaths. """ - supports_tool_choice_values: Optional[ - Sequence[Literal["auto", "any", "tool"]] - ] = None + supports_tool_choice_values: Optional[Sequence[Literal["auto", "any", "tool"]]] = ( + None + ) """Which types of tool_choice values the model supports. Inferred if not specified. Inferred as ('auto', 'any', 'tool') if a 'claude-3' @@ -502,7 +502,29 @@ def _generate( messages=bedrock_messages, system=system, **params ) response_message = _parse_response(response) - return ChatResult(generations=[ChatGeneration(message=response_message)]) + + usage_metadata: UsageMetadata = ( + response_message.usage_metadata or UsageMetadata() + ) + input_token_details: InputTokenDetails = ( + usage_metadata.get("input_token_details") or InputTokenDetails() + ) + llm_output = { + "usage": { + "prompt_tokens": usage_metadata.get("input_tokens", 0), + "prompt_tokens_cache_write": input_token_details.get( + "cache_creation", 0 + ), + "prompt_tokens_cache_read": input_token_details.get("cache_read", 0), + "completion_tokens": usage_metadata.get("output_tokens", 0), + "total_tokens": usage_metadata.get("total_tokens", 0), + }, + "model_id": self.model_id, + } + return ChatResult( + generations=[ChatGeneration(message=response_message)], + llm_output=llm_output, + ) def _stream( self, @@ -742,13 +764,25 @@ def _extract_response_metadata(response: Dict[str, Any]) -> Dict[str, Any]: return response_metadata +def _extract_usage_metadata(response: Dict[str, Any]) -> UsageMetadata: + usage: Dict[str, int] = response.pop("usage") # type: ignore[misc] + return UsageMetadata( + input_tokens=usage.get("inputTokens", 0), + output_tokens=usage.get("outputTokens", 0), + total_tokens=usage.get("totalTokens", 0), + input_token_details=InputTokenDetails( + cache_creation=usage.get("cacheWriteInputTokensCount", 0), + cache_read=usage.get("cacheReadInputTokensCount", 0), + ), + ) + + def _parse_response(response: Dict[str, Any]) -> AIMessage: lc_content = _bedrock_to_lc(response.pop("output")["message"]["content"]) tool_calls = _extract_tool_calls(lc_content) - usage = UsageMetadata(_camel_to_snake_keys(response.pop("usage"))) # type: ignore[misc] return AIMessage( content=_str_if_single_text_block(lc_content), # type: ignore[arg-type] - usage_metadata=usage, + usage_metadata=_extract_usage_metadata(response), response_metadata=_extract_response_metadata(response), tool_calls=tool_calls, ) @@ -803,9 +837,11 @@ def _parse_stream_event(event: Dict[str, Any]) -> Optional[BaseMessageChunk]: # TODO: snake case response metadata? return AIMessageChunk(content=[], response_metadata=event["messageStop"]) elif "metadata" in event: - usage = UsageMetadata(_camel_to_snake_keys(event["metadata"].pop("usage"))) # type: ignore[misc] + usage_metadata = _extract_usage_metadata(event["metadata"]) return AIMessageChunk( - content=[], response_metadata=event["metadata"], usage_metadata=usage + content=[], + response_metadata=event["metadata"], + usage_metadata=usage_metadata, ) elif "Exception" in list(event.keys())[0]: name, info = list(event.items())[0] diff --git a/libs/aws/langchain_aws/llms/bedrock.py b/libs/aws/langchain_aws/llms/bedrock.py index befb4c51..ff6cd32f 100644 --- a/libs/aws/langchain_aws/llms/bedrock.py +++ b/libs/aws/langchain_aws/llms/bedrock.py @@ -349,6 +349,9 @@ def prepare_output(cls, provider: str, response: Any) -> dict: headers = response.get("ResponseMetadata", {}).get("HTTPHeaders", {}) prompt_tokens = int(headers.get("x-amzn-bedrock-input-token-count", 0)) + # TODO what are the headers for cache write and read, if any? + prompt_tokens_cache_write = int(headers.get("x-amzn-bedrock-XXX", 0)) + prompt_tokens_cache_read = int(headers.get("x-amzn-bedrock-XXX", 0)) completion_tokens = int(headers.get("x-amzn-bedrock-output-token-count", 0)) return { "text": text, @@ -356,6 +359,8 @@ def prepare_output(cls, provider: str, response: Any) -> dict: "body": response_body, "usage": { "prompt_tokens": prompt_tokens, + "prompt_tokens_cache_write": prompt_tokens_cache_write, + "prompt_tokens_cache_read": prompt_tokens_cache_read, "completion_tokens": completion_tokens, "total_tokens": prompt_tokens + completion_tokens, },