Skip to content

Commit

Permalink
Add Batch Size kwarg to the llm start callback (#13483)
Browse files Browse the repository at this point in the history
So you can more easily use the token counts directly from the API
endpoint for batch size of 1
  • Loading branch information
hinthornw authored Nov 22, 2023
1 parent 23566cb commit 163bf16
Show file tree
Hide file tree
Showing 6 changed files with 147 additions and 8 deletions.
4 changes: 4 additions & 0 deletions libs/core/langchain_core/language_models/chat_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,7 @@ def stream(
invocation_params=params,
options=options,
name=config.get("run_name"),
batch_size=1,
)
try:
generation: Optional[ChatGenerationChunk] = None
Expand Down Expand Up @@ -259,6 +260,7 @@ async def astream(
invocation_params=params,
options=options,
name=config.get("run_name"),
batch_size=1,
)
try:
generation: Optional[ChatGenerationChunk] = None
Expand Down Expand Up @@ -334,6 +336,7 @@ def generate(
invocation_params=params,
options=options,
name=run_name,
batch_size=len(messages),
)
results = []
for i, m in enumerate(messages):
Expand Down Expand Up @@ -396,6 +399,7 @@ async def agenerate(
invocation_params=params,
options=options,
name=run_name,
batch_size=len(messages),
)

results = await asyncio.gather(
Expand Down
6 changes: 6 additions & 0 deletions libs/core/langchain_core/language_models/llms.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,6 +382,7 @@ def stream(
invocation_params=params,
options=options,
name=config.get("run_name"),
batch_size=1,
)
try:
generation: Optional[GenerationChunk] = None
Expand Down Expand Up @@ -433,6 +434,7 @@ async def astream(
invocation_params=params,
options=options,
name=config.get("run_name"),
batch_size=1,
)
try:
generation: Optional[GenerationChunk] = None
Expand Down Expand Up @@ -645,6 +647,7 @@ def generate(
invocation_params=params,
options=options,
name=run_name,
batch_size=len(prompts),
)[0]
for callback_manager, prompt, run_name in zip(
callback_managers, prompts, run_name_list
Expand All @@ -662,6 +665,7 @@ def generate(
invocation_params=params,
options=options,
name=run_name_list[idx],
batch_size=len(missing_prompts),
)[0]
for idx in missing_prompt_idxs
]
Expand Down Expand Up @@ -810,6 +814,7 @@ async def agenerate(
invocation_params=params,
options=options,
name=run_name,
batch_size=len(prompts),
)
for callback_manager, prompt, run_name in zip(
callback_managers, prompts, run_name_list
Expand All @@ -830,6 +835,7 @@ async def agenerate(
invocation_params=params,
options=options,
name=run_name_list[idx],
batch_size=len(missing_prompts),
)
for idx in missing_prompt_idxs
]
Expand Down
Empty file.
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
"""Test base chat model."""
import pytest

from langchain_core.messages import HumanMessage, SystemMessage
from langchain_core.tracers.context import collect_runs
from tests.unit_tests.fake.chat_model import FakeListChatModel


@pytest.fixture
def messages() -> list:
return [
SystemMessage(content="You are a test user."),
HumanMessage(content="Hello, I am a test user."),
]


@pytest.fixture
def messages_2() -> list:
return [
SystemMessage(content="You are a test user."),
HumanMessage(content="Hello, I not a test user."),
]


def test_batch_size(messages: list, messages_2: list) -> None:
# The base endpoint doesn't support native batching,
# so we expect batch_size to always be 1
llm = FakeListChatModel(responses=[str(i) for i in range(100)])
with collect_runs() as cb:
llm.batch([messages, messages_2], {"callbacks": [cb]})
assert len(cb.traced_runs) == 2
assert all([(r.extra or {}).get("batch_size") == 1 for r in cb.traced_runs])
with collect_runs() as cb:
llm.batch([messages], {"callbacks": [cb]})
assert all([(r.extra or {}).get("batch_size") == 1 for r in cb.traced_runs])
assert len(cb.traced_runs) == 1

with collect_runs() as cb:
llm.invoke(messages)
assert len(cb.traced_runs) == 1
assert (cb.traced_runs[0].extra or {}).get("batch_size") == 1

with collect_runs() as cb:
list(llm.stream(messages))
assert len(cb.traced_runs) == 1
assert (cb.traced_runs[0].extra or {}).get("batch_size") == 1


async def test_async_batch_size(messages: list, messages_2: list) -> None:
llm = FakeListChatModel(responses=[str(i) for i in range(100)])
# The base endpoint doesn't support native batching,
# so we expect batch_size to always be 1
with collect_runs() as cb:
await llm.abatch([messages, messages_2], {"callbacks": [cb]})
assert all([(r.extra or {}).get("batch_size") == 1 for r in cb.traced_runs])
assert len(cb.traced_runs) == 2
with collect_runs() as cb:
await llm.abatch([messages], {"callbacks": [cb]})
assert all([(r.extra or {}).get("batch_size") == 1 for r in cb.traced_runs])
assert len(cb.traced_runs) == 1

with collect_runs() as cb:
await llm.ainvoke(messages)
assert len(cb.traced_runs) == 1
assert (cb.traced_runs[0].extra or {}).get("batch_size") == 1

with collect_runs() as cb:
async for _ in llm.astream(messages):
pass
assert len(cb.traced_runs) == 1
assert (cb.traced_runs[0].extra or {}).get("batch_size") == 1
58 changes: 58 additions & 0 deletions libs/core/tests/unit_tests/language_models/llms/test_base.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from langchain_core.tracers.context import collect_runs
from tests.unit_tests.fake.llm import FakeListLLM


Expand All @@ -17,3 +18,60 @@ async def test_abatch() -> None:

output = await llm.abatch(["foo", "bar", "foo"], config={"max_concurrency": 2})
assert output == ["foo"] * 3


def test_batch_size() -> None:
llm = FakeListLLM(responses=["foo"] * 3)
with collect_runs() as cb:
llm.batch(["foo", "bar", "foo"], {"callbacks": [cb]})
assert all([(r.extra or {}).get("batch_size") == 3 for r in cb.traced_runs])
assert len(cb.traced_runs) == 3
llm = FakeListLLM(responses=["foo"])
with collect_runs() as cb:
llm.batch(["foo"], {"callbacks": [cb]})
assert all([(r.extra or {}).get("batch_size") == 1 for r in cb.traced_runs])
assert len(cb.traced_runs) == 1

llm = FakeListLLM(responses=["foo"])
with collect_runs() as cb:
llm.invoke("foo")
assert len(cb.traced_runs) == 1
assert (cb.traced_runs[0].extra or {}).get("batch_size") == 1

llm = FakeListLLM(responses=["foo"])
with collect_runs() as cb:
list(llm.stream("foo"))
assert len(cb.traced_runs) == 1
assert (cb.traced_runs[0].extra or {}).get("batch_size") == 1

llm = FakeListLLM(responses=["foo"] * 1)
with collect_runs() as cb:
llm.predict("foo")
assert len(cb.traced_runs) == 1
assert (cb.traced_runs[0].extra or {}).get("batch_size") == 1


async def test_async_batch_size() -> None:
llm = FakeListLLM(responses=["foo"] * 3)
with collect_runs() as cb:
await llm.abatch(["foo", "bar", "foo"], {"callbacks": [cb]})
assert all([(r.extra or {}).get("batch_size") == 3 for r in cb.traced_runs])
assert len(cb.traced_runs) == 3
llm = FakeListLLM(responses=["foo"])
with collect_runs() as cb:
await llm.abatch(["foo"], {"callbacks": [cb]})
assert all([(r.extra or {}).get("batch_size") == 1 for r in cb.traced_runs])
assert len(cb.traced_runs) == 1

llm = FakeListLLM(responses=["foo"])
with collect_runs() as cb:
await llm.ainvoke("foo")
assert len(cb.traced_runs) == 1
assert (cb.traced_runs[0].extra or {}).get("batch_size") == 1

llm = FakeListLLM(responses=["foo"])
with collect_runs() as cb:
async for _ in llm.astream("foo"):
pass
assert len(cb.traced_runs) == 1
assert (cb.traced_runs[0].extra or {}).get("batch_size") == 1

Large diffs are not rendered by default.

0 comments on commit 163bf16

Please sign in to comment.