Skip to content

Commit

Permalink
Ensure BLE device is disconnected after unhandled connect exception (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
bdraco authored Nov 23, 2024
1 parent b2505a1 commit 71212d8
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 1 deletion.
18 changes: 17 additions & 1 deletion aioesphomeapi/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -580,6 +580,7 @@ async def bluetooth_device_connect( # pylint: disable=too-many-locals, too-many
)
timeout_expired = False
connect_ok = False
unhandled_exception = False
try:
await connect_future
connect_ok = True
Expand All @@ -606,11 +607,18 @@ async def bluetooth_device_connect( # pylint: disable=too-many-locals, too-many
f"after {timeout}s, disconnect timed out: {disconnect_timed_out}, "
f" after {disconnect_timeout}s"
) from err
except BaseException:
unhandled_exception = True
raise
finally:
if not connect_ok and not timeout_expired:
if unhandled_exception or (not connect_ok and not timeout_expired):
unsub()
if not timeout_expired:
timeout_handle.cancel()
if unhandled_exception:
# Make sure to disconnect if we had an unhandled exception
# as otherwise the connection will be left open.
self._bluetooth_disconnect_no_wait(address)

return unsub

Expand Down Expand Up @@ -717,6 +725,14 @@ def _raise_for_ble_connection_change(
f"({response.error})"
)

def _bluetooth_disconnect_no_wait(self, address: int) -> None:
"""Disconnect from a Bluetooth device without waiting for a response."""
self._get_connection().send_message(
BluetoothDeviceRequest(
address=address, request_type=BluetoothDeviceRequestType.DISCONNECT
)
)

async def _bluetooth_device_request(
self,
address: int,
Expand Down
7 changes: 7 additions & 0 deletions tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2194,6 +2194,13 @@ def on_bluetooth_connection_state(connected: bool, mtu: int, error: int) -> None
await connect_task
assert states == []

# Ensure the disconnect request is written
assert len(transport.writelines.mock_calls) == 2
req = BluetoothDeviceRequest(
address=1234, request_type=BluetoothDeviceRequestType.DISCONNECT
).SerializeToString()
assert transport.writelines.mock_calls[-1] == call([b"\x00", b"\x05", b"D", req])

handlers_after = len(list(itertools.chain(*connection._message_handlers.values())))
# Make sure we do not leak message handlers
assert handlers_after == handlers_before
Expand Down

0 comments on commit 71212d8

Please sign in to comment.