Skip to content

Commit

Permalink
mrp: Add retry to heartbeats (#978)
Browse files Browse the repository at this point in the history
Fixes #977
  • Loading branch information
postlund authored Feb 23, 2021
1 parent b0eaaef commit 7184717
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 7 deletions.
20 changes: 16 additions & 4 deletions pyatv/mrp/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
_LOGGER = logging.getLogger(__name__)

HEARTBEAT_INTERVAL = 30
HEARTBEAT_RETRIES = 1 # One regular attempt + retries

Listener = namedtuple("Listener", "func data")
OutstandingMessage = namedtuple("OutstandingMessage", "semaphore response")
Expand All @@ -23,22 +24,33 @@ async def heartbeat_loop(protocol):
"""Periodically send heartbeat messages to device."""
_LOGGER.debug("Starting heartbeat loop")
count = 0
attempts = 0
message = messages.create(protobuf.ProtocolMessage.SEND_COMMAND_MESSAGE)
while True:
try:
await asyncio.sleep(HEARTBEAT_INTERVAL)
# Re-attempts are made with no initial delay to more quickly
# recover a failed heartbeat (if possible)
if attempts == 0:
await asyncio.sleep(HEARTBEAT_INTERVAL)

_LOGGER.debug("Sending periodic heartbeat %d", count)
await protocol.send_and_receive(message)
_LOGGER.debug("Got heartbeat %d", count)
except asyncio.CancelledError:
break
except Exception:
_LOGGER.exception(f"heartbeat {count} failed")
protocol.connection.close()
break
attempts += 1
if attempts > HEARTBEAT_RETRIES:
_LOGGER.error(f"heartbeat {count} failed after {attempts} tries")
protocol.connection.close()
break
else:
_LOGGER.debug(f"heartbeat {count} failed")
else:
attempts = 0
finally:
count += 1

_LOGGER.debug("Stopping heartbeat loop at %d", count)


Expand Down
23 changes: 21 additions & 2 deletions tests/mrp/test_protocol.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,20 @@
"""Unittests for pyatv.mrp.protocol."""
from unittest.mock import MagicMock

import pytest

from pyatv.conf import MrpService
from pyatv.const import Protocol
from pyatv.mrp.connection import MrpConnection
from pyatv.mrp.protocol import MrpProtocol
from pyatv.mrp.protocol import (
HEARTBEAT_INTERVAL,
HEARTBEAT_RETRIES,
MrpProtocol,
heartbeat_loop,
)
from pyatv.mrp.srp import SRPAuthHandler

from tests.utils import until, stub_sleep
from tests.utils import until, total_sleep_time
from tests.fake_device import FakeAppleTV


Expand Down Expand Up @@ -37,3 +44,15 @@ async def test_heartbeat_loop(mrp_atv, mrp_protocol):

mrp_state = mrp_atv.get_state(Protocol.MRP)
await until(lambda: mrp_state.heartbeat_count >= 3)


@pytest.mark.asyncio
async def test_heartbeat_fail_closes_connection(stub_sleep):
protocol = MagicMock()
protocol.send_and_receive.side_effect = Exception()

await heartbeat_loop(protocol)
assert protocol.send_and_receive.call_count == 1 + HEARTBEAT_RETRIES
assert total_sleep_time() == HEARTBEAT_INTERVAL

protocol.connection.close.assert_called_once()
12 changes: 11 additions & 1 deletion tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
async def _fake_sleep(time: float = None, loop=None):
async def dummy():
_fake_sleep._sleep_time.insert(0, time)
_fake_sleep._total_sleep += time

await asyncio.ensure_future(dummy())

Expand All @@ -28,20 +29,29 @@ def stub_sleep(fn=None) -> float:
if asyncio.sleep == _fake_sleep:
if not hasattr(asyncio.sleep, "_sleep_time"):
asyncio.sleep._sleep_time = [0.0]
asyncio.sleep._total_sleep = 0.0
if len(asyncio.sleep._sleep_time) == 1:
return asyncio.sleep._sleep_time[0]
return asyncio.sleep._sleep_time.pop()

return 0.0


def unstub_sleep():
def unstub_sleep() -> None:
"""Restore original asyncio.sleep method."""
if asyncio.sleep == _fake_sleep:
asyncio.sleep._sleep_time = [0.0]
asyncio.sleep._total_sleep = 0.0
asyncio.sleep = real_sleep


def total_sleep_time() -> float:
"""Return total amount of fake time slept."""
if asyncio.sleep == _fake_sleep:
return _fake_sleep._total_sleep
return 0.0


async def simple_get(url):
"""Perform a GET-request to a specified URL."""
async with ClientSession() as session:
Expand Down

0 comments on commit 7184717

Please sign in to comment.