Skip to content

Commit

Permalink
do not route packet from relay addr
Browse files Browse the repository at this point in the history
  • Loading branch information
gfreezy committed Aug 30, 2024
1 parent ee132ee commit 21eaf06
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 20 deletions.
26 changes: 10 additions & 16 deletions config/src/rule.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ use std::net::IpAddr;
use std::path::{Path, PathBuf};
use std::str::FromStr;
use std::sync::{Arc, Once};
use std::thread;

#[derive(Debug, Clone, Eq, PartialEq)]
pub enum Rule {
Expand Down Expand Up @@ -43,10 +42,12 @@ impl ProxyRules {
let s = Self {
rules: Arc::new(RwLock::new(rules)),
geo_ip_db: Arc::new(Mutex::new(None)),
geo_ip_path,
geo_ip_path: geo_ip_path.clone(),
default_download_path: default_geo_ip_path(),
};
s.init_geo_ip_db(true);
if geo_ip_path.is_some() {
s.init_geo_ip_db();
}
s
}

Expand All @@ -60,7 +61,7 @@ impl ProxyRules {
geo_ip_path,
default_download_path: default_geo_ip_path(),
};
s.init_geo_ip_db(false);
s.init_geo_ip_db();
s
}

Expand Down Expand Up @@ -95,26 +96,18 @@ impl ProxyRules {
success
}

fn init_geo_ip_db(&self, background: bool) {
fn init_geo_ip_db(&self) {
let default_path = self.default_download_path.clone();
let path = match &self.geo_ip_path {
Some(path) => {
tracing::info!("geoip path: {:?}", path);
// Check if path is a valid http or https url
if path.starts_with("http://") || path.starts_with("https://") {
if !default_path.exists() {
let self_clone = self.clone();
static ONCE: std::sync::Once = Once::new();
if background {
ONCE.call_once(|| {
thread::spawn(move || {
let _ = self_clone.download_geoip_database();
});
});
return;
} else {
let _ = self_clone.download_geoip_database();
}
ONCE.call_once(|| {
let _ = self.download_geoip_database();
});
}
default_path
} else {
Expand Down Expand Up @@ -208,6 +201,7 @@ impl ProxyRules {

pub(crate) fn set_geo_ip_path(&mut self, path: Option<PathBuf>) {
self.geo_ip_path = path;
self.init_geo_ip_db();
}
}

Expand Down
10 changes: 6 additions & 4 deletions dnsserver/src/resolver.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use std::io;
use std::io::Result;
use std::sync::Arc;
use store::Store;
use tracing::{debug, error};
use tracing::{debug, error, info};

/// A Forwarding DNS Resolver
///
Expand Down Expand Up @@ -55,7 +55,7 @@ impl RuleBasedDnsResolver {
.await
.map_err(|e| {
let msg = e.to_string();
error!("directly lookup host error: {}", &msg);
info!("directly lookup host error: {}", &msg);
io::Error::new(io::ErrorKind::Other, msg)
})?;
for record in lookup.record_iter() {
Expand Down Expand Up @@ -114,7 +114,7 @@ impl RuleBasedDnsResolver {
ttl: TransientTtl(record.ttl()),
},
other => {
tracing::error!("unsupported record type: {:?}", other);
tracing::info!("unsupported record type: {:?}", other);
continue;
}
};
Expand All @@ -129,7 +129,7 @@ impl RuleBasedDnsResolver {

async fn resolve(&self, domain: &str, qtype: QueryType) -> Result<DnsPacket> {
// We only support A record for now, for other records, we just forward them to upstream.
if !matches!(qtype, QueryType::A | QueryType::AAAA) {
if !matches!(qtype, QueryType::A) {
return self.resolve_real(domain, qtype).await;
}

Expand Down Expand Up @@ -182,6 +182,7 @@ impl RuleBasedDnsResolver {
match self.inner.rules.action_for_domain(Some(domain), ip) {
// Return real ip when `bypass_direct` is true.
Some(Action::Direct) if bypass_direct => {
tracing::info!("bypass_direct, domain: {:?}, ip: {:?}", domain, ip);
return real_packet;
}
// Do not return dns records when action is reject.
Expand All @@ -197,6 +198,7 @@ impl RuleBasedDnsResolver {
addr: ip,
ttl: TransientTtl(3),
});
tracing::info!("lookup domain: {:?}, ip: {:?}", domain, ip);
Ok(packet)
}
}
Expand Down
7 changes: 7 additions & 0 deletions tun_nat/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,13 @@ pub fn run_nat(
Ok(p) => p,
};

// convert relay_addr to bytes
let relay_addr_bytes = relay_addr.octets();
if ipv4_packet.dst_addr().as_bytes() == relay_addr_bytes {
tracing::info!("tun_nat: drop packet from relay_addr");
continue;
}

if let Some(packet) = match ipv4_packet.protocol() {
IpProtocol::Udp => route_packet!(
UdpPacket,
Expand Down

0 comments on commit 21eaf06

Please sign in to comment.