Skip to content

Commit

Permalink
Add registry support (#416)
Browse files Browse the repository at this point in the history
Add registry support which enables different response
matching and selection logic.

Co-authored-by: Mark Story <[email protected]>
  • Loading branch information
beliaev-maksim and markstory authored Jan 7, 2022
1 parent 476ad67 commit 3dccae8
Show file tree
Hide file tree
Showing 8 changed files with 229 additions and 49 deletions.
3 changes: 3 additions & 0 deletions CHANGES
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@
* Deprecate ``match_querystring`` argument in ``Response`` and ``CallbackResponse``.
Use `responses.matchers.query_param_matcher` or `responses.matchers.query_string_matcher`
* Added support for non-UTF-8 bytes in `responses.matchers.multipart_matcher`
* Added `responses.registries`. Now user can create custom registries to
manipulate the order of responses in the match algorithm
`responses.activate(registry=CustomRegistry)`

0.16.0
------
Expand Down
43 changes: 43 additions & 0 deletions README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -387,6 +387,49 @@ include any additional headers.
resp = session.send(prepped)
assert resp.text == "hello world"
Response Registry
---------------------------

By default, ``responses`` will search all registered``Response`` objects and
return a match. If only one ``Response`` is registered, the registry is kept unchanged.
However, if multiple matches are found for the same request, then first match is returned and
removed from registry.

Such behavior is suitable for most of use cases, but to handle special conditions, you can
implement custom registry which must follow interface of ``registries.FirstMatchRegistry``.
Redefining the ``find`` method will allow you to create custom search logic and return
appropriate ``Response``

Example that shows how to set custom registry

.. code-block:: python
import responses
from responses import registries
class CustomRegistry(registries.FirstMatchRegistry):
pass
""" Before tests: <responses.registries.FirstMatchRegistry object> """
# using function decorator
@responses.activate(registry=CustomRegistry)
def run():
""" Within test: <__main__.CustomRegistry object> """
run()
""" After test: <responses.registries.FirstMatchRegistry object> """
# using context manager
with responses.RequestsMock(registry=CustomRegistry) as rsps:
""" In context manager: <__main__.CustomRegistry object> """
"""
After exit from context manager: <responses.registries.FirstMatchRegistry object>
"""
Dynamic Responses
-----------------

Expand Down
70 changes: 38 additions & 32 deletions responses/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from requests.utils import cookiejar_from_dict
from responses.matchers import json_params_matcher as _json_params_matcher
from responses.matchers import urlencoded_params_matcher as _urlencoded_params_matcher
from responses.registries import FirstMatchRegistry
from responses.matchers import query_string_matcher as _query_string_matcher
from warnings import warn

Expand Down Expand Up @@ -164,7 +165,7 @@ def wrapper%(wrapper_args)s:
"""


def get_wrapped(func, responses):
def get_wrapped(func, responses, registry=None):
if six.PY2:
args, a, kw, defaults = inspect.getargspec(func)
wrapper_args = inspect.formatargspec(args, a, kw, defaults)
Expand Down Expand Up @@ -201,6 +202,9 @@ def get_wrapped(func, responses):
signature = signature.replace(parameters=params_without_defaults)
func_args = str(signature)

if registry is not None:
responses._set_registry(registry)

evaldict = {"func": func, "responses": responses}
six.exec_(
_wrapper_template % {"wrapper_args": wrapper_args, "func_args": func_args},
Expand Down Expand Up @@ -563,18 +567,33 @@ def __init__(
response_callback=None,
passthru_prefixes=(),
target="requests.adapters.HTTPAdapter.send",
registry=FirstMatchRegistry,
):
self._calls = CallList()
self.reset()
self._registry = registry() # call only after reset
self.assert_all_requests_are_fired = assert_all_requests_are_fired
self.response_callback = response_callback
self.passthru_prefixes = tuple(passthru_prefixes)
self.target = target
self._patcher = None
self._matches = []

def _get_registry(self):
return self._registry

def _set_registry(self, new_registry):
if self.registered():
err_msg = (
"Cannot replace Registry, current registry has responses.\n"
"Run 'responses.registry.reset()' first"
)
raise AttributeError(err_msg)

self._registry = new_registry()

def reset(self):
self._matches = []
self._registry = FirstMatchRegistry()
self._calls.reset()
self.passthru_prefixes = ()

Expand Down Expand Up @@ -616,13 +635,13 @@ def add(
"""
if isinstance(method, BaseResponse):
self._matches.append(method)
self._registry.add(method)
return

if adding_headers is not None:
kwargs.setdefault("headers", adding_headers)

self._matches.append(Response(method=method, url=url, body=body, **kwargs))
self._registry.add(Response(method=method, url=url, body=body, **kwargs))

def add_passthru(self, prefix):
"""
Expand Down Expand Up @@ -657,8 +676,7 @@ def remove(self, method_or_response=None, url=None):
else:
response = BaseResponse(method=method_or_response, url=url)

while response in self._matches:
self._matches.remove(response)
self._registry.remove(response)

def replace(self, method_or_response=None, url=None, body="", *args, **kwargs):
"""
Expand All @@ -676,11 +694,7 @@ def replace(self, method_or_response=None, url=None, body="", *args, **kwargs):
else:
response = Response(method=method_or_response, url=url, body=body, **kwargs)

try:
index = self._matches.index(response)
except ValueError:
raise ValueError("Response is not registered for URL %s" % url)
self._matches[index] = response
self._registry.replace(response)

def upsert(self, method_or_response=None, url=None, body="", *args, **kwargs):
"""
Expand Down Expand Up @@ -709,7 +723,7 @@ def add_callback(
# ensure the url has a default path set if the url is a string
# url = _ensure_url_default_path(url, match_querystring)

self._matches.append(
self._registry.add(
CallbackResponse(
url=url,
method=method,
Expand All @@ -721,7 +735,7 @@ def add_callback(
)

def registered(self):
return self._matches
return self._registry.registered

@property
def calls(self):
Expand All @@ -737,8 +751,14 @@ def __exit__(self, type, value, traceback):
self.reset()
return success

def activate(self, func):
return get_wrapped(func, self)
def activate(self, func=None, registry=None):
if func is not None:
return get_wrapped(func, self)

def deco_activate(func):
return get_wrapped(func, self, registry)

return deco_activate

def _find_match(self, request):
"""
Expand All @@ -749,21 +769,7 @@ def _find_match(self, request):
(Response) found match. If multiple found, then remove & return the first match.
(list) list with reasons why other matches don't match
"""
found = None
found_match = None
match_failed_reasons = []
for i, match in enumerate(self._matches):
match_result, reason = match.matches(request)
if match_result:
if found is None:
found = i
found_match = match
else:
# Multiple matches found. Remove & return the first match.
return self._matches.pop(found), match_failed_reasons
else:
match_failed_reasons.append(reason)
return found_match, match_failed_reasons
return self._registry.find(request)

def _parse_request_params(self, url):
params = {}
Expand Down Expand Up @@ -802,7 +808,7 @@ def _on_request(self, adapter, request, **kwargs):
"- %s %s\n\n"
"Available matches:\n" % (request.method, request.url)
)
for i, m in enumerate(self._matches):
for i, m in enumerate(self.registered()):
error_msg += "- {} {} {}\n".format(
m.method, m.url, match_failed_reasons[i]
)
Expand Down Expand Up @@ -846,7 +852,7 @@ def stop(self, allow_assert=True):
if not allow_assert:
return

not_called = [m for m in self._matches if m.call_count == 0]
not_called = [m for m in self.registered() if m.call_count == 0]
if not_called:
raise AssertionError(
"Not all requests have been executed {0!r}".format(
Expand Down
18 changes: 10 additions & 8 deletions responses/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ from unittest import mock as std_mock
from urllib.parse import quote as quote
from urllib3.response import HTTPHeaderDict
from .matchers import urlencoded_params_matcher, json_params_matcher
from .registries import FirstMatchRegistry


def _clean_unicode(url: str) -> str: ...
Expand All @@ -38,7 +39,7 @@ def _handle_body(
def _has_unicode(s: str) -> bool: ...
def _is_string(s: Union[Pattern[str], str]) -> bool: ...
def get_wrapped(
func: Callable[..., Any], responses: RequestsMock
func: Callable[..., Any], responses: RequestsMock, registry: Optional[Any]
) -> Callable[..., Any]: ...


Expand Down Expand Up @@ -146,6 +147,8 @@ class OriginalResponseShim:
) -> None: ...
def isclosed(self) -> bool: ...

_F = TypeVar("_F", bound=Callable[..., Any])

class RequestsMock:
DELETE: Literal["DELETE"]
GET: Literal["GET"]
Expand All @@ -165,6 +168,7 @@ class RequestsMock:
response_callback: Optional[Callable[[Any], Any]] = ...,
passthru_prefixes: Tuple[str, ...] = ...,
target: str = ...,
registry: Any = ...,
) -> None:
self._patcher = Callable[[Any], Any]
self._calls = CallList
Expand All @@ -184,19 +188,17 @@ class RequestsMock:
def calls(self) -> CallList: ...
def __enter__(self) -> RequestsMock: ...
def __exit__(self, type: Any, value: Any, traceback: Any) -> bool: ...
activate: _Activate
def activate(self, func: Optional[_F], registry: Optional[Any]) -> _F: ...
def start(self) -> None: ...
def stop(self, allow_assert: bool = ...) -> None: ...
def assert_call_count(self, url: str, count: int) -> bool: ...
def registered(self) -> List[Any]: ...
def _set_registry(self, registry: Any) -> None: ...
def _get_registry(self) -> Any: ...

_F = TypeVar("_F", bound=Callable[..., Any])

HeaderSet = Optional[Union[Mapping[str, str], List[Tuple[str, str]]]]

class _Activate(Protocol):
def __call__(self, func: _F) -> _F: ...

class _Add(Protocol):
def __call__(
self,
Expand Down Expand Up @@ -273,7 +275,7 @@ class _Registered(Protocol):
def __call__(self) -> List[Response]: ...


activate: _Activate
activate: Any
add: _Add
add_callback: _AddCallback
add_passthru: _AddPassthru
Expand Down Expand Up @@ -328,5 +330,5 @@ __all__ = [
"start",
"stop",
"target",
"upsert"
"upsert",
]
43 changes: 43 additions & 0 deletions responses/registries.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
class FirstMatchRegistry(object):
def __init__(self):
self._responses = []

@property
def registered(self):
return self._responses

def reset(self):
self._responses = []

def find(self, request):
found = None
found_match = None
match_failed_reasons = []
for i, response in enumerate(self.registered):
match_result, reason = response.matches(request)
if match_result:
if found is None:
found = i
found_match = response
else:
# Multiple matches found. Remove & return the first response.
return self.registered.pop(found), match_failed_reasons
else:
match_failed_reasons.append(reason)
return found_match, match_failed_reasons

def add(self, response):
self.registered.append(response)

def remove(self, response):
while response in self.registered:
self.registered.remove(response)

def replace(self, response):
try:
index = self.registered.index(response)
except ValueError:
raise ValueError(
"Response is not registered for URL {}".format(response.url)
)
self.registered[index] = response
17 changes: 17 additions & 0 deletions responses/registries.pyi
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
from typing import (
List,
Tuple,
)
from requests.adapters import PreparedRequest
from responses import BaseResponse

class FirstMatchRegistry:
_responses = List[BaseResponse]
def __init__(self) -> None: ...
@property
def registered(self) -> List[BaseResponse]: ...
def reset(self) -> None: ...
def find(self, request: PreparedRequest) -> Tuple[BaseResponse, List[str]]: ...
def add(self, response: BaseResponse) -> None: ...
def remove(self, response: BaseResponse) -> None: ...
def replace(self, response: BaseResponse) -> None: ...
2 changes: 1 addition & 1 deletion responses/test_matchers.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def assert_response(resp, body=None, content_type="text/plain"):


def assert_reset():
assert len(responses._default_mock._matches) == 0
assert len(responses._default_mock.registered()) == 0
assert len(responses.calls) == 0


Expand Down
Loading

0 comments on commit 3dccae8

Please sign in to comment.