Skip to content

Commit

Permalink
Allow matching the Request object based on a closure
Browse files Browse the repository at this point in the history
  • Loading branch information
lipanski committed Nov 11, 2024
1 parent e115abb commit ad063c1
Show file tree
Hide file tree
Showing 6 changed files with 165 additions and 2 deletions.
24 changes: 24 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -472,6 +472,30 @@
//! .create();
//! ```
//!
//! # Custom matchers
//!
//! If you need a more custom matcher, you can use the [`Mock::match_request`] function, which
//! takes a closure and exposes the [`Request`] object as an argument. The closure should return
//! a boolean value.
//!
//! ## Example
//!
//! ```
//! use mockito::Matcher;
//!
//! let mut s = mockito::Server::new();
//!
//! // This will match requests that have the x-test header set
//! // and contain the word "hello" inside the body
//! s.mock("GET", "/")
//! .match_request(|request| {
//! request.has_header("x-test") &&
//! request.utf8_lossy_body().unwrap().contains("hello")
//! })
//! .create();
//!
//! ```
//!
//! # Asserts
//!
//! You can use the [`Mock::assert`] method to **assert that a mock was called**. In other words,
Expand Down
32 changes: 32 additions & 0 deletions src/matcher.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use crate::request::Request;
use assert_json_diff::{assert_json_matches_no_panic, CompareMode};
use http::header::HeaderValue;
use regex::Regex;
Expand All @@ -9,6 +10,7 @@ use std::io;
use std::io::Read;
use std::path::Path;
use std::string::ToString;
use std::sync::Arc;

///
/// Allows matching the request path, headers or body in multiple ways: by the exact value, by any value (as
Expand Down Expand Up @@ -281,3 +283,33 @@ impl fmt::Display for BinaryBody {
}
}
}

#[derive(Clone)]
pub(crate) struct RequestMatcher(Arc<dyn Fn(&Request) -> bool + Send + Sync>);

impl RequestMatcher {
pub(crate) fn matches(&self, value: &Request) -> bool {
self.0(value)
}
}

impl<F> From<F> for RequestMatcher
where
F: Fn(&Request) -> bool + Send + Sync + 'static,
{
fn from(value: F) -> Self {
Self(Arc::new(value))
}
}

impl Default for RequestMatcher {
fn default() -> Self {
RequestMatcher(Arc::new(|_| true))
}
}

impl fmt::Debug for RequestMatcher {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "(RequestMatcher)")
}
}
34 changes: 33 additions & 1 deletion src/mock.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use crate::diff;
use crate::matcher::{Matcher, PathAndQueryMatcher};
use crate::matcher::{Matcher, PathAndQueryMatcher, RequestMatcher};
use crate::response::{Body, Response};
use crate::server::RemoteMock;
use crate::server::State;
Expand Down Expand Up @@ -67,6 +67,7 @@ pub struct InnerMock {
pub(crate) path: PathAndQueryMatcher,
pub(crate) headers: HeaderMap<Matcher>,
pub(crate) body: Matcher,
pub(crate) request_matcher: RequestMatcher,
pub(crate) response: Response,
pub(crate) hits: usize,
pub(crate) expected_hits_at_least: Option<usize>,
Expand Down Expand Up @@ -161,6 +162,7 @@ impl Mock {
path: PathAndQueryMatcher::Unified(path.into()),
headers: HeaderMap::<Matcher>::default(),
body: Matcher::Any,
request_matcher: RequestMatcher::default(),
response: Response::default(),
hits: 0,
expected_hits_at_least: None,
Expand Down Expand Up @@ -303,6 +305,36 @@ impl Mock {
self
}

///
/// Allows matching the entire request based on a closure that takes
/// the [`Request`] object as an argument and returns a boolean value.
///
/// ## Example
///
/// ```
/// use mockito::Matcher;
///
/// let mut s = mockito::Server::new();
///
/// // This will match requests that have the x-test header set
/// // and contain the word "hello" inside the body
/// s.mock("GET", "/")
/// .match_request(|request| {
/// request.has_header("x-test") &&
/// request.utf8_lossy_body().unwrap().contains("hello")
/// })
/// .create();
/// ```
///
pub fn match_request<F>(mut self, request_matcher: F) -> Self
where
F: Fn(&Request) -> bool + Send + Sync + 'static,
{
self.inner.request_matcher = request_matcher.into();

self
}

///
/// Sets the status code of the mock response. The default status code is 200.
///
Expand Down
9 changes: 8 additions & 1 deletion src/request.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ use http::header::{AsHeaderName, HeaderValue};
use http::Request as HttpRequest;
use http_body_util::BodyExt;
use hyper::body::Incoming;
use std::borrow::Cow;

///
/// Stores a HTTP request
Expand Down Expand Up @@ -51,13 +52,19 @@ impl Request {
}

/// Returns the request body or an error, if the body hasn't been read
/// up to this moment.
/// yet.
pub fn body(&self) -> Result<&Vec<u8>, Error> {
self.body
.as_ref()
.ok_or_else(|| Error::new(ErrorKind::RequestBodyFailure))
}

/// Returns the request body as UTF8 or an error, if the body hasn't
/// been read yet.
pub fn utf8_lossy_body(&self) -> Result<Cow<'_, str>, Error> {
self.body().map(|body| String::from_utf8_lossy(body))
}

/// Reads the body (if it hasn't been read already) and returns it
pub(crate) async fn read_body(&mut self) -> &Vec<u8> {
if self.body.is_none() {
Expand Down
5 changes: 5 additions & 0 deletions src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ impl RemoteMock {
&& self.path_matches(other)
&& self.headers_match(other)
&& self.body_matches(other)
&& self.request_matches(other)
}

fn method_matches(&self, request: &Request) -> bool {
Expand All @@ -65,6 +66,10 @@ impl RemoteMock {
self.inner.body.matches_value(safe_body) || self.inner.body.matches_binary_value(body)
}

fn request_matches(&self, request: &Request) -> bool {
self.inner.request_matcher.matches(request)
}

#[allow(clippy::missing_const_for_fn)]
fn is_missing_hits(&self) -> bool {
match (
Expand Down
63 changes: 63 additions & 0 deletions tests/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1879,6 +1879,69 @@ fn test_anyof_exact_path_and_query_matcher() {
mock.assert();
}

#[test]
fn test_request_matcher_path() {
let mut s = Server::new();
let host = s.host_with_port();
let m = s
.mock("GET", Matcher::Any)
.match_request(|req| req.path().contains("hello"))
.with_body("world")
.create();

let (status_line, _, _) = request(&host, "GET /", "");
assert_eq!("HTTP/1.1 501 Not Implemented\r\n", status_line);

let (status_line, _, body) = request(host, "GET /hello", "");
assert_eq!("HTTP/1.1 200 OK\r\n", status_line);
assert_eq!("world", body);

m.assert();
}

#[test]
fn test_request_matcher_headers() {
let mut s = Server::new();
let host = s.host_with_port();
let m = s
.mock("GET", "/")
.match_request(|req| req.has_header("x-test"))
.with_body("world")
.create();

let (status_line, _, _) = request(&host, "GET /", "");
assert_eq!("HTTP/1.1 501 Not Implemented\r\n", status_line);

let (status_line, _, body) = request(host, "GET /", "x-test: 1\r\n");
assert_eq!("HTTP/1.1 200 OK\r\n", status_line);
assert_eq!("world", body);

m.assert();
}

#[test]
fn test_request_matcher_body() {
let mut s = Server::new();
let host = s.host_with_port();
let m = s
.mock("GET", "/")
.match_request(|req| {
let body = req.utf8_lossy_body().unwrap();
body.contains("hello")
})
.with_body("world")
.create();

let (status_line, _, _) = request_with_body(&host, "GET /", "", "bye");
assert_eq!("HTTP/1.1 501 Not Implemented\r\n", status_line);

let (status_line, _, body) = request_with_body(host, "GET /", "", "hello");
assert_eq!("HTTP/1.1 200 OK\r\n", status_line);
assert_eq!("world", body);

m.assert();
}

#[test]
fn test_default_headers() {
let mut s = Server::new();
Expand Down

0 comments on commit ad063c1

Please sign in to comment.