From e9a649ec314122d34d1b408c83800e1f1e263554 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Mon, 25 Jul 2022 15:01:06 -0400 Subject: [PATCH 01/10] Add an API for listing threads in a room. --- changelog.d/13394.feature | 1 + synapse/config/experimental.py | 3 + synapse/handlers/relations.py | 63 +++++++++++++++ synapse/rest/client/relations.py | 44 +++++++++++ synapse/storage/databases/main/events.py | 7 +- synapse/storage/databases/main/relations.py | 87 +++++++++++++++++++++ tests/rest/client/test_relations.py | 74 ++++++++++++++++++ 7 files changed, 277 insertions(+), 2 deletions(-) create mode 100644 changelog.d/13394.feature diff --git a/changelog.d/13394.feature b/changelog.d/13394.feature new file mode 100644 index 000000000000..68de079cf317 --- /dev/null +++ b/changelog.d/13394.feature @@ -0,0 +1 @@ +Experimental support for [MSC3856](https://github.com/matrix-org/matrix-spec-proposals/pull/3856): threads list API. diff --git a/synapse/config/experimental.py b/synapse/config/experimental.py index 1902222d7b67..5613c8e038fe 100644 --- a/synapse/config/experimental.py +++ b/synapse/config/experimental.py @@ -93,3 +93,6 @@ def read_config(self, config: JsonDict, **kwargs: Any) -> None: # MSC3848: Introduce errcodes for specific event sending failures self.msc3848_enabled: bool = experimental.get("msc3848_enabled", False) + + # MSC3856: Threads list API + self.msc3856_enabled: bool = experimental.get("msc3856_enabled", False) diff --git a/synapse/handlers/relations.py b/synapse/handlers/relations.py index 8f797e3ae9c3..af83ac1a7bc9 100644 --- a/synapse/handlers/relations.py +++ b/synapse/handlers/relations.py @@ -480,3 +480,66 @@ async def get_bundled_aggregations( results.setdefault(event_id, BundledAggregations()).replace = edit return results + + async def get_threads( + self, + requester: Requester, + room_id: str, + limit: int = 5, + from_token: Optional[StreamToken] = None, + to_token: Optional[StreamToken] = None, + ) -> JsonDict: + """Get related events of a event, ordered by topological ordering. + + Args: + requester: The user requesting the relations. + room_id: The room the event belongs to. + limit: Only fetch the most recent `limit` events. + from_token: Fetch rows from the given token, or from the start if None. + to_token: Fetch rows up to the given token, or up to the end if None. + + Returns: + The pagination chunk. + """ + + user_id = requester.user.to_string() + + # TODO Properly handle a user leaving a room. + (_, member_event_id) = await self._auth.check_user_in_room_or_world_readable( + room_id, user_id, allow_departed_users=True + ) + + # Note that ignored users are not passed into get_relations_for_event + # below. Ignored users are handled in filter_events_for_client (and by + # not passing them in here we should get a better cache hit rate). + thread_roots, next_token = await self._main_store.get_threads( + room_id=room_id, limit=limit, from_token=from_token, to_token=to_token + ) + + events = await self._main_store.get_events_as_list(thread_roots) + + events = await filter_events_for_client( + self._storage_controllers, + user_id, + events, + is_peeking=(member_event_id is None), + ) + + now = self._clock.time_msec() + + aggregations = await self.get_bundled_aggregations( + events, requester.user.to_string() + ) + serialized_events = self._event_serializer.serialize_events( + events, now, bundle_aggregations=aggregations + ) + + return_value: JsonDict = {"chunk": serialized_events} + + if next_token: + return_value["next_batch"] = await next_token.to_string(self._main_store) + + if from_token: + return_value["prev_batch"] = await from_token.to_string(self._main_store) + + return return_value diff --git a/synapse/rest/client/relations.py b/synapse/rest/client/relations.py index ce970800136a..faa962e3a8ee 100644 --- a/synapse/rest/client/relations.py +++ b/synapse/rest/client/relations.py @@ -13,6 +13,7 @@ # limitations under the License. import logging +import re from typing import TYPE_CHECKING, Optional, Tuple from synapse.http.server import HttpServer @@ -91,5 +92,48 @@ async def on_GET( return 200, result +class ThreadsServlet(RestServlet): + PATTERNS = ( + re.compile( + "^/_matrix/client/unstable/org.matrix.msc3856/rooms/(?P[^/]*)/threads" + ), + ) + + def __init__(self, hs: "HomeServer"): + super().__init__() + self.auth = hs.get_auth() + self.store = hs.get_datastores().main + self._relations_handler = hs.get_relations_handler() + + async def on_GET( + self, request: SynapseRequest, room_id: str + ) -> Tuple[int, JsonDict]: + requester = await self.auth.get_user_by_req(request) + + limit = parse_integer(request, "limit", default=5) + from_token_str = parse_string(request, "from") + to_token_str = parse_string(request, "to") + + # Return the relations + from_token = None + if from_token_str: + from_token = await StreamToken.from_string(self.store, from_token_str) + to_token = None + if to_token_str: + to_token = await StreamToken.from_string(self.store, to_token_str) + + result = await self._relations_handler.get_threads( + requester=requester, + room_id=room_id, + limit=limit, + from_token=from_token, + to_token=to_token, + ) + + return 200, result + + def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: RelationPaginationServlet(hs).register(http_server) + if hs.config.experimental.msc3856_enabled: + ThreadsServlet(hs).register(http_server) diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py index 1f600f119029..cf18dd179cb6 100644 --- a/synapse/storage/databases/main/events.py +++ b/synapse/storage/databases/main/events.py @@ -1594,7 +1594,7 @@ def _update_metadata_tables_txn( ) # Remove from relations table. - self._handle_redact_relations(txn, event.redacts) + self._handle_redact_relations(txn, event.room_id, event.redacts) # Update the event_forward_extremities, event_backward_extremities and # event_edges tables. @@ -1909,6 +1909,7 @@ def _handle_event_relations( self.store.get_thread_participated.invalidate, (relation.parent_id, event.sender), ) + txn.call_after(self.store.get_threads.invalidate, (event.room_id,)) def _handle_insertion_event( self, txn: LoggingTransaction, event: EventBase @@ -2033,13 +2034,14 @@ def _handle_batch_event(self, txn: LoggingTransaction, event: EventBase) -> None txn.execute(sql, (batch_id,)) def _handle_redact_relations( - self, txn: LoggingTransaction, redacted_event_id: str + self, txn: LoggingTransaction, room_id: str, redacted_event_id: str ) -> None: """Handles receiving a redaction and checking whether the redacted event has any relations which must be removed from the database. Args: txn + room_id: The room ID of the event that was redacted. redacted_event_id: The event that was redacted. """ @@ -2068,6 +2070,7 @@ def _handle_redact_relations( self.store._invalidate_cache_and_stream( txn, self.store.get_thread_participated, (redacted_relates_to,) ) + txn.call_after(self.store.get_threads.invalidate, (room_id,)) self.store._invalidate_cache_and_stream( txn, self.store.get_mutual_event_relations_for_rel_type, diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py index 7bd27790ebfe..57b2f7c188c2 100644 --- a/synapse/storage/databases/main/relations.py +++ b/synapse/storage/databases/main/relations.py @@ -814,6 +814,93 @@ def _get_event_relations( "get_event_relations", _get_event_relations ) + @cached(tree=True) + async def get_threads( + self, + room_id: str, + limit: int = 5, + from_token: Optional[StreamToken] = None, + to_token: Optional[StreamToken] = None, + ) -> Tuple[List[str], Optional[StreamToken]]: + """Get a list of thread IDs, ordered by topological ordering of their + latest reply. + + Args: + room_id: The room the event belongs to. + limit: Only fetch the most recent `limit` threads. + from_token: Fetch rows from the given token, or from the start if None. + to_token: Fetch rows up to the given token, or up to the end if None. + + Returns: + A tuple of: + A list of thread root event IDs. + + The next stream token, if one exists. + """ + pagination_clause = generate_pagination_where_clause( + direction="b", + column_names=("topological_ordering", "stream_ordering"), + from_token=from_token.room_key.as_historical_tuple() + if from_token + else None, + to_token=to_token.room_key.as_historical_tuple() if to_token else None, + engine=self.database_engine, + ) + + if pagination_clause: + pagination_clause = "AND " + pagination_clause + + sql = f""" + SELECT relates_to_id, MAX(topological_ordering), MAX(stream_ordering) + FROM event_relations + INNER JOIN events USING (event_id) + WHERE + room_id = ? AND + relation_type = '{RelationTypes.THREAD}' + {pagination_clause} + GROUP BY relates_to_id + ORDER BY MAX(topological_ordering) DESC, MAX(stream_ordering) DESC + LIMIT ? + """ + + def _get_threads_txn( + txn: LoggingTransaction, + ) -> Tuple[List[str], Optional[StreamToken]]: + txn.execute(sql, [room_id, limit + 1]) + + last_topo_id = None + last_stream_id = None + thread_ids = [] + for thread_id, topo_id, stream_id in txn: + thread_ids.append(thread_id) + last_topo_id = topo_id + last_stream_id = stream_id + + # If there are more events, generate the next pagination key. + next_token = None + if len(thread_ids) > limit and last_topo_id and last_stream_id: + next_key = RoomStreamToken(last_topo_id, last_stream_id) + if from_token: + next_token = from_token.copy_and_replace( + StreamKeyType.ROOM, next_key + ) + else: + next_token = StreamToken( + room_key=next_key, + presence_key=0, + typing_key=0, + receipt_key=0, + account_data_key=0, + push_rules_key=0, + to_device_key=0, + device_list_key=0, + groups_key=0, + ) + + return thread_ids[:limit], next_token + + return await self.db_pool.runInteraction("get_threads", _get_threads_txn) + class RelationsStore(RelationsWorkerStore): pass diff --git a/tests/rest/client/test_relations.py b/tests/rest/client/test_relations.py index ad03eee17bc8..0666bec479c7 100644 --- a/tests/rest/client/test_relations.py +++ b/tests/rest/client/test_relations.py @@ -1677,3 +1677,77 @@ def test_redact_parent_thread(self) -> None: relations[RelationTypes.THREAD]["latest_event"]["event_id"], related_event_id, ) + + +class ThreadsTestCase(BaseRelationsTestCase): + @unittest.override_config({"experimental_features": {"msc3856_enabled": True}}) + def test_threads(self) -> None: + """Create threads and ensure the ordering is due to their latest event.""" + # Create 2 threads. + thread_1 = self.parent_id + res = self.helper.send(self.room, body="Thread Root!", tok=self.user_token) + thread_2 = res["event_id"] + + self._send_relation(RelationTypes.THREAD, "m.room.test") + self._send_relation(RelationTypes.THREAD, "m.room.test", parent_id=thread_2) + + # Request the threads in the room. + channel = self.make_request( + "GET", + f"/_matrix/client/unstable/org.matrix.msc3856/rooms/{self.room}/threads", + access_token=self.user_token, + ) + self.assertEquals(200, channel.code, channel.json_body) + thread_roots = [ev["event_id"] for ev in channel.json_body["chunk"]] + self.assertEqual(thread_roots, [thread_2, thread_1]) + + # Update the first thread, the ordering should swap. + self._send_relation(RelationTypes.THREAD, "m.room.test") + + channel = self.make_request( + "GET", + f"/_matrix/client/unstable/org.matrix.msc3856/rooms/{self.room}/threads", + access_token=self.user_token, + ) + self.assertEquals(200, channel.code, channel.json_body) + thread_roots = [ev["event_id"] for ev in channel.json_body["chunk"]] + self.assertEqual(thread_roots, [thread_1, thread_2]) + + @unittest.override_config({"experimental_features": {"msc3856_enabled": True}}) + def test_pagination(self) -> None: + """Create threads and paginate through them.""" + # Create 2 threads. + thread_1 = self.parent_id + res = self.helper.send(self.room, body="Thread Root!", tok=self.user_token) + thread_2 = res["event_id"] + + self._send_relation(RelationTypes.THREAD, "m.room.test") + self._send_relation(RelationTypes.THREAD, "m.room.test", parent_id=thread_2) + + # Request the threads in the room. + channel = self.make_request( + "GET", + f"/_matrix/client/unstable/org.matrix.msc3856/rooms/{self.room}/threads?limit=1", + access_token=self.user_token, + ) + self.assertEquals(200, channel.code, channel.json_body) + thread_roots = [ev["event_id"] for ev in channel.json_body["chunk"]] + self.assertEqual(thread_roots, [thread_2]) + + # Make sure next_batch has something in it that looks like it could be a + # valid token. + next_batch = channel.json_body.get("next_batch") + self.assertIsInstance(next_batch, str, channel.json_body) + + channel = self.make_request( + "GET", + f"/_matrix/client/unstable/org.matrix.msc3856/rooms/{self.room}/threads?limit=1&from={next_batch}", + access_token=self.user_token, + ) + self.assertEquals(200, channel.code, channel.json_body) + thread_roots = [ev["event_id"] for ev in channel.json_body["chunk"]] + self.assertEqual(thread_roots, [thread_1], channel.json_body) + + self.assertNotIn("next_batch", channel.json_body, channel.json_body) + + # XXX Test ignoring users. From 8dcdb4efa9a22c21e426718ca95aba86ee10586d Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Tue, 26 Jul 2022 12:07:01 -0400 Subject: [PATCH 02/10] Allow limiting threads by participation. --- synapse/handlers/relations.py | 18 +++++++++++ synapse/rest/client/relations.py | 4 +++ tests/rest/client/test_relations.py | 46 +++++++++++++++++++++++++++++ 3 files changed, 68 insertions(+) diff --git a/synapse/handlers/relations.py b/synapse/handlers/relations.py index af83ac1a7bc9..8f17ee429062 100644 --- a/synapse/handlers/relations.py +++ b/synapse/handlers/relations.py @@ -485,6 +485,7 @@ async def get_threads( self, requester: Requester, room_id: str, + include: str, limit: int = 5, from_token: Optional[StreamToken] = None, to_token: Optional[StreamToken] = None, @@ -494,6 +495,8 @@ async def get_threads( Args: requester: The user requesting the relations. room_id: The room the event belongs to. + include: One of "all" or "participated" to indicate which threads should + be returned. limit: Only fetch the most recent `limit` events. from_token: Fetch rows from the given token, or from the start if None. to_token: Fetch rows up to the given token, or up to the end if None. @@ -518,6 +521,21 @@ async def get_threads( events = await self._main_store.get_events_as_list(thread_roots) + if include == "participated": + # Pre-seed thread participation with whether the requester sent the event. + participated = {event.event_id: event.sender == user_id for event in events} + # For events the requester did not send, check the database for whether + # the requester sent a threaded reply. + participated.update( + await self._main_store.get_threads_participated( + [eid for eid, p in participated.items() if not p], + user_id, + ) + ) + + # Limit the returned threads to those the user has participated in. + events = [event for event in events if participated[event.event_id]] + events = await filter_events_for_client( self._storage_controllers, user_id, diff --git a/synapse/rest/client/relations.py b/synapse/rest/client/relations.py index faa962e3a8ee..8d1fd4fb9873 100644 --- a/synapse/rest/client/relations.py +++ b/synapse/rest/client/relations.py @@ -113,6 +113,9 @@ async def on_GET( limit = parse_integer(request, "limit", default=5) from_token_str = parse_string(request, "from") to_token_str = parse_string(request, "to") + include = parse_string( + request, "include", default="all", allowed_values=["all", "participated"] + ) # Return the relations from_token = None @@ -125,6 +128,7 @@ async def on_GET( result = await self._relations_handler.get_threads( requester=requester, room_id=room_id, + include=include, limit=limit, from_token=from_token, to_token=to_token, diff --git a/tests/rest/client/test_relations.py b/tests/rest/client/test_relations.py index 0666bec479c7..6b302d90bfee 100644 --- a/tests/rest/client/test_relations.py +++ b/tests/rest/client/test_relations.py @@ -1750,4 +1750,50 @@ def test_pagination(self) -> None: self.assertNotIn("next_batch", channel.json_body, channel.json_body) + @unittest.override_config({"experimental_features": {"msc3856_enabled": True}}) + def test_include(self) -> None: + """Filtering threads to all or participated in should work.""" + # Thread 1 has the user as the root event. + thread_1 = self.parent_id + self._send_relation( + RelationTypes.THREAD, "m.room.test", access_token=self.user2_token + ) + + # Thread 2 has the user replying. + res = self.helper.send(self.room, body="Thread Root!", tok=self.user2_token) + thread_2 = res["event_id"] + self._send_relation(RelationTypes.THREAD, "m.room.test", parent_id=thread_2) + + # Thread 3 has the user not participating in. + res = self.helper.send(self.room, body="Another thread!", tok=self.user2_token) + thread_3 = res["event_id"] + self._send_relation( + RelationTypes.THREAD, + "m.room.test", + access_token=self.user2_token, + parent_id=thread_3, + ) + + # All threads in the room. + channel = self.make_request( + "GET", + f"/_matrix/client/unstable/org.matrix.msc3856/rooms/{self.room}/threads", + access_token=self.user_token, + ) + self.assertEquals(200, channel.code, channel.json_body) + thread_roots = [ev["event_id"] for ev in channel.json_body["chunk"]] + self.assertEqual( + thread_roots, [thread_3, thread_2, thread_1], channel.json_body + ) + + # Only participated threads. + channel = self.make_request( + "GET", + f"/_matrix/client/unstable/org.matrix.msc3856/rooms/{self.room}/threads?include=participated", + access_token=self.user_token, + ) + self.assertEquals(200, channel.code, channel.json_body) + thread_roots = [ev["event_id"] for ev in channel.json_body["chunk"]] + self.assertEqual(thread_roots, [thread_2, thread_1], channel.json_body) + # XXX Test ignoring users. From d510975b2f136ced7ae11b4a1553be3eac5eae7d Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Tue, 26 Jul 2022 12:50:15 -0400 Subject: [PATCH 03/10] Test ignored users. --- tests/rest/client/test_relations.py | 33 ++++++++++++++++++++++++++++- 1 file changed, 32 insertions(+), 1 deletion(-) diff --git a/tests/rest/client/test_relations.py b/tests/rest/client/test_relations.py index 6b302d90bfee..9176047ed403 100644 --- a/tests/rest/client/test_relations.py +++ b/tests/rest/client/test_relations.py @@ -1796,4 +1796,35 @@ def test_include(self) -> None: thread_roots = [ev["event_id"] for ev in channel.json_body["chunk"]] self.assertEqual(thread_roots, [thread_2, thread_1], channel.json_body) - # XXX Test ignoring users. + @unittest.override_config({"experimental_features": {"msc3856_enabled": True}}) + def test_ignored_user(self) -> None: + """Events from ignored users should be ignored.""" + # Thread 1 has a reply from an ignored user. + thread_1 = self.parent_id + self._send_relation( + RelationTypes.THREAD, "m.room.test", access_token=self.user2_token + ) + + # Thread 2 is created by an ignored user. + res = self.helper.send(self.room, body="Thread Root!", tok=self.user2_token) + thread_2 = res["event_id"] + self._send_relation(RelationTypes.THREAD, "m.room.test", parent_id=thread_2) + + # Ignore user2. + self.get_success( + self.store.add_account_data_for_user( + self.user_id, + AccountDataTypes.IGNORED_USER_LIST, + {"ignored_users": {self.user2_id: {}}}, + ) + ) + + # Only thread 1 is returned. + channel = self.make_request( + "GET", + f"/_matrix/client/unstable/org.matrix.msc3856/rooms/{self.room}/threads", + access_token=self.user_token, + ) + self.assertEquals(200, channel.code, channel.json_body) + thread_roots = [ev["event_id"] for ev in channel.json_body["chunk"]] + self.assertEqual(thread_roots, [thread_1], channel.json_body) From f6267b1abefc22559c6a7a88fa4df6cb03e71481 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Mon, 1 Aug 2022 09:04:33 -0400 Subject: [PATCH 04/10] Add an enum. --- synapse/handlers/relations.py | 12 ++++++++++-- synapse/rest/client/relations.py | 8 ++++++-- 2 files changed, 16 insertions(+), 4 deletions(-) diff --git a/synapse/handlers/relations.py b/synapse/handlers/relations.py index 8f17ee429062..e60a5693c60c 100644 --- a/synapse/handlers/relations.py +++ b/synapse/handlers/relations.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import enum import logging from typing import TYPE_CHECKING, Dict, FrozenSet, Iterable, List, Optional, Tuple @@ -30,6 +31,13 @@ logger = logging.getLogger(__name__) +class ThreadsListInclude(str, enum.Enum): + """Valid values for the 'include' flag of /threads.""" + + all = "all" + participated = "participated" + + @attr.s(slots=True, frozen=True, auto_attribs=True) class _ThreadAggregation: # The latest event in the thread. @@ -485,7 +493,7 @@ async def get_threads( self, requester: Requester, room_id: str, - include: str, + include: ThreadsListInclude, limit: int = 5, from_token: Optional[StreamToken] = None, to_token: Optional[StreamToken] = None, @@ -521,7 +529,7 @@ async def get_threads( events = await self._main_store.get_events_as_list(thread_roots) - if include == "participated": + if include == ThreadsListInclude.participated: # Pre-seed thread participation with whether the requester sent the event. participated = {event.event_id: event.sender == user_id for event in events} # For events the requester did not send, check the database for whether diff --git a/synapse/rest/client/relations.py b/synapse/rest/client/relations.py index 8d1fd4fb9873..d787aeaae160 100644 --- a/synapse/rest/client/relations.py +++ b/synapse/rest/client/relations.py @@ -16,6 +16,7 @@ import re from typing import TYPE_CHECKING, Optional, Tuple +from synapse.handlers.relations import ThreadsListInclude from synapse.http.server import HttpServer from synapse.http.servlet import RestServlet, parse_integer, parse_string from synapse.http.site import SynapseRequest @@ -114,7 +115,10 @@ async def on_GET( from_token_str = parse_string(request, "from") to_token_str = parse_string(request, "to") include = parse_string( - request, "include", default="all", allowed_values=["all", "participated"] + request, + "include", + default=ThreadsListInclude.all.value, + allowed_values=[v.value for v in ThreadsListInclude], ) # Return the relations @@ -128,7 +132,7 @@ async def on_GET( result = await self._relations_handler.get_threads( requester=requester, room_id=room_id, - include=include, + include=ThreadsListInclude(include), limit=limit, from_token=from_token, to_token=to_token, From 18571aee6bb2556d5717c4e49f086df5b058ae6a Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Fri, 16 Sep 2022 07:19:16 -0400 Subject: [PATCH 05/10] Fix call to check_user_in_room_or_world_readable. --- synapse/handlers/relations.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/synapse/handlers/relations.py b/synapse/handlers/relations.py index a917f2c870c8..2587efe4e101 100644 --- a/synapse/handlers/relations.py +++ b/synapse/handlers/relations.py @@ -519,7 +519,7 @@ async def get_threads( # TODO Properly handle a user leaving a room. (_, member_event_id) = await self._auth.check_user_in_room_or_world_readable( - room_id, user_id, allow_departed_users=True + room_id, requester, allow_departed_users=True ) # Note that ignored users are not passed into get_relations_for_event From 3061cd9a08e19bc2ae2d77a679cb23c25264d580 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Thu, 6 Oct 2022 08:52:00 -0400 Subject: [PATCH 06/10] Add a threads table and start filling it in. --- synapse/storage/databases/main/events.py | 30 ++++++++++++++++++- .../schema/main/delta/73/09threads_table.sql | 25 ++++++++++++++++ 2 files changed, 54 insertions(+), 1 deletion(-) create mode 100644 synapse/storage/schema/main/delta/73/09threads_table.sql diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py index 8a0f630a2155..678a2b4866b2 100644 --- a/synapse/storage/databases/main/events.py +++ b/synapse/storage/databases/main/events.py @@ -35,7 +35,7 @@ from prometheus_client import Counter import synapse.metrics -from synapse.api.constants import EventContentFields, EventTypes +from synapse.api.constants import EventContentFields, EventTypes, RelationTypes from synapse.api.errors import Codes, SynapseError from synapse.api.room_versions import RoomVersions from synapse.events import EventBase, relation_from_event @@ -1866,6 +1866,34 @@ def _handle_event_relations( }, ) + if relation.rel_type == RelationTypes.THREAD: + # Upsert into the threads table, but only overwrite the value if the + # new event is of a later topological order OR if the topological + # ordering is equal, but the stream ordering is later. + sql = """ + INSERT INTO threads (room_id, thread_id, latest_event_id, topological_ordering, stream_ordering) + VALUES (?, ?, ?, ?, ?) + ON CONFLICT (room_id, thread_id) + DO UPDATE SET + latest_event_id = excluded.latest_event_id, + topological_ordering = excluded.topological_ordering, + stream_ordering = excluded.stream_ordering + WHERE + threads.topological_ordering <= excluded.topological_ordering AND + threads.stream_ordering < excluded.stream_ordering + """ + + txn.execute( + sql, + ( + event.room_id, + relation.parent_id, + event.event_id, + event.depth, + event.internal_metadata.stream_ordering, + ), + ) + def _handle_insertion_event( self, txn: LoggingTransaction, event: EventBase ) -> None: diff --git a/synapse/storage/schema/main/delta/73/09threads_table.sql b/synapse/storage/schema/main/delta/73/09threads_table.sql new file mode 100644 index 000000000000..5887c7cfc5f3 --- /dev/null +++ b/synapse/storage/schema/main/delta/73/09threads_table.sql @@ -0,0 +1,25 @@ +/* Copyright 2022 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +CREATE TABLE threads ( + room_id TEXT NOT NULL, + -- The event ID of the root event in the thread. + thread_id TEXT NOT NULL, + -- The latest event ID and corresponding topo / stream ordering. + latest_event_id TEXT NOT NULL, + topological_ordering BIGINT NOT NULL, + stream_ordering BIGINT NOT NULL, + CONSTRAINT threads_uniqueness UNIQUE (room_id, thread_id) +); From e3357f7a594fb7812f281ea9194504cd4aab5ef2 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Thu, 6 Oct 2022 09:22:47 -0400 Subject: [PATCH 07/10] Backfill old threads. --- synapse/storage/databases/main/relations.py | 81 ++++++++++++++++++- .../schema/main/delta/73/09threads_table.sql | 3 + 2 files changed, 83 insertions(+), 1 deletion(-) diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py index e96f832cf8de..890d35eac467 100644 --- a/synapse/storage/databases/main/relations.py +++ b/synapse/storage/databases/main/relations.py @@ -14,6 +14,7 @@ import logging from typing import ( + TYPE_CHECKING, Collection, Dict, FrozenSet, @@ -31,12 +32,20 @@ from synapse.api.constants import RelationTypes from synapse.events import EventBase from synapse.storage._base import SQLBaseStore -from synapse.storage.database import LoggingTransaction, make_in_list_sql_clause +from synapse.storage.database import ( + DatabasePool, + LoggingDatabaseConnection, + LoggingTransaction, + make_in_list_sql_clause, +) from synapse.storage.databases.main.stream import generate_pagination_where_clause from synapse.storage.engines import PostgresEngine from synapse.types import JsonDict, RoomStreamToken, StreamKeyType, StreamToken from synapse.util.caches.descriptors import cached, cachedList +if TYPE_CHECKING: + from synapse.server import HomeServer + logger = logging.getLogger(__name__) @@ -56,6 +65,76 @@ class _RelatedEvent: class RelationsWorkerStore(SQLBaseStore): + def __init__( + self, + database: DatabasePool, + db_conn: LoggingDatabaseConnection, + hs: "HomeServer", + ): + super().__init__(database, db_conn, hs) + + self.db_pool.updates.register_background_update_handler( + "threads_backfill", self._backfill_threads + ) + + async def _backfill_threads(self, progress: JsonDict, batch_size: int) -> int: + """Backfill the threads table.""" + + def threads_backfill_txn(txn: LoggingTransaction) -> int: + last_thread_id = progress.get("last_thread_id", "") + + # Get the latest event in each thread by topo ordering / stream ordering. + # + # Note that the MAX(event_id) is needed to abide by the rules of group by, + # but doesn't actually do anything since there should only be a single event + # ID per topo/stream ordering pair. + sql = f""" + SELECT room_id, relates_to_id, MAX(topological_ordering), MAX(stream_ordering), MAX(event_id) + FROM event_relations + INNER JOIN events USING (event_id) + WHERE + relates_to_id > ? AND + relation_type = '{RelationTypes.THREAD}' + GROUP BY room_id, relates_to_id + ORDER BY relates_to_id + LIMIT ? + """ + txn.execute(sql, (last_thread_id, batch_size)) + + # No more rows to process. + rows = txn.fetchall() + if not rows: + return 0 + + # Insert the rows into the threads table. If a matching thread already exists, + # assume it is from a newer event. + sql = """ + INSERT INTO threads (room_id, thread_id, topological_ordering, stream_ordering, latest_event_id) + VALUES %s + ON CONFLICT (room_id, thread_id) + DO NOTHING + """ + if isinstance(txn.database_engine, PostgresEngine): + txn.execute_values(sql % ("?",), rows, fetch=False) + else: + txn.execute_batch(sql % ("?, ?, ?, ?, ?",), rows) + + # Mark the progress. + self.db_pool.updates._background_update_progress_txn( + txn, "threads_backfill", {"last_thread_id": rows[-1][1]} + ) + + return txn.rowcount + + result = await self.db_pool.runInteraction( + "threads_backfill", threads_backfill_txn + ) + + if not result: + await self.db_pool.updates._end_background_update("threads_backfill") + + return result + @cached(uncached_args=("event",), tree=True) async def get_relations_for_event( self, diff --git a/synapse/storage/schema/main/delta/73/09threads_table.sql b/synapse/storage/schema/main/delta/73/09threads_table.sql index 5887c7cfc5f3..060f623eb9e1 100644 --- a/synapse/storage/schema/main/delta/73/09threads_table.sql +++ b/synapse/storage/schema/main/delta/73/09threads_table.sql @@ -23,3 +23,6 @@ CREATE TABLE threads ( stream_ordering BIGINT NOT NULL, CONSTRAINT threads_uniqueness UNIQUE (room_id, thread_id) ); + +INSERT INTO background_updates (ordering, update_name, progress_json) VALUES + (7309, 'threads_backfill', '{}'); From b1f54e5a7246475fae099e8ae48836d442a8f4a7 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Thu, 6 Oct 2022 11:17:59 -0400 Subject: [PATCH 08/10] Use the threads table to select a list of threads. Updates the pagination token based on the MSC. --- synapse/handlers/relations.py | 21 ++-- synapse/rest/client/relations.py | 8 +- synapse/storage/databases/main/relations.py | 108 ++++++++++---------- 3 files changed, 63 insertions(+), 74 deletions(-) diff --git a/synapse/handlers/relations.py b/synapse/handlers/relations.py index a7781ef3be07..fc6a3b93e080 100644 --- a/synapse/handlers/relations.py +++ b/synapse/handlers/relations.py @@ -21,7 +21,7 @@ from synapse.api.errors import SynapseError from synapse.events import EventBase, relation_from_event from synapse.logging.opentracing import trace -from synapse.storage.databases.main.relations import _RelatedEvent +from synapse.storage.databases.main.relations import ThreadsNextBatch, _RelatedEvent from synapse.types import JsonDict, Requester, StreamToken, UserID from synapse.visibility import filter_events_for_client @@ -498,8 +498,7 @@ async def get_threads( room_id: str, include: ThreadsListInclude, limit: int = 5, - from_token: Optional[StreamToken] = None, - to_token: Optional[StreamToken] = None, + from_token: Optional[ThreadsNextBatch] = None, ) -> JsonDict: """Get related events of a event, ordered by topological ordering. @@ -510,7 +509,6 @@ async def get_threads( be returned. limit: Only fetch the most recent `limit` events. from_token: Fetch rows from the given token, or from the start if None. - to_token: Fetch rows up to the given token, or up to the end if None. Returns: The pagination chunk. @@ -526,8 +524,8 @@ async def get_threads( # Note that ignored users are not passed into get_relations_for_event # below. Ignored users are handled in filter_events_for_client (and by # not passing them in here we should get a better cache hit rate). - thread_roots, next_token = await self._main_store.get_threads( - room_id=room_id, limit=limit, from_token=from_token, to_token=to_token + thread_roots, next_batch = await self._main_store.get_threads( + room_id=room_id, limit=limit, from_token=from_token ) events = await self._main_store.get_events_as_list(thread_roots) @@ -554,21 +552,18 @@ async def get_threads( is_peeking=(member_event_id is None), ) - now = self._clock.time_msec() - aggregations = await self.get_bundled_aggregations( events, requester.user.to_string() ) + + now = self._clock.time_msec() serialized_events = self._event_serializer.serialize_events( events, now, bundle_aggregations=aggregations ) return_value: JsonDict = {"chunk": serialized_events} - if next_token: - return_value["next_batch"] = await next_token.to_string(self._main_store) - - if from_token: - return_value["prev_batch"] = await from_token.to_string(self._main_store) + if next_batch: + return_value["next_batch"] = str(next_batch) return return_value diff --git a/synapse/rest/client/relations.py b/synapse/rest/client/relations.py index 93dd8f0bb7db..5fbe73f329aa 100644 --- a/synapse/rest/client/relations.py +++ b/synapse/rest/client/relations.py @@ -21,6 +21,7 @@ from synapse.http.servlet import RestServlet, parse_integer, parse_string from synapse.http.site import SynapseRequest from synapse.rest.client._base import client_patterns +from synapse.storage.databases.main.relations import ThreadsNextBatch from synapse.types import JsonDict, StreamToken if TYPE_CHECKING: @@ -125,7 +126,6 @@ async def on_GET( limit = parse_integer(request, "limit", default=5) from_token_str = parse_string(request, "from") - to_token_str = parse_string(request, "to") include = parse_string( request, "include", @@ -136,10 +136,7 @@ async def on_GET( # Return the relations from_token = None if from_token_str: - from_token = await StreamToken.from_string(self.store, from_token_str) - to_token = None - if to_token_str: - to_token = await StreamToken.from_string(self.store, to_token_str) + from_token = ThreadsNextBatch.from_string(from_token_str) result = await self._relations_handler.get_threads( requester=requester, @@ -147,7 +144,6 @@ async def on_GET( include=ThreadsListInclude(include), limit=limit, from_token=from_token, - to_token=to_token, ) return 200, result diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py index 890d35eac467..da89af86102c 100644 --- a/synapse/storage/databases/main/relations.py +++ b/synapse/storage/databases/main/relations.py @@ -30,6 +30,7 @@ import attr from synapse.api.constants import RelationTypes +from synapse.api.errors import SynapseError from synapse.events import EventBase from synapse.storage._base import SQLBaseStore from synapse.storage.database import ( @@ -49,6 +50,26 @@ logger = logging.getLogger(__name__) +@attr.s(slots=True, frozen=True, auto_attribs=True) +class ThreadsNextBatch: + topological_ordering: int + stream_ordering: int + + def __str__(self) -> str: + return f"{self.topological_ordering}_{self.stream_ordering}" + + @classmethod + def from_string(cls, string: str) -> "ThreadsNextBatch": + """ + Creates a ThreadsNextBatch from its textual representation. + """ + try: + keys = (int(s) for s in string.split("_")) + return cls(*keys) + except Exception: + raise SynapseError(400, "Invalid threads token") + + @attr.s(slots=True, frozen=True, auto_attribs=True) class _RelatedEvent: """ @@ -916,83 +937,60 @@ async def get_threads( self, room_id: str, limit: int = 5, - from_token: Optional[StreamToken] = None, - to_token: Optional[StreamToken] = None, - ) -> Tuple[List[str], Optional[StreamToken]]: + from_token: Optional[ThreadsNextBatch] = None, + ) -> Tuple[List[str], Optional[ThreadsNextBatch]]: """Get a list of thread IDs, ordered by topological ordering of their latest reply. Args: room_id: The room the event belongs to. limit: Only fetch the most recent `limit` threads. - from_token: Fetch rows from the given token, or from the start if None. - to_token: Fetch rows up to the given token, or up to the end if None. + from_token: Fetch rows from a previous next_batch, or from the start if None. Returns: A tuple of: A list of thread root event IDs. - The next stream token, if one exists. + The next_batch, if one exists. """ - pagination_clause = generate_pagination_where_clause( - direction="b", - column_names=("topological_ordering", "stream_ordering"), - from_token=from_token.room_key.as_historical_tuple() - if from_token - else None, - to_token=to_token.room_key.as_historical_tuple() if to_token else None, - engine=self.database_engine, - ) - - if pagination_clause: - pagination_clause = "AND " + pagination_clause + # Generate the pagination clause, if necessary. + # + # Find any threads where the latest reply is equal / before the last + # thread's topo ordering and earlier in stream ordering. + pagination_clause = "" + pagination_args: tuple = () + if from_token: + pagination_clause = "AND topological_ordering <= ? AND stream_ordering < ?" + pagination_args = ( + from_token.topological_ordering, + from_token.stream_ordering, + ) sql = f""" - SELECT relates_to_id, MAX(topological_ordering), MAX(stream_ordering) - FROM event_relations - INNER JOIN events USING (event_id) + SELECT thread_id, topological_ordering, stream_ordering + FROM threads WHERE - room_id = ? AND - relation_type = '{RelationTypes.THREAD}' + room_id = ? {pagination_clause} - GROUP BY relates_to_id - ORDER BY MAX(topological_ordering) DESC, MAX(stream_ordering) DESC + ORDER BY topological_ordering DESC, stream_ordering DESC LIMIT ? """ def _get_threads_txn( txn: LoggingTransaction, - ) -> Tuple[List[str], Optional[StreamToken]]: - txn.execute(sql, [room_id, limit + 1]) - - last_topo_id = None - last_stream_id = None - thread_ids = [] - for thread_id, topo_id, stream_id in txn: - thread_ids.append(thread_id) - last_topo_id = topo_id - last_stream_id = stream_id - - # If there are more events, generate the next pagination key. + ) -> Tuple[List[str], Optional[ThreadsNextBatch]]: + txn.execute(sql, (room_id, *pagination_args, limit + 1)) + + rows = cast(List[Tuple[str, int, int]], txn.fetchall()) + thread_ids = [r[0] for r in rows] + + # If there are more events, generate the next pagination key from the + # last thread which will be returned. next_token = None - if len(thread_ids) > limit and last_topo_id and last_stream_id: - next_key = RoomStreamToken(last_topo_id, last_stream_id) - if from_token: - next_token = from_token.copy_and_replace( - StreamKeyType.ROOM, next_key - ) - else: - next_token = StreamToken( - room_key=next_key, - presence_key=0, - typing_key=0, - receipt_key=0, - account_data_key=0, - push_rules_key=0, - to_device_key=0, - device_list_key=0, - groups_key=0, - ) + if len(thread_ids) > limit: + last_topo_id = rows[-2][1] + last_stream_id = rows[-2][2] + next_token = ThreadsNextBatch(last_topo_id, last_stream_id) return thread_ids[:limit], next_token From 66da339417d194f816591c759be03a2d8ce521d0 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Thu, 6 Oct 2022 14:42:12 -0400 Subject: [PATCH 09/10] Fix portdb. --- synapse/_scripts/synapse_port_db.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/synapse/_scripts/synapse_port_db.py b/synapse/_scripts/synapse_port_db.py index 5fa599e70e90..d850e54e1751 100755 --- a/synapse/_scripts/synapse_port_db.py +++ b/synapse/_scripts/synapse_port_db.py @@ -72,6 +72,7 @@ RegistrationBackgroundUpdateStore, find_max_generated_user_id_localpart, ) +from synapse.storage.databases.main.relations import RelationsWorkerStore from synapse.storage.databases.main.room import RoomBackgroundUpdateStore from synapse.storage.databases.main.roommember import RoomMemberBackgroundUpdateStore from synapse.storage.databases.main.search import SearchBackgroundUpdateStore @@ -206,6 +207,7 @@ class Store( PusherWorkerStore, PresenceBackgroundUpdateStore, ReceiptsBackgroundUpdateStore, + RelationsWorkerStore, ): def execute(self, f: Callable[..., R], *args: Any, **kwargs: Any) -> Awaitable[R]: return self.db_pool.runInteraction(f.__name__, f, *args, **kwargs) From d5866873f4f9bf03274e25185a5aee9fc1b29fed Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Wed, 12 Oct 2022 14:22:35 -0400 Subject: [PATCH 10/10] Add missing index. --- synapse/storage/schema/main/delta/73/09threads_table.sql | 2 ++ 1 file changed, 2 insertions(+) diff --git a/synapse/storage/schema/main/delta/73/09threads_table.sql b/synapse/storage/schema/main/delta/73/09threads_table.sql index 060f623eb9e1..aa7c5e9a2eb3 100644 --- a/synapse/storage/schema/main/delta/73/09threads_table.sql +++ b/synapse/storage/schema/main/delta/73/09threads_table.sql @@ -24,5 +24,7 @@ CREATE TABLE threads ( CONSTRAINT threads_uniqueness UNIQUE (room_id, thread_id) ); +CREATE INDEX threads_ordering_idx ON threads(room_id, topological_ordering, stream_ordering); + INSERT INTO background_updates (ordering, update_name, progress_json) VALUES (7309, 'threads_backfill', '{}');