This repository has been archived by the owner on Apr 26, 2024. It is now read-only.
-
-
Notifications
You must be signed in to change notification settings - Fork 2.1k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Replace simple_async_mock with AsyncMock (#16180)
Python 3.8 has a native AsyncMock, use it instead of a custom implementation.
- Loading branch information
Showing
15 changed files
with
140 additions
and
160 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
Use `AsyncMock` instead of custom code. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -12,7 +12,7 @@ | |
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
from unittest.mock import Mock | ||
from unittest.mock import AsyncMock, Mock | ||
|
||
import pymacaroons | ||
|
||
|
@@ -35,7 +35,6 @@ | |
from synapse.util import Clock | ||
|
||
from tests import unittest | ||
from tests.test_utils import simple_async_mock | ||
from tests.unittest import override_config | ||
from tests.utils import mock_getRawHeaders | ||
|
||
|
@@ -60,16 +59,16 @@ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: | |
# this is overridden for the appservice tests | ||
self.store.get_app_service_by_token = Mock(return_value=None) | ||
|
||
self.store.insert_client_ip = simple_async_mock(None) | ||
self.store.is_support_user = simple_async_mock(False) | ||
self.store.insert_client_ip = AsyncMock(return_value=None) | ||
self.store.is_support_user = AsyncMock(return_value=False) | ||
|
||
def test_get_user_by_req_user_valid_token(self) -> None: | ||
user_info = TokenLookupResult( | ||
user_id=self.test_user, token_id=5, device_id="device" | ||
) | ||
self.store.get_user_by_access_token = simple_async_mock(user_info) | ||
self.store.mark_access_token_as_used = simple_async_mock(None) | ||
self.store.get_user_locked_status = simple_async_mock(False) | ||
self.store.get_user_by_access_token = AsyncMock(return_value=user_info) | ||
self.store.mark_access_token_as_used = AsyncMock(return_value=None) | ||
self.store.get_user_locked_status = AsyncMock(return_value=False) | ||
|
||
request = Mock(args={}) | ||
request.args[b"access_token"] = [self.test_token] | ||
|
@@ -78,7 +77,7 @@ def test_get_user_by_req_user_valid_token(self) -> None: | |
self.assertEqual(requester.user.to_string(), self.test_user) | ||
|
||
def test_get_user_by_req_user_bad_token(self) -> None: | ||
self.store.get_user_by_access_token = simple_async_mock(None) | ||
self.store.get_user_by_access_token = AsyncMock(return_value=None) | ||
|
||
request = Mock(args={}) | ||
request.args[b"access_token"] = [self.test_token] | ||
|
@@ -91,7 +90,7 @@ def test_get_user_by_req_user_bad_token(self) -> None: | |
|
||
def test_get_user_by_req_user_missing_token(self) -> None: | ||
user_info = TokenLookupResult(user_id=self.test_user, token_id=5) | ||
self.store.get_user_by_access_token = simple_async_mock(user_info) | ||
self.store.get_user_by_access_token = AsyncMock(return_value=user_info) | ||
|
||
request = Mock(args={}) | ||
request.requestHeaders.getRawHeaders = mock_getRawHeaders() | ||
|
@@ -106,7 +105,7 @@ def test_get_user_by_req_appservice_valid_token(self) -> None: | |
token="foobar", url="a_url", sender=self.test_user, ip_range_whitelist=None | ||
) | ||
self.store.get_app_service_by_token = Mock(return_value=app_service) | ||
self.store.get_user_by_access_token = simple_async_mock(None) | ||
self.store.get_user_by_access_token = AsyncMock(return_value=None) | ||
|
||
request = Mock(args={}) | ||
request.getClientAddress.return_value.host = "127.0.0.1" | ||
|
@@ -125,7 +124,7 @@ def test_get_user_by_req_appservice_valid_token_good_ip(self) -> None: | |
ip_range_whitelist=IPSet(["192.168/16"]), | ||
) | ||
self.store.get_app_service_by_token = Mock(return_value=app_service) | ||
self.store.get_user_by_access_token = simple_async_mock(None) | ||
self.store.get_user_by_access_token = AsyncMock(return_value=None) | ||
|
||
request = Mock(args={}) | ||
request.getClientAddress.return_value.host = "192.168.10.10" | ||
|
@@ -144,7 +143,7 @@ def test_get_user_by_req_appservice_valid_token_bad_ip(self) -> None: | |
ip_range_whitelist=IPSet(["192.168/16"]), | ||
) | ||
self.store.get_app_service_by_token = Mock(return_value=app_service) | ||
self.store.get_user_by_access_token = simple_async_mock(None) | ||
self.store.get_user_by_access_token = AsyncMock(return_value=None) | ||
|
||
request = Mock(args={}) | ||
request.getClientAddress.return_value.host = "131.111.8.42" | ||
|
@@ -158,7 +157,7 @@ def test_get_user_by_req_appservice_valid_token_bad_ip(self) -> None: | |
|
||
def test_get_user_by_req_appservice_bad_token(self) -> None: | ||
self.store.get_app_service_by_token = Mock(return_value=None) | ||
self.store.get_user_by_access_token = simple_async_mock(None) | ||
self.store.get_user_by_access_token = AsyncMock(return_value=None) | ||
|
||
request = Mock(args={}) | ||
request.args[b"access_token"] = [self.test_token] | ||
|
@@ -172,7 +171,7 @@ def test_get_user_by_req_appservice_bad_token(self) -> None: | |
def test_get_user_by_req_appservice_missing_token(self) -> None: | ||
app_service = Mock(token="foobar", url="a_url", sender=self.test_user) | ||
self.store.get_app_service_by_token = Mock(return_value=app_service) | ||
self.store.get_user_by_access_token = simple_async_mock(None) | ||
self.store.get_user_by_access_token = AsyncMock(return_value=None) | ||
|
||
request = Mock(args={}) | ||
request.requestHeaders.getRawHeaders = mock_getRawHeaders() | ||
|
@@ -190,8 +189,8 @@ def test_get_user_by_req_appservice_valid_token_valid_user_id(self) -> None: | |
app_service.is_interested_in_user = Mock(return_value=True) | ||
self.store.get_app_service_by_token = Mock(return_value=app_service) | ||
# This just needs to return a truth-y value. | ||
self.store.get_user_by_id = simple_async_mock({"is_guest": False}) | ||
self.store.get_user_by_access_token = simple_async_mock(None) | ||
self.store.get_user_by_id = AsyncMock(return_value={"is_guest": False}) | ||
self.store.get_user_by_access_token = AsyncMock(return_value=None) | ||
|
||
request = Mock(args={}) | ||
request.getClientAddress.return_value.host = "127.0.0.1" | ||
|
@@ -210,7 +209,7 @@ def test_get_user_by_req_appservice_valid_token_bad_user_id(self) -> None: | |
) | ||
app_service.is_interested_in_user = Mock(return_value=False) | ||
self.store.get_app_service_by_token = Mock(return_value=app_service) | ||
self.store.get_user_by_access_token = simple_async_mock(None) | ||
self.store.get_user_by_access_token = AsyncMock(return_value=None) | ||
|
||
request = Mock(args={}) | ||
request.getClientAddress.return_value.host = "127.0.0.1" | ||
|
@@ -234,10 +233,10 @@ def test_get_user_by_req_appservice_valid_token_valid_device_id(self) -> None: | |
app_service.is_interested_in_user = Mock(return_value=True) | ||
self.store.get_app_service_by_token = Mock(return_value=app_service) | ||
# This just needs to return a truth-y value. | ||
self.store.get_user_by_id = simple_async_mock({"is_guest": False}) | ||
self.store.get_user_by_access_token = simple_async_mock(None) | ||
self.store.get_user_by_id = AsyncMock(return_value={"is_guest": False}) | ||
self.store.get_user_by_access_token = AsyncMock(return_value=None) | ||
# This also needs to just return a truth-y value | ||
self.store.get_device = simple_async_mock({"hidden": False}) | ||
self.store.get_device = AsyncMock(return_value={"hidden": False}) | ||
|
||
request = Mock(args={}) | ||
request.getClientAddress.return_value.host = "127.0.0.1" | ||
|
@@ -266,10 +265,10 @@ def test_get_user_by_req_appservice_valid_token_invalid_device_id(self) -> None: | |
app_service.is_interested_in_user = Mock(return_value=True) | ||
self.store.get_app_service_by_token = Mock(return_value=app_service) | ||
# This just needs to return a truth-y value. | ||
self.store.get_user_by_id = simple_async_mock({"is_guest": False}) | ||
self.store.get_user_by_access_token = simple_async_mock(None) | ||
self.store.get_user_by_id = AsyncMock(return_value={"is_guest": False}) | ||
self.store.get_user_by_access_token = AsyncMock(return_value=None) | ||
# This also needs to just return a falsey value | ||
self.store.get_device = simple_async_mock(None) | ||
self.store.get_device = AsyncMock(return_value=None) | ||
|
||
request = Mock(args={}) | ||
request.getClientAddress.return_value.host = "127.0.0.1" | ||
|
@@ -283,18 +282,18 @@ def test_get_user_by_req_appservice_valid_token_invalid_device_id(self) -> None: | |
self.assertEqual(failure.value.errcode, Codes.EXCLUSIVE) | ||
|
||
def test_get_user_by_req__puppeted_token__not_tracking_puppeted_mau(self) -> None: | ||
self.store.get_user_by_access_token = simple_async_mock( | ||
TokenLookupResult( | ||
self.store.get_user_by_access_token = AsyncMock( | ||
return_value=TokenLookupResult( | ||
user_id="@baldrick:matrix.org", | ||
device_id="device", | ||
token_id=5, | ||
token_owner="@admin:matrix.org", | ||
token_used=True, | ||
) | ||
) | ||
self.store.insert_client_ip = simple_async_mock(None) | ||
self.store.mark_access_token_as_used = simple_async_mock(None) | ||
self.store.get_user_locked_status = simple_async_mock(False) | ||
self.store.insert_client_ip = AsyncMock(return_value=None) | ||
self.store.mark_access_token_as_used = AsyncMock(return_value=None) | ||
self.store.get_user_locked_status = AsyncMock(return_value=False) | ||
request = Mock(args={}) | ||
request.getClientAddress.return_value.host = "127.0.0.1" | ||
request.args[b"access_token"] = [self.test_token] | ||
|
@@ -304,18 +303,18 @@ def test_get_user_by_req__puppeted_token__not_tracking_puppeted_mau(self) -> Non | |
|
||
def test_get_user_by_req__puppeted_token__tracking_puppeted_mau(self) -> None: | ||
self.auth._track_puppeted_user_ips = True | ||
self.store.get_user_by_access_token = simple_async_mock( | ||
TokenLookupResult( | ||
self.store.get_user_by_access_token = AsyncMock( | ||
return_value=TokenLookupResult( | ||
user_id="@baldrick:matrix.org", | ||
device_id="device", | ||
token_id=5, | ||
token_owner="@admin:matrix.org", | ||
token_used=True, | ||
) | ||
) | ||
self.store.get_user_locked_status = simple_async_mock(False) | ||
self.store.insert_client_ip = simple_async_mock(None) | ||
self.store.mark_access_token_as_used = simple_async_mock(None) | ||
self.store.get_user_locked_status = AsyncMock(return_value=False) | ||
self.store.insert_client_ip = AsyncMock(return_value=None) | ||
self.store.mark_access_token_as_used = AsyncMock(return_value=None) | ||
request = Mock(args={}) | ||
request.getClientAddress.return_value.host = "127.0.0.1" | ||
request.args[b"access_token"] = [self.test_token] | ||
|
@@ -324,7 +323,7 @@ def test_get_user_by_req__puppeted_token__tracking_puppeted_mau(self) -> None: | |
self.assertEqual(self.store.insert_client_ip.call_count, 2) | ||
|
||
def test_get_user_from_macaroon(self) -> None: | ||
self.store.get_user_by_access_token = simple_async_mock(None) | ||
self.store.get_user_by_access_token = AsyncMock(return_value=None) | ||
|
||
user_id = "@baldrick:matrix.org" | ||
macaroon = pymacaroons.Macaroon( | ||
|
@@ -342,8 +341,8 @@ def test_get_user_from_macaroon(self) -> None: | |
) | ||
|
||
def test_get_guest_user_from_macaroon(self) -> None: | ||
self.store.get_user_by_id = simple_async_mock({"is_guest": True}) | ||
self.store.get_user_by_access_token = simple_async_mock(None) | ||
self.store.get_user_by_id = AsyncMock(return_value={"is_guest": True}) | ||
self.store.get_user_by_access_token = AsyncMock(return_value=None) | ||
|
||
user_id = "@baldrick:matrix.org" | ||
macaroon = pymacaroons.Macaroon( | ||
|
@@ -373,7 +372,7 @@ def test_blocking_mau(self) -> None: | |
|
||
self.auth_blocking._limit_usage_by_mau = True | ||
|
||
self.store.get_monthly_active_count = simple_async_mock(lots_of_users) | ||
self.store.get_monthly_active_count = AsyncMock(return_value=lots_of_users) | ||
|
||
e = self.get_failure( | ||
self.auth_blocking.check_auth_blocking(), ResourceLimitError | ||
|
@@ -383,25 +382,27 @@ def test_blocking_mau(self) -> None: | |
self.assertEqual(e.value.code, 403) | ||
|
||
# Ensure does not throw an error | ||
self.store.get_monthly_active_count = simple_async_mock(small_number_of_users) | ||
self.store.get_monthly_active_count = AsyncMock( | ||
return_value=small_number_of_users | ||
) | ||
self.get_success(self.auth_blocking.check_auth_blocking()) | ||
|
||
def test_blocking_mau__depending_on_user_type(self) -> None: | ||
self.auth_blocking._max_mau_value = 50 | ||
self.auth_blocking._limit_usage_by_mau = True | ||
|
||
self.store.get_monthly_active_count = simple_async_mock(100) | ||
self.store.get_monthly_active_count = AsyncMock(return_value=100) | ||
# Support users allowed | ||
self.get_success( | ||
self.auth_blocking.check_auth_blocking(user_type=UserTypes.SUPPORT) | ||
) | ||
self.store.get_monthly_active_count = simple_async_mock(100) | ||
self.store.get_monthly_active_count = AsyncMock(return_value=100) | ||
# Bots not allowed | ||
self.get_failure( | ||
self.auth_blocking.check_auth_blocking(user_type=UserTypes.BOT), | ||
ResourceLimitError, | ||
) | ||
self.store.get_monthly_active_count = simple_async_mock(100) | ||
self.store.get_monthly_active_count = AsyncMock(return_value=100) | ||
# Real users not allowed | ||
self.get_failure(self.auth_blocking.check_auth_blocking(), ResourceLimitError) | ||
|
||
|
@@ -412,9 +413,9 @@ def test_blocking_mau__appservice_requester_allowed_when_not_tracking_ips( | |
self.auth_blocking._limit_usage_by_mau = True | ||
self.auth_blocking._track_appservice_user_ips = False | ||
|
||
self.store.get_monthly_active_count = simple_async_mock(100) | ||
self.store.user_last_seen_monthly_active = simple_async_mock() | ||
self.store.is_trial_user = simple_async_mock() | ||
self.store.get_monthly_active_count = AsyncMock(return_value=100) | ||
self.store.user_last_seen_monthly_active = AsyncMock(return_value=None) | ||
self.store.is_trial_user = AsyncMock(return_value=False) | ||
|
||
appservice = ApplicationService( | ||
"abcd", | ||
|
@@ -443,9 +444,9 @@ def test_blocking_mau__appservice_requester_disallowed_when_tracking_ips( | |
self.auth_blocking._limit_usage_by_mau = True | ||
self.auth_blocking._track_appservice_user_ips = True | ||
|
||
self.store.get_monthly_active_count = simple_async_mock(100) | ||
self.store.user_last_seen_monthly_active = simple_async_mock() | ||
self.store.is_trial_user = simple_async_mock() | ||
self.store.get_monthly_active_count = AsyncMock(return_value=100) | ||
self.store.user_last_seen_monthly_active = AsyncMock(return_value=None) | ||
self.store.is_trial_user = AsyncMock(return_value=False) | ||
|
||
appservice = ApplicationService( | ||
"abcd", | ||
|
@@ -473,7 +474,7 @@ def test_blocking_mau__appservice_requester_disallowed_when_tracking_ips( | |
def test_reserved_threepid(self) -> None: | ||
self.auth_blocking._limit_usage_by_mau = True | ||
self.auth_blocking._max_mau_value = 1 | ||
self.store.get_monthly_active_count = simple_async_mock(2) | ||
self.store.get_monthly_active_count = AsyncMock(return_value=2) | ||
threepid = {"medium": "email", "address": "[email protected]"} | ||
unknown_threepid = {"medium": "email", "address": "[email protected]"} | ||
self.auth_blocking._mau_limits_reserved_threepids = [threepid] | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.