From 3b5a8053d9b9547deca6817fcac99f4cc1b24323 Mon Sep 17 00:00:00 2001 From: zu1k Date: Sun, 10 Dec 2023 17:58:08 +0800 Subject: [PATCH 1/2] feat: Add Connector trait --- Cargo.lock | 11 +++++++++ Cargo.toml | 1 + examples/bind_connect.rs | 52 ++++++++++++++++++++++++++++++++++++++++ src/agent.rs | 15 ++++++++++++ src/connect.rs | 47 ++++++++++++++++++++++++++++++++++++ src/lib.rs | 2 ++ src/stream.rs | 4 ++-- src/unit.rs | 5 ++++ 8 files changed, 135 insertions(+), 2 deletions(-) create mode 100644 examples/bind_connect.rs create mode 100644 src/connect.rs diff --git a/Cargo.lock b/Cargo.lock index a708ff82..f1419fd3 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -662,6 +662,16 @@ dependencies = [ "serde", ] +[[package]] +name = "socket2" +version = "0.5.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7b5fac59a5cb5dd637972e5fca70daf0523c9067fcdc4842f053dae04a18f8e9" +dependencies = [ + "libc", + "windows-sys", +] + [[package]] name = "socks" version = "0.3.4" @@ -811,6 +821,7 @@ dependencies = [ "rustls-webpki", "serde", "serde_json", + "socket2", "socks", "url", "webpki-roots", diff --git a/Cargo.toml b/Cargo.toml index e2e1d116..d2609dff 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -59,6 +59,7 @@ serde = { version = "1", features = ["derive"] } env_logger = "0.10" rustls = { version = "0.22.0" } rustls-pemfile = { version = "2.0" } +socket2 = "0.5" [[example]] name = "cureq" diff --git a/examples/bind_connect.rs b/examples/bind_connect.rs new file mode 100644 index 00000000..57e327c4 --- /dev/null +++ b/examples/bind_connect.rs @@ -0,0 +1,52 @@ +use socket2::{Domain, Socket, Type}; +use std::net::SocketAddr; +use ureq::Connector; + +#[derive(Debug)] +pub(crate) struct BindConnector { + bind_addr: SocketAddr, +} + +impl BindConnector { + pub fn new_bind(bind_addr: SocketAddr) -> Self { + Self { bind_addr } + } +} + +impl Connector for BindConnector { + fn connect(&self, addr: &std::net::SocketAddr) -> std::io::Result { + let socket = Socket::new(Domain::for_address(addr.to_owned()), Type::STREAM, None)?; + socket.bind(&self.bind_addr.into())?; + socket.connect(&addr.to_owned().into())?; + Ok(socket.into()) + } + + fn connect_timeout( + &self, + addr: &std::net::SocketAddr, + timeout: std::time::Duration, + ) -> std::io::Result { + let socket = Socket::new(Domain::for_address(addr.to_owned()), Type::STREAM, None)?; + socket.bind(&self.bind_addr.into())?; + socket.connect_timeout(&addr.to_owned().into(), timeout)?; + Ok(socket.into()) + } +} + +pub fn main() { + let agent = ureq::builder() + .connector(BindConnector::new_bind("127.0.0.1:54321".parse().unwrap())) + .build(); + + let result = agent.get("http://127.0.0.1:8080/").call(); + + match result { + Err(err) => { + println!("{:?}", err); + std::process::exit(1); + } + Ok(response) => { + assert_eq!(response.status(), 200); + } + } +} diff --git a/src/agent.rs b/src/agent.rs index 986d3f92..fbf4da4a 100644 --- a/src/agent.rs +++ b/src/agent.rs @@ -4,6 +4,7 @@ use std::sync::Arc; use std::time::Duration; use url::Url; +use crate::connect::{ArcConnector, StdTcpConnector}; use crate::middleware::Middleware; use crate::pool::ConnectionPool; use crate::proxy::Proxy; @@ -45,6 +46,7 @@ pub struct AgentBuilder { #[cfg(feature = "cookies")] cookie_store: Option, resolver: ArcResolver, + connector: ArcConnector, middleware: Vec>, } @@ -126,6 +128,7 @@ pub(crate) struct AgentState { #[cfg(feature = "cookies")] pub(crate) cookie_tin: CookieTin, pub(crate) resolver: ArcResolver, + pub(crate) connector: ArcConnector, pub(crate) middleware: Vec>, } @@ -271,6 +274,7 @@ impl AgentBuilder { max_idle_connections: DEFAULT_MAX_IDLE_CONNECTIONS, max_idle_connections_per_host: DEFAULT_MAX_IDLE_CONNECTIONS_PER_HOST, resolver: StdResolver.into(), + connector: StdTcpConnector.into(), #[cfg(feature = "cookies")] cookie_store: None, middleware: vec![], @@ -298,6 +302,7 @@ impl AgentBuilder { #[cfg(feature = "cookies")] cookie_tin: CookieTin::new(self.cookie_store.unwrap_or_else(CookieStore::default)), resolver: self.resolver, + connector: self.connector, middleware: self.middleware, }), } @@ -402,6 +407,16 @@ impl AgentBuilder { self } + /// Configures a custom connector to be used by this agent. By default, + /// tcp-connect is done by std::net::TcpStream. This allows you + /// to override that connection with your own alternative. + /// + /// See `examples/bind_connect.rs` for example. + pub fn connector(mut self, connector: impl crate::Connector + 'static) -> Self { + self.connector = connector.into(); + self + } + /// Timeout for the socket connection to be successful. /// If both this and `.timeout()` are both set, `.timeout_connect()` /// takes precedence. diff --git a/src/connect.rs b/src/connect.rs new file mode 100644 index 00000000..475964c6 --- /dev/null +++ b/src/connect.rs @@ -0,0 +1,47 @@ +use std::fmt; +use std::io::Result as IoResult; +use std::net::{SocketAddr, TcpStream}; +use std::sync::Arc; +use std::time::Duration; + +/// A custom Connector to override the default TcpStream connector. +pub trait Connector: Send + Sync { + fn connect(&self, addr: &SocketAddr) -> IoResult { + TcpStream::connect(addr) + } + + fn connect_timeout(&self, addr: &SocketAddr, timeout: Duration) -> IoResult { + TcpStream::connect_timeout(addr, timeout) + } +} + +#[derive(Debug)] +pub(crate) struct StdTcpConnector; + +impl Connector for StdTcpConnector {} + +#[derive(Clone)] +pub(crate) struct ArcConnector(Arc); + +impl From for ArcConnector +where + R: Connector + 'static, +{ + fn from(r: R) -> Self { + Self(Arc::new(r)) + } +} + +impl fmt::Debug for ArcConnector { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "ArcConnector(...)") + } +} + +impl std::ops::Deref for ArcConnector { + type Target = dyn Connector; + + fn deref(&self) -> &Self::Target { + self.0.as_ref() + } +} diff --git a/src/lib.rs b/src/lib.rs index 70d42e49..5a9b67f4 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -357,6 +357,7 @@ mod agent; mod body; mod chunked; +mod connect; mod error; mod header; mod middleware; @@ -429,6 +430,7 @@ mod http_crate; pub use crate::agent::Agent; pub use crate::agent::AgentBuilder; pub use crate::agent::RedirectAuthHeaders; +pub use crate::connect::Connector; pub use crate::error::{Error, ErrorKind, OrAnyStatus, Transport}; pub use crate::middleware::{Middleware, MiddlewareNext}; pub use crate::proxy::Proxy; diff --git a/src/stream.rs b/src/stream.rs index 8e7a36ee..ed0d77ad 100644 --- a/src/stream.rs +++ b/src/stream.rs @@ -406,9 +406,9 @@ pub(crate) fn connect_host( proto.unwrap(), ) } else if let Some(timeout) = timeout { - TcpStream::connect_timeout(&sock_addr, timeout) + unit.connector().connect_timeout(&sock_addr, timeout) } else { - TcpStream::connect(sock_addr) + unit.connector().connect(&sock_addr) }; if let Ok(stream) = stream { diff --git a/src/unit.rs b/src/unit.rs index 13c3c9a6..ba70d1a3 100644 --- a/src/unit.rs +++ b/src/unit.rs @@ -12,6 +12,7 @@ use cookie::Cookie; use crate::agent::RedirectAuthHeaders; use crate::body::{self, BodySize, Payload, SizedReader}; +use crate::connect::ArcConnector; use crate::error::{Error, ErrorKind}; use crate::header; use crate::header::{get_header, Header}; @@ -115,6 +116,10 @@ impl Unit { self.agent.state.resolver.clone() } + pub fn connector(&self) -> ArcConnector { + self.agent.state.connector.clone() + } + #[cfg(test)] pub fn header(&self, name: &str) -> Option<&str> { header::get_header(&self.headers, name) From 940e9f9ae6c2cb97f28665cfb9992ada9a5fee59 Mon Sep 17 00:00:00 2001 From: zu1k Date: Sun, 10 Dec 2023 19:15:16 +0800 Subject: [PATCH 2/2] Rename Connector to TcpConnector --- examples/bind_connect.rs | 6 +++--- src/agent.rs | 20 ++++++++++---------- src/connect.rs | 18 +++++++++--------- src/lib.rs | 2 +- src/stream.rs | 4 ++-- src/unit.rs | 6 +++--- 6 files changed, 28 insertions(+), 28 deletions(-) diff --git a/examples/bind_connect.rs b/examples/bind_connect.rs index 57e327c4..f585e36c 100644 --- a/examples/bind_connect.rs +++ b/examples/bind_connect.rs @@ -1,6 +1,6 @@ use socket2::{Domain, Socket, Type}; use std::net::SocketAddr; -use ureq::Connector; +use ureq::TcpConnector; #[derive(Debug)] pub(crate) struct BindConnector { @@ -13,7 +13,7 @@ impl BindConnector { } } -impl Connector for BindConnector { +impl TcpConnector for BindConnector { fn connect(&self, addr: &std::net::SocketAddr) -> std::io::Result { let socket = Socket::new(Domain::for_address(addr.to_owned()), Type::STREAM, None)?; socket.bind(&self.bind_addr.into())?; @@ -35,7 +35,7 @@ impl Connector for BindConnector { pub fn main() { let agent = ureq::builder() - .connector(BindConnector::new_bind("127.0.0.1:54321".parse().unwrap())) + .tcp_connector(BindConnector::new_bind("127.0.0.1:54321".parse().unwrap())) .build(); let result = agent.get("http://127.0.0.1:8080/").call(); diff --git a/src/agent.rs b/src/agent.rs index fbf4da4a..05cd22b9 100644 --- a/src/agent.rs +++ b/src/agent.rs @@ -4,7 +4,7 @@ use std::sync::Arc; use std::time::Duration; use url::Url; -use crate::connect::{ArcConnector, StdTcpConnector}; +use crate::connect::{ArcTcpConnector, StdTcpConnector}; use crate::middleware::Middleware; use crate::pool::ConnectionPool; use crate::proxy::Proxy; @@ -46,7 +46,7 @@ pub struct AgentBuilder { #[cfg(feature = "cookies")] cookie_store: Option, resolver: ArcResolver, - connector: ArcConnector, + tcp_connector: ArcTcpConnector, middleware: Vec>, } @@ -128,7 +128,7 @@ pub(crate) struct AgentState { #[cfg(feature = "cookies")] pub(crate) cookie_tin: CookieTin, pub(crate) resolver: ArcResolver, - pub(crate) connector: ArcConnector, + pub(crate) tcp_connector: ArcTcpConnector, pub(crate) middleware: Vec>, } @@ -274,7 +274,7 @@ impl AgentBuilder { max_idle_connections: DEFAULT_MAX_IDLE_CONNECTIONS, max_idle_connections_per_host: DEFAULT_MAX_IDLE_CONNECTIONS_PER_HOST, resolver: StdResolver.into(), - connector: StdTcpConnector.into(), + tcp_connector: StdTcpConnector.into(), #[cfg(feature = "cookies")] cookie_store: None, middleware: vec![], @@ -302,7 +302,7 @@ impl AgentBuilder { #[cfg(feature = "cookies")] cookie_tin: CookieTin::new(self.cookie_store.unwrap_or_else(CookieStore::default)), resolver: self.resolver, - connector: self.connector, + tcp_connector: self.tcp_connector, middleware: self.middleware, }), } @@ -407,13 +407,13 @@ impl AgentBuilder { self } - /// Configures a custom connector to be used by this agent. By default, - /// tcp-connect is done by std::net::TcpStream. This allows you - /// to override that connection with your own alternative. + /// Configures a custom TCP connector to be used by this agent. + /// By default, tcp-connect is done by std::net::TcpStream. + /// This allows you to override that connection with your own alternative. /// /// See `examples/bind_connect.rs` for example. - pub fn connector(mut self, connector: impl crate::Connector + 'static) -> Self { - self.connector = connector.into(); + pub fn tcp_connector(mut self, tcp_connector: impl crate::TcpConnector + 'static) -> Self { + self.tcp_connector = tcp_connector.into(); self } diff --git a/src/connect.rs b/src/connect.rs index 475964c6..378b49be 100644 --- a/src/connect.rs +++ b/src/connect.rs @@ -5,7 +5,7 @@ use std::sync::Arc; use std::time::Duration; /// A custom Connector to override the default TcpStream connector. -pub trait Connector: Send + Sync { +pub trait TcpConnector: Send + Sync { fn connect(&self, addr: &SocketAddr) -> IoResult { TcpStream::connect(addr) } @@ -18,28 +18,28 @@ pub trait Connector: Send + Sync { #[derive(Debug)] pub(crate) struct StdTcpConnector; -impl Connector for StdTcpConnector {} +impl TcpConnector for StdTcpConnector {} #[derive(Clone)] -pub(crate) struct ArcConnector(Arc); +pub(crate) struct ArcTcpConnector(Arc); -impl From for ArcConnector +impl From for ArcTcpConnector where - R: Connector + 'static, + R: TcpConnector + 'static, { fn from(r: R) -> Self { Self(Arc::new(r)) } } -impl fmt::Debug for ArcConnector { +impl fmt::Debug for ArcTcpConnector { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "ArcConnector(...)") + write!(f, "ArcTcpConnector(...)") } } -impl std::ops::Deref for ArcConnector { - type Target = dyn Connector; +impl std::ops::Deref for ArcTcpConnector { + type Target = dyn TcpConnector; fn deref(&self) -> &Self::Target { self.0.as_ref() diff --git a/src/lib.rs b/src/lib.rs index 5a9b67f4..29c5ca86 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -430,7 +430,7 @@ mod http_crate; pub use crate::agent::Agent; pub use crate::agent::AgentBuilder; pub use crate::agent::RedirectAuthHeaders; -pub use crate::connect::Connector; +pub use crate::connect::TcpConnector; pub use crate::error::{Error, ErrorKind, OrAnyStatus, Transport}; pub use crate::middleware::{Middleware, MiddlewareNext}; pub use crate::proxy::Proxy; diff --git a/src/stream.rs b/src/stream.rs index ed0d77ad..4240189e 100644 --- a/src/stream.rs +++ b/src/stream.rs @@ -406,9 +406,9 @@ pub(crate) fn connect_host( proto.unwrap(), ) } else if let Some(timeout) = timeout { - unit.connector().connect_timeout(&sock_addr, timeout) + unit.tcp_connector().connect_timeout(&sock_addr, timeout) } else { - unit.connector().connect(&sock_addr) + unit.tcp_connector().connect(&sock_addr) }; if let Ok(stream) = stream { diff --git a/src/unit.rs b/src/unit.rs index ba70d1a3..a02d7fb9 100644 --- a/src/unit.rs +++ b/src/unit.rs @@ -12,7 +12,7 @@ use cookie::Cookie; use crate::agent::RedirectAuthHeaders; use crate::body::{self, BodySize, Payload, SizedReader}; -use crate::connect::ArcConnector; +use crate::connect::ArcTcpConnector; use crate::error::{Error, ErrorKind}; use crate::header; use crate::header::{get_header, Header}; @@ -116,8 +116,8 @@ impl Unit { self.agent.state.resolver.clone() } - pub fn connector(&self) -> ArcConnector { - self.agent.state.connector.clone() + pub fn tcp_connector(&self) -> ArcTcpConnector { + self.agent.state.tcp_connector.clone() } #[cfg(test)]