Skip to content

Commit

Permalink
chore(internal): share client instances between all tests (#314)
Browse files Browse the repository at this point in the history
  • Loading branch information
stainless-bot authored Jan 18, 2024
1 parent b13f824 commit ccf731b
Show file tree
Hide file tree
Showing 4 changed files with 76 additions and 61 deletions.
42 changes: 18 additions & 24 deletions tests/api_resources/beta/test_messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,17 +9,13 @@

from anthropic import Anthropic, AsyncAnthropic
from tests.utils import assert_matches_type
from anthropic._client import Anthropic, AsyncAnthropic
from anthropic.types.beta import Message

base_url = os.environ.get("TEST_API_BASE_URL", "http://127.0.0.1:4010")
api_key = "my-anthropic-api-key"


class TestMessages:
strict_client = Anthropic(base_url=base_url, api_key=api_key, _strict_response_validation=True)
loose_client = Anthropic(base_url=base_url, api_key=api_key, _strict_response_validation=False)
parametrize = pytest.mark.parametrize("client", [strict_client, loose_client], ids=["strict", "loose"])
parametrize = pytest.mark.parametrize("client", [False, True], indirect=True, ids=["loose", "strict"])

@parametrize
def test_method_create_overload_1(self, client: Anthropic) -> None:
Expand Down Expand Up @@ -171,13 +167,11 @@ def test_streaming_response_create_overload_2(self, client: Anthropic) -> None:


class TestAsyncMessages:
strict_client = AsyncAnthropic(base_url=base_url, api_key=api_key, _strict_response_validation=True)
loose_client = AsyncAnthropic(base_url=base_url, api_key=api_key, _strict_response_validation=False)
parametrize = pytest.mark.parametrize("client", [strict_client, loose_client], ids=["strict", "loose"])
parametrize = pytest.mark.parametrize("async_client", [False, True], indirect=True, ids=["loose", "strict"])

@parametrize
async def test_method_create_overload_1(self, client: AsyncAnthropic) -> None:
message = await client.beta.messages.create(
async def test_method_create_overload_1(self, async_client: AsyncAnthropic) -> None:
message = await async_client.beta.messages.create(
max_tokens=1024,
messages=[
{
Expand All @@ -190,8 +184,8 @@ async def test_method_create_overload_1(self, client: AsyncAnthropic) -> None:
assert_matches_type(Message, message, path=["response"])

@parametrize
async def test_method_create_with_all_params_overload_1(self, client: AsyncAnthropic) -> None:
message = await client.beta.messages.create(
async def test_method_create_with_all_params_overload_1(self, async_client: AsyncAnthropic) -> None:
message = await async_client.beta.messages.create(
max_tokens=1024,
messages=[
{
Expand All @@ -211,8 +205,8 @@ async def test_method_create_with_all_params_overload_1(self, client: AsyncAnthr
assert_matches_type(Message, message, path=["response"])

@parametrize
async def test_raw_response_create_overload_1(self, client: AsyncAnthropic) -> None:
response = await client.beta.messages.with_raw_response.create(
async def test_raw_response_create_overload_1(self, async_client: AsyncAnthropic) -> None:
response = await async_client.beta.messages.with_raw_response.create(
max_tokens=1024,
messages=[
{
Expand All @@ -229,8 +223,8 @@ async def test_raw_response_create_overload_1(self, client: AsyncAnthropic) -> N
assert_matches_type(Message, message, path=["response"])

@parametrize
async def test_streaming_response_create_overload_1(self, client: AsyncAnthropic) -> None:
async with client.beta.messages.with_streaming_response.create(
async def test_streaming_response_create_overload_1(self, async_client: AsyncAnthropic) -> None:
async with async_client.beta.messages.with_streaming_response.create(
max_tokens=1024,
messages=[
{
Expand All @@ -249,8 +243,8 @@ async def test_streaming_response_create_overload_1(self, client: AsyncAnthropic
assert cast(Any, response.is_closed) is True

@parametrize
async def test_method_create_overload_2(self, client: AsyncAnthropic) -> None:
message_stream = await client.beta.messages.create(
async def test_method_create_overload_2(self, async_client: AsyncAnthropic) -> None:
message_stream = await async_client.beta.messages.create(
max_tokens=1024,
messages=[
{
Expand All @@ -264,8 +258,8 @@ async def test_method_create_overload_2(self, client: AsyncAnthropic) -> None:
await message_stream.response.aclose()

@parametrize
async def test_method_create_with_all_params_overload_2(self, client: AsyncAnthropic) -> None:
message_stream = await client.beta.messages.create(
async def test_method_create_with_all_params_overload_2(self, async_client: AsyncAnthropic) -> None:
message_stream = await async_client.beta.messages.create(
max_tokens=1024,
messages=[
{
Expand All @@ -285,8 +279,8 @@ async def test_method_create_with_all_params_overload_2(self, client: AsyncAnthr
await message_stream.response.aclose()

@parametrize
async def test_raw_response_create_overload_2(self, client: AsyncAnthropic) -> None:
response = await client.beta.messages.with_raw_response.create(
async def test_raw_response_create_overload_2(self, async_client: AsyncAnthropic) -> None:
response = await async_client.beta.messages.with_raw_response.create(
max_tokens=1024,
messages=[
{
Expand All @@ -303,8 +297,8 @@ async def test_raw_response_create_overload_2(self, client: AsyncAnthropic) -> N
await stream.close()

@parametrize
async def test_streaming_response_create_overload_2(self, client: AsyncAnthropic) -> None:
async with client.beta.messages.with_streaming_response.create(
async def test_streaming_response_create_overload_2(self, async_client: AsyncAnthropic) -> None:
async with async_client.beta.messages.with_streaming_response.create(
max_tokens=1024,
messages=[
{
Expand Down
42 changes: 18 additions & 24 deletions tests/api_resources/test_completions.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,16 +10,12 @@
from anthropic import Anthropic, AsyncAnthropic
from tests.utils import assert_matches_type
from anthropic.types import Completion
from anthropic._client import Anthropic, AsyncAnthropic

base_url = os.environ.get("TEST_API_BASE_URL", "http://127.0.0.1:4010")
api_key = "my-anthropic-api-key"


class TestCompletions:
strict_client = Anthropic(base_url=base_url, api_key=api_key, _strict_response_validation=True)
loose_client = Anthropic(base_url=base_url, api_key=api_key, _strict_response_validation=False)
parametrize = pytest.mark.parametrize("client", [strict_client, loose_client], ids=["strict", "loose"])
parametrize = pytest.mark.parametrize("client", [False, True], indirect=True, ids=["loose", "strict"])

@parametrize
def test_method_create_overload_1(self, client: Anthropic) -> None:
Expand Down Expand Up @@ -129,22 +125,20 @@ def test_streaming_response_create_overload_2(self, client: Anthropic) -> None:


class TestAsyncCompletions:
strict_client = AsyncAnthropic(base_url=base_url, api_key=api_key, _strict_response_validation=True)
loose_client = AsyncAnthropic(base_url=base_url, api_key=api_key, _strict_response_validation=False)
parametrize = pytest.mark.parametrize("client", [strict_client, loose_client], ids=["strict", "loose"])
parametrize = pytest.mark.parametrize("async_client", [False, True], indirect=True, ids=["loose", "strict"])

@parametrize
async def test_method_create_overload_1(self, client: AsyncAnthropic) -> None:
completion = await client.completions.create(
async def test_method_create_overload_1(self, async_client: AsyncAnthropic) -> None:
completion = await async_client.completions.create(
max_tokens_to_sample=256,
model="claude-2.1",
prompt="\n\nHuman: Hello, world!\n\nAssistant:",
)
assert_matches_type(Completion, completion, path=["response"])

@parametrize
async def test_method_create_with_all_params_overload_1(self, client: AsyncAnthropic) -> None:
completion = await client.completions.create(
async def test_method_create_with_all_params_overload_1(self, async_client: AsyncAnthropic) -> None:
completion = await async_client.completions.create(
max_tokens_to_sample=256,
model="claude-2.1",
prompt="\n\nHuman: Hello, world!\n\nAssistant:",
Expand All @@ -158,8 +152,8 @@ async def test_method_create_with_all_params_overload_1(self, client: AsyncAnthr
assert_matches_type(Completion, completion, path=["response"])

@parametrize
async def test_raw_response_create_overload_1(self, client: AsyncAnthropic) -> None:
response = await client.completions.with_raw_response.create(
async def test_raw_response_create_overload_1(self, async_client: AsyncAnthropic) -> None:
response = await async_client.completions.with_raw_response.create(
max_tokens_to_sample=256,
model="claude-2.1",
prompt="\n\nHuman: Hello, world!\n\nAssistant:",
Expand All @@ -171,8 +165,8 @@ async def test_raw_response_create_overload_1(self, client: AsyncAnthropic) -> N
assert_matches_type(Completion, completion, path=["response"])

@parametrize
async def test_streaming_response_create_overload_1(self, client: AsyncAnthropic) -> None:
async with client.completions.with_streaming_response.create(
async def test_streaming_response_create_overload_1(self, async_client: AsyncAnthropic) -> None:
async with async_client.completions.with_streaming_response.create(
max_tokens_to_sample=256,
model="claude-2.1",
prompt="\n\nHuman: Hello, world!\n\nAssistant:",
Expand All @@ -186,8 +180,8 @@ async def test_streaming_response_create_overload_1(self, client: AsyncAnthropic
assert cast(Any, response.is_closed) is True

@parametrize
async def test_method_create_overload_2(self, client: AsyncAnthropic) -> None:
completion_stream = await client.completions.create(
async def test_method_create_overload_2(self, async_client: AsyncAnthropic) -> None:
completion_stream = await async_client.completions.create(
max_tokens_to_sample=256,
model="claude-2.1",
prompt="\n\nHuman: Hello, world!\n\nAssistant:",
Expand All @@ -196,8 +190,8 @@ async def test_method_create_overload_2(self, client: AsyncAnthropic) -> None:
await completion_stream.response.aclose()

@parametrize
async def test_method_create_with_all_params_overload_2(self, client: AsyncAnthropic) -> None:
completion_stream = await client.completions.create(
async def test_method_create_with_all_params_overload_2(self, async_client: AsyncAnthropic) -> None:
completion_stream = await async_client.completions.create(
max_tokens_to_sample=256,
model="claude-2.1",
prompt="\n\nHuman: Hello, world!\n\nAssistant:",
Expand All @@ -211,8 +205,8 @@ async def test_method_create_with_all_params_overload_2(self, client: AsyncAnthr
await completion_stream.response.aclose()

@parametrize
async def test_raw_response_create_overload_2(self, client: AsyncAnthropic) -> None:
response = await client.completions.with_raw_response.create(
async def test_raw_response_create_overload_2(self, async_client: AsyncAnthropic) -> None:
response = await async_client.completions.with_raw_response.create(
max_tokens_to_sample=256,
model="claude-2.1",
prompt="\n\nHuman: Hello, world!\n\nAssistant:",
Expand All @@ -224,8 +218,8 @@ async def test_raw_response_create_overload_2(self, client: AsyncAnthropic) -> N
await stream.close()

@parametrize
async def test_streaming_response_create_overload_2(self, client: AsyncAnthropic) -> None:
async with client.completions.with_streaming_response.create(
async def test_streaming_response_create_overload_2(self, async_client: AsyncAnthropic) -> None:
async with async_client.completions.with_streaming_response.create(
max_tokens_to_sample=256,
model="claude-2.1",
prompt="\n\nHuman: Hello, world!\n\nAssistant:",
Expand Down
18 changes: 6 additions & 12 deletions tests/api_resources/test_top_level.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,27 +7,21 @@
import pytest

from anthropic import Anthropic, AsyncAnthropic
from anthropic._client import Anthropic, AsyncAnthropic

base_url = os.environ.get("TEST_API_BASE_URL", "http://127.0.0.1:4010")
api_key = "my-anthropic-api-key"


class TestTopLevel:
strict_client = Anthropic(base_url=base_url, api_key=api_key, _strict_response_validation=True)
loose_client = Anthropic(base_url=base_url, api_key=api_key, _strict_response_validation=False)
parametrize = pytest.mark.parametrize("client", [strict_client, loose_client], ids=["strict", "loose"])
parametrize = pytest.mark.parametrize("client", [False, True], indirect=True, ids=["loose", "strict"])

def test_count_tokens(self) -> None:
tokens = self.strict_client.count_tokens("hello world!")
def test_count_tokens(self, client: Anthropic) -> None:
tokens = client.count_tokens("hello world!")
assert tokens == 3


class TestAsyncTopLevel:
strict_client = AsyncAnthropic(base_url=base_url, api_key=api_key, _strict_response_validation=True)
loose_client = AsyncAnthropic(base_url=base_url, api_key=api_key, _strict_response_validation=False)
parametrize = pytest.mark.parametrize("client", [strict_client, loose_client], ids=["strict", "loose"])
parametrize = pytest.mark.parametrize("async_client", [False, True], indirect=True, ids=["loose", "strict"])

async def test_count_tokens(self) -> None:
tokens = await self.strict_client.count_tokens("hello world!")
async def test_count_tokens(self, async_client: AsyncAnthropic) -> None:
tokens = await async_client.count_tokens("hello world!")
assert tokens == 3
35 changes: 34 additions & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,17 @@
from __future__ import annotations

import os
import asyncio
import logging
from typing import Iterator
from typing import TYPE_CHECKING, Iterator, AsyncIterator

import pytest

from anthropic import Anthropic, AsyncAnthropic

if TYPE_CHECKING:
from _pytest.fixtures import FixtureRequest

pytest.register_assert_rewrite("tests.utils")

logging.getLogger("anthropic").setLevel(logging.DEBUG)
Expand All @@ -14,3 +22,28 @@ def event_loop() -> Iterator[asyncio.AbstractEventLoop]:
loop = asyncio.new_event_loop()
yield loop
loop.close()


base_url = os.environ.get("TEST_API_BASE_URL", "http://127.0.0.1:4010")

api_key = "my-anthropic-api-key"


@pytest.fixture(scope="session")
def client(request: FixtureRequest) -> Iterator[Anthropic]:
strict = getattr(request, "param", True)
if not isinstance(strict, bool):
raise TypeError(f"Unexpected fixture parameter type {type(strict)}, expected {bool}")

with Anthropic(base_url=base_url, api_key=api_key, _strict_response_validation=strict) as client:
yield client


@pytest.fixture(scope="session")
async def async_client(request: FixtureRequest) -> AsyncIterator[AsyncAnthropic]:
strict = getattr(request, "param", True)
if not isinstance(strict, bool):
raise TypeError(f"Unexpected fixture parameter type {type(strict)}, expected {bool}")

async with AsyncAnthropic(base_url=base_url, api_key=api_key, _strict_response_validation=strict) as client:
yield client

0 comments on commit ccf731b

Please sign in to comment.