Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Delete push actions on receipt #13834

Closed
1 change: 1 addition & 0 deletions changelog.d/13834.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Clear out pending push actions when processing read receipts, removing extra checks during push action processing. Contributed by Nick @ Beeper (@fizzadar).
87 changes: 0 additions & 87 deletions synapse/storage/databases/main/event_push_actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,32 +119,6 @@
]


@attr.s(slots=True, auto_attribs=True)
class _RoomReceipt:
"""
HttpPushAction instances include the information used to generate HTTP
requests to a push gateway.
"""

unthreaded_stream_ordering: int = 0
# threaded_stream_ordering includes the main pseudo-thread.
threaded_stream_ordering: Dict[str, int] = attr.Factory(dict)

def is_unread(self, thread_id: str, stream_ordering: int) -> bool:
"""Returns True if the stream ordering is unread according to the receipt information."""

# Only include push actions with a stream ordering after both the unthreaded
# and threaded receipt. Properly handles a user without any receipts present.
return (
self.unthreaded_stream_ordering < stream_ordering
and self.threaded_stream_ordering.get(thread_id, 0) < stream_ordering
)


# A _RoomReceipt with no receipts in it.
MISSING_ROOM_RECEIPT = _RoomReceipt()


@attr.s(slots=True, frozen=True, auto_attribs=True)
class HttpPushAction:
"""
Expand Down Expand Up @@ -843,49 +817,6 @@ def f(txn: LoggingTransaction) -> List[str]:

return await self.db_pool.runInteraction("get_push_action_users_in_range", f)

def _get_receipts_by_room_txn(
self, txn: LoggingTransaction, user_id: str
) -> Dict[str, _RoomReceipt]:
"""
Generate a map of room ID to the latest stream ordering that has been
read by the given user.

Args:
txn:
user_id: The user to fetch receipts for.

Returns:
A map including all rooms the user is in with a receipt. It maps
room IDs to _RoomReceipt instances
"""
receipt_types_clause, args = make_in_list_sql_clause(
self.database_engine,
"receipt_type",
(ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE),
)

sql = f"""
SELECT room_id, thread_id, MAX(stream_ordering)
FROM receipts_linearized
INNER JOIN events USING (room_id, event_id)
WHERE {receipt_types_clause}
AND user_id = ?
GROUP BY room_id, thread_id
"""

args.extend((user_id,))
txn.execute(sql, args)

result: Dict[str, _RoomReceipt] = {}
for room_id, thread_id, stream_ordering in txn:
room_receipt = result.setdefault(room_id, _RoomReceipt())
if thread_id is None:
room_receipt.unthreaded_stream_ordering = stream_ordering
else:
room_receipt.threaded_stream_ordering[thread_id] = stream_ordering

return result

async def get_unread_push_actions_for_user_in_range_for_http(
self,
user_id: str,
Expand All @@ -909,12 +840,6 @@ async def get_unread_push_actions_for_user_in_range_for_http(
The list will have between 0~limit entries.
"""

receipts_by_room = await self.db_pool.runInteraction(
"get_unread_push_actions_for_user_in_range_http_receipts",
self._get_receipts_by_room_txn,
user_id=user_id,
)

def get_push_actions_txn(
txn: LoggingTransaction,
) -> List[Tuple[str, str, str, int, str, bool]]:
Expand Down Expand Up @@ -944,9 +869,6 @@ def get_push_actions_txn(
actions=_deserialize_action(actions, highlight),
)
for event_id, room_id, thread_id, stream_ordering, actions, highlight in push_actions
if receipts_by_room.get(room_id, MISSING_ROOM_RECEIPT).is_unread(
thread_id, stream_ordering
)
]

# Now sort it so it's ordered correctly, since currently it will
Expand Down Expand Up @@ -982,12 +904,6 @@ async def get_unread_push_actions_for_user_in_range_for_email(
The list will have between 0~limit entries.
"""

receipts_by_room = await self.db_pool.runInteraction(
"get_unread_push_actions_for_user_in_range_email_receipts",
self._get_receipts_by_room_txn,
user_id=user_id,
)

def get_push_actions_txn(
txn: LoggingTransaction,
) -> List[Tuple[str, str, str, int, str, bool, int]]:
Expand Down Expand Up @@ -1020,9 +936,6 @@ def get_push_actions_txn(
received_ts=received_ts,
)
for event_id, room_id, thread_id, stream_ordering, actions, highlight, received_ts in push_actions
if receipts_by_room.get(room_id, MISSING_ROOM_RECEIPT).is_unread(
thread_id, stream_ordering
)
]

# Now sort it so it's ordered correctly, since currently it will
Expand Down
23 changes: 22 additions & 1 deletion synapse/storage/databases/main/receipts.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
cast,
)

from synapse.api.constants import EduTypes
from synapse.api.constants import EduTypes, ReceiptTypes
from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
from synapse.replication.tcp.streams import ReceiptsStream
from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
Expand Down Expand Up @@ -707,6 +707,27 @@ def _insert_linearized_receipt_txn(
lock=False,
)

if stream_ordering is not None and receipt_type in (
ReceiptTypes.READ,
ReceiptTypes.READ_PRIVATE,
):
args = [room_id, user_id, stream_ordering]
where_clause = ""

if thread_id is not None:
where_clause = "AND thread_id = ?"
args.append(thread_id)

sql = f"""
DELETE FROM event_push_actions
WHERE room_id = ?
AND user_id = ?
AND stream_ordering <= ?
AND highlight = 0
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add a comment explaining why we are special casing highlights here?

{where_clause}
"""
txn.execute(sql, args)

return rx_ts

def _graph_to_linear(
Expand Down
73 changes: 73 additions & 0 deletions tests/storage/test_receipts.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,3 +292,76 @@ def test_get_last_receipt_event_id_for_user(self) -> None:
[ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE], room_id=self.room_id2
)
self.assertEqual(res, event2_1_id)

def test_receipts_clear_push_actions(self) -> None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add a docstring explaining what this test tests please? Same with the other test.

event_1_id = self.create_and_send_event(
self.room_id1, UserID.from_string(OTHER_USER_ID)
)
event_2_id = self.create_and_send_event(
self.room_id1, UserID.from_string(OTHER_USER_ID)
)

def _assert_push_action_count(expected: int):
result = self.get_success(
self.store.db_pool.simple_select_list(
table="event_push_actions",
keyvalues={"1": 1},
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can also just set keyvalues to None which generates an SQL query with no WHERE clause

retcols=("event_id",),
desc="",
)
)
self.assertEqual(len(result), expected)

# Check we have 2 push actions pending
_assert_push_action_count(2)

# Send a read receipt for the first event
self.get_success(
self.store.insert_receipt(
self.room_id1, ReceiptTypes.READ, OUR_USER_ID, [event_1_id], None, {}
)
)

# Check that we now have a single push action pending
_assert_push_action_count(1)

# Send a read receipt for the second event
self.get_success(
self.store.insert_receipt(
self.room_id1, ReceiptTypes.READ, OUR_USER_ID, [event_2_id], None, {}
)
)

# Check that we have no push actions pending
_assert_push_action_count(0)

def test_receipts_not_clear_highlight_push_actions(self) -> None:
event_1_id = self.create_and_send_event(
self.room_id1,
UserID.from_string(OTHER_USER_ID),
content="our",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
content="our",
content=UserID.from_string(OUR_USER_ID).localpart,

Bit more boilerplatey, but it makes it easier to understand what "our" is supposed to be here (and that it's not just a magic value we're not pulling out of nowhere).

)

def _assert_push_action_count(expected: int):
result = self.get_success(
self.store.db_pool.simple_select_list(
table="event_push_actions",
keyvalues={"1": 1},
retcols=("*",),
desc="",
)
)
self.assertEqual(len(result), expected)

# Check we have 1 push actions pending
_assert_push_action_count(1)

# Send a read receipt for the first event
self.get_success(
self.store.insert_receipt(
self.room_id1, ReceiptTypes.READ, OUR_USER_ID, [event_1_id], None, {}
)
)

# Check that we now have a single push action pending
_assert_push_action_count(1)
6 changes: 5 additions & 1 deletion tests/unittest.py
Original file line number Diff line number Diff line change
Expand Up @@ -702,6 +702,7 @@ def create_and_send_event(
user: UserID,
soft_failed: bool = False,
prev_event_ids: Optional[List[str]] = None,
content: Optional[str] = None,
) -> str:
"""
Create and send an event.
Expand All @@ -724,7 +725,10 @@ def create_and_send_event(
"type": EventTypes.Message,
"room_id": room_id,
"sender": user.to_string(),
"content": {"body": secrets.token_hex(), "msgtype": "m.text"},
"content": {
"body": content or secrets.token_hex(),
"msgtype": "m.text",
},
},
prev_event_ids=prev_event_ids,
)
Expand Down