Skip to content

Commit

Permalink
Periodically close open request.Sessions to avoid buggy interaction…
Browse files Browse the repository at this point in the history
… with Docker Desktop (openai#478)

* Periodically refresh open `requests.Session`s to mitigate open filehandle issues (openai#179)

As reported, we create a `requests.Session` object on first request to the servers and then reuse it indefinitely. This can leave some open file handles on the OS (not a big deal), but can interact poorly with a bug in Docker Desktop which causes the SDK to entierly break connections to the server.
See openai#140 for more info.

The order of items in the API responses is intentional, and this order is clobbered by the rendering of `OpenAIObject`. This change removes the alphabetic sort of response keys
  • Loading branch information
jhallard authored and megamanics committed Aug 14, 2024
1 parent 653306c commit 778ef67
Show file tree
Hide file tree
Showing 4 changed files with 68 additions and 1 deletion.
10 changes: 10 additions & 0 deletions openai/api_requestor.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import platform
import sys
import threading
import time
import warnings
from contextlib import asynccontextmanager
from json import JSONDecodeError
Expand Down Expand Up @@ -32,6 +33,7 @@
from openai.util import ApiType

TIMEOUT_SECS = 600
MAX_SESSION_LIFETIME_SECS = 180
MAX_CONNECTION_RETRIES = 2

# Has one attribute per thread, 'session'.
Expand Down Expand Up @@ -516,6 +518,14 @@ def request_raw(

if not hasattr(_thread_context, "session"):
_thread_context.session = _make_session()
_thread_context.session_create_time = time.time()
elif (
time.time() - getattr(_thread_context, "session_create_time", 0)
>= MAX_SESSION_LIFETIME_SECS
):
_thread_context.session.close()
_thread_context.session = _make_session()
_thread_context.session_create_time = time.time()
try:
result = _thread_context.session.request(
method,
Expand Down
2 changes: 1 addition & 1 deletion openai/openai_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,7 +278,7 @@ def __repr__(self):

def __str__(self):
obj = self.to_dict_recursive()
return json.dumps(obj, sort_keys=True, indent=2)
return json.dumps(obj, indent=2)

def to_dict(self):
return dict(self)
Expand Down
32 changes: 32 additions & 0 deletions openai/tests/test_api_requestor.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,3 +67,35 @@ def test_requestor_azure_ad_headers() -> None:
assert headers["Test_Header"] == "Unit_Test_Header"
assert "Authorization" in headers
assert headers["Authorization"] == "Bearer test_key"


@pytest.mark.requestor
def test_requestor_cycle_sessions(mocker: MockerFixture) -> None:
# HACK: we need to purge the _thread_context to not interfere
# with other tests
from openai.api_requestor import _thread_context

delattr(_thread_context, "session")

api_requestor = APIRequestor(key="test_key", api_type="azure_ad")

mock_session = mocker.MagicMock()
mocker.patch("openai.api_requestor._make_session", lambda: mock_session)

# We don't call `session.close()` if not enough time has elapsed
api_requestor.request_raw("get", "http://example.com")
mock_session.request.assert_called()
api_requestor.request_raw("get", "http://example.com")
mock_session.close.assert_not_called()

mocker.patch("openai.api_requestor.MAX_SESSION_LIFETIME_SECS", 0)

# Due to 0 lifetime, the original session will be closed before the next call
# and a new session will be created
mock_session_2 = mocker.MagicMock()
mocker.patch("openai.api_requestor._make_session", lambda: mock_session_2)
api_requestor.request_raw("get", "http://example.com")
mock_session.close.assert_called()
mock_session_2.request.assert_called()

delattr(_thread_context, "session")
25 changes: 25 additions & 0 deletions openai/tests/test_util.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import json
from tempfile import NamedTemporaryFile

import pytest
Expand Down Expand Up @@ -28,3 +29,27 @@ def test_openai_api_key_path_with_malformed_key(api_key_file) -> None:
api_key_file.flush()
with pytest.raises(ValueError, match="Malformed API key"):
util.default_api_key()


def test_key_order_openai_object_rendering() -> None:
sample_response = {
"id": "chatcmpl-7NaPEA6sgX7LnNPyKPbRlsyqLbr5V",
"object": "chat.completion",
"created": 1685855844,
"model": "gpt-3.5-turbo-0301",
"usage": {"prompt_tokens": 57, "completion_tokens": 40, "total_tokens": 97},
"choices": [
{
"message": {
"role": "assistant",
"content": "The 2020 World Series was played at Globe Life Field in Arlington, Texas. It was the first time that the World Series was played at a neutral site because of the COVID-19 pandemic.",
},
"finish_reason": "stop",
"index": 0,
}
],
}

oai_object = util.convert_to_openai_object(sample_response)
# The `__str__` method was sorting while dumping to json
assert list(json.loads(str(oai_object)).keys()) == list(sample_response.keys())

0 comments on commit 778ef67

Please sign in to comment.