From 948357378f2234e7ddc3843c0427cfa0b9914a21 Mon Sep 17 00:00:00 2001 From: "stainless-app[bot]" <142633134+stainless-app[bot]@users.noreply.github.com> Date: Fri, 28 Jun 2024 12:40:50 +0000 Subject: [PATCH] chore(internal): add reflection helper function (#565) --- src/anthropic/_utils/__init__.py | 5 ++++- src/anthropic/_utils/_reflection.py | 34 +++++++++++++++++++++++++++++ 2 files changed, 38 insertions(+), 1 deletion(-) diff --git a/src/anthropic/_utils/__init__.py b/src/anthropic/_utils/__init__.py index 667e2473..3efe66c8 100644 --- a/src/anthropic/_utils/__init__.py +++ b/src/anthropic/_utils/__init__.py @@ -49,4 +49,7 @@ maybe_transform as maybe_transform, async_maybe_transform as async_maybe_transform, ) -from ._reflection import function_has_argument as function_has_argument +from ._reflection import ( + function_has_argument as function_has_argument, + assert_signatures_in_sync as assert_signatures_in_sync, +) diff --git a/src/anthropic/_utils/_reflection.py b/src/anthropic/_utils/_reflection.py index e134f58e..9a53c7bd 100644 --- a/src/anthropic/_utils/_reflection.py +++ b/src/anthropic/_utils/_reflection.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import inspect from typing import Any, Callable @@ -6,3 +8,35 @@ def function_has_argument(func: Callable[..., Any], arg_name: str) -> bool: """Returns whether or not the given function has a specific parameter""" sig = inspect.signature(func) return arg_name in sig.parameters + + +def assert_signatures_in_sync( + source_func: Callable[..., Any], + check_func: Callable[..., Any], + *, + exclude_params: set[str] = set(), +) -> None: + """Ensure that the signature of the second function matches the first.""" + + check_sig = inspect.signature(check_func) + source_sig = inspect.signature(source_func) + + errors: list[str] = [] + + for name, source_param in source_sig.parameters.items(): + if name in exclude_params: + continue + + custom_param = check_sig.parameters.get(name) + if not custom_param: + errors.append(f"the `{name}` param is missing") + continue + + if custom_param.annotation != source_param.annotation: + errors.append( + f"types for the `{name}` param are do not match; source={repr(source_param.annotation)} checking={repr(source_param.annotation)}" + ) + continue + + if errors: + raise AssertionError(f"{len(errors)} errors encountered when comparing signatures:\n\n" + "\n\n".join(errors))