Skip to content

Commit

Permalink
Add with_header_from_request function
Browse files Browse the repository at this point in the history
Signed-off-by: Alex Johnson <[email protected]>
  • Loading branch information
alex-kattathra-johnson committed Nov 12, 2024
1 parent 0168e88 commit 23d2479
Show file tree
Hide file tree
Showing 3 changed files with 77 additions and 6 deletions.
40 changes: 38 additions & 2 deletions src/mock.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use crate::diff;
use crate::matcher::{Matcher, PathAndQueryMatcher, RequestMatcher};
use crate::response::{Body, Response};
use crate::response::{Body, Header, Response};
use crate::server::RemoteMock;
use crate::server::State;
use crate::Request;
Expand Down Expand Up @@ -370,11 +370,47 @@ impl Mock {
self.inner
.response
.headers
.append(field.into_header_name(), value.to_owned());
.append(field.into_header_name(), Header::String(value.to_string()));

self
}

///
/// Sets the headers of the mock response dynamically while exposing the request object.
///
/// You can use this method to provide custom headers for every incoming request.
///
/// The function must be thread-safe. If it's a closure, it can't be borrowing its context.
/// Use `move` closures and `Arc` to share any data.
///
/// ### Example
///
/// ```
/// let mut s = mockito::Server::new();
///
/// let _m = s.mock("GET", mockito::Matcher::Any).with_header_from_request("user", |request| {
/// if request.path() == "/bob" {
/// "bob".into()
/// } else if request.path() == "/alice" {
/// "alice".into()
/// } else {
/// "everyone".into()
/// }
/// });
/// ```
///
pub fn with_header_from_request<T: IntoHeaderName>(
mut self,
field: T,
callback: impl Fn(&Request) -> String + Send + Sync + 'static,
) -> Self {
self.inner.response.headers.append(
field.into_header_name(),
Header::FnWithRequest(Arc::new(move |req| callback(req))),
);
self
}

///
/// Sets the body of the mock response. Its `Content-Length` is handled automatically.
///
Expand Down
34 changes: 32 additions & 2 deletions src/response.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,40 @@ use tokio::sync::mpsc;
#[derive(Clone, Debug, PartialEq)]
pub(crate) struct Response {
pub status: StatusCode,
pub headers: HeaderMap<String>,
pub headers: HeaderMap<Header>,
pub body: Body,
}

#[derive(Clone)]
pub(crate) enum Header {
String(String),
FnWithRequest(Arc<HeaderFnWithRequest>),
}

impl fmt::Debug for Header {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match *self {
Header::String(ref s) => s.fmt(f),
Header::FnWithRequest(_) => f.write_str("<callback>"),
}
}
}

impl PartialEq for Header {
fn eq(&self, other: &Self) -> bool {
match (self, other) {
(Header::String(ref a), Header::String(ref b)) => a == b,
(Header::FnWithRequest(ref a), Header::FnWithRequest(ref b)) => std::ptr::eq(
a.as_ref() as *const HeaderFnWithRequest as *const u8,
b.as_ref() as *const HeaderFnWithRequest as *const u8,
),
_ => false,
}
}
}

type HeaderFnWithRequest = dyn Fn(&Request) -> String + Send + Sync;

type BodyFnWithWriter = dyn Fn(&mut dyn io::Write) -> io::Result<()> + Send + Sync + 'static;
type BodyFnWithRequest = dyn Fn(&Request) -> Bytes + Send + Sync + 'static;

Expand Down Expand Up @@ -57,7 +87,7 @@ impl PartialEq for Body {
impl Default for Response {
fn default() -> Self {
let mut headers = HeaderMap::with_capacity(1);
headers.insert("connection", "close".parse().unwrap());
headers.insert("connection", Header::String("close".to_string()));
Self {
status: StatusCode::OK,
headers,
Expand Down
9 changes: 7 additions & 2 deletions src/server.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use crate::mock::InnerMock;
use crate::request::Request;
use crate::response::{Body as ResponseBody, ChunkedStream};
use crate::response::{Body as ResponseBody, ChunkedStream, Header};
use crate::ServerGuard;
use crate::{Error, ErrorKind, Matcher, Mock};
use bytes::Bytes;
Expand Down Expand Up @@ -559,7 +559,12 @@ fn respond_with_mock(request: Request, mock: &RemoteMock) -> Result<Response<Bod
let mut response = Response::builder().status(status);

for (name, value) in mock.inner.response.headers.iter() {
response = response.header(name, value);
match value {
Header::String(value) => response = response.header(name, value),
Header::FnWithRequest(header_fn) => {
response = response.header(name, header_fn(&request))
}
}
}

let body = if request.method() != "HEAD" {
Expand Down

0 comments on commit 23d2479

Please sign in to comment.