diff --git a/README.md b/README.md index 52faa96..ba1edbb 100644 --- a/README.md +++ b/README.md @@ -1,5 +1,5 @@ # go-dnsmasq -*Version 1.0.1* +*Version 1.0.2* go-dnsmasq is a light weight (1.2 MB) DNS caching server/forwarder with minimal filesystem and runtime overhead. @@ -89,3 +89,14 @@ docker run -d -p 53:53/udp -p 53:53 janeczku/go-dnsmasq:latest ``` You can configure the container by passing the corresponding environmental variables with docker run's `--env` flag. + +#### Serving A/AAAA records from a hosts file +The `--hostsfile` parameter expects a standard plain text [hosts file](https://en.wikipedia.org/wiki/Hosts_(file)) with the only difference being that a wildcard `*` in the left-most label of hostnames is allowed. Wildcard entries will match any subdomain that is not explicitely defined. +For example, given a hosts file with the following content: + +``` +192.168.0.1 db1.db.local +192.168.0.2 *.db.local +``` + +Queries for `db2.db.local` would be answered with an A record pointing to 192.168.0.2, while queries for `db1.db.local` would yield an A record pointing to 192.168.0.1. diff --git a/hostsfile/hostsfile.go b/hostsfile/hostsfile.go index 741b756..bd7f8c0 100644 --- a/hostsfile/hostsfile.go +++ b/hostsfile/hostsfile.go @@ -67,13 +67,7 @@ func (h *Hostsfile) FindHosts(name string) (addrs []net.IP, err error) { name = strings.TrimSuffix(name, ".") h.hostMutex.RLock() defer h.hostMutex.RUnlock() - - for _, hostname := range *h.hosts { - if hostname.domain == name { - addrs = append(addrs, hostname.ip) - } - } - + addrs = h.hosts.FindHosts(name); return } diff --git a/hostsfile/hostsfile_test.go b/hostsfile/hostsfile_test.go index fc59582..2fba33f 100644 --- a/hostsfile/hostsfile_test.go +++ b/hostsfile/hostsfile_test.go @@ -3,8 +3,6 @@ package hosts import ( "fmt" "net" - "runtime" - "strings" "testing" ) @@ -27,6 +25,7 @@ const ipv4Fail = ` const domain = "localhost" const ip = "127.0.0.1" const ipv6 = false +const wildcard = false func Diff(expected, actual string) string { return fmt.Sprintf(` @@ -47,6 +46,38 @@ func (h *hostlist) Contains(b *hostname) bool { return false } +func TestEquality(t *testing.T) { + var host1 *hostname + var host2 *hostname + + host1 = newHostname("hello", net.ParseIP("255.255.255.255"), false, false); + host2 = newHostname("hello", net.ParseIP("255.255.255.255"), false, false); + if !host1.Equal(host2) { + t.Error("Hosts are expected equal, got: ", host1, host2); + } + + host2 = newHostname("hello2", net.ParseIP("255.255.255.255"), false, false); + if host1.Equal(host2) { + t.Error("Hosts are expected different, got: ", host1, host2); + } + + host2 = newHostname("hello1", net.ParseIP("255.255.255.254"), false, false); + if host1.Equal(host2) { + t.Error("Hosts are expected different, got: ", host1, host2); + } + + host2 = newHostname("hello1", net.ParseIP("255.255.255.255"), true, false); + if host1.Equal(host2) { + t.Error("Hosts are expected different, got: ", host1, host2); + } + + host2 = newHostname("hello1", net.ParseIP("255.255.255.255"), false, true); + if host1.Equal(host2) { + t.Error("Hosts are expected different, got: ", host1, host2); + } + +} + func TestParseLine(t *testing.T) { var hosts hostlist @@ -73,31 +104,89 @@ func TestParseLine(t *testing.T) { t.Error("Expected to find zero hostnames when line is commented out") } + var err error; + err = hosts.add(newHostname("aaa", net.ParseIP("192.168.0.1"), false, false)); + if err != nil { + t.Error("Did not expect error on first hostname"); + } + err = hosts.add(newHostname("aaa", net.ParseIP("192.168.0.1"), false, false)); + if err == nil { + t.Error("Expected error on duplicate host"); + } + // Not Commented stuff - hosts = parseLine("255.255.255.255 broadcasthost test.domain.com domain.com") - if !hosts.Contains(newHostname("broadcasthost", net.ParseIP("255.255.255.255"), false)) || - !hosts.Contains(newHostname("test.domain.com", net.ParseIP("255.255.255.255"), false)) || - !hosts.Contains(newHostname("domain.com", net.ParseIP("255.255.255.255"), false)) || + hosts = parseLine("192.168.0.1 broadcasthost test.domain.com domain.com") + if !hosts.Contains(newHostname("broadcasthost", net.ParseIP("192.168.0.1"), false, false)) || + !hosts.Contains(newHostname("test.domain.com", net.ParseIP("192.168.0.1"), false, false)) || + !hosts.Contains(newHostname("domain.com", net.ParseIP("192.168.0.1"), false, false)) || len(hosts) != 3 { t.Error("Expected to find broadcasthost, domain.com, and test.domain.com") } + // Wildcard stuff + hosts = parseLine("192.168.0.1 *.domain.com mail.domain.com serenity") + if !hosts.Contains(newHostname("domain.com", net.ParseIP("192.168.0.1"), false, true)) || + !hosts.Contains(newHostname("mail.domain.com", net.ParseIP("192.168.0.1"), false, false)) || + !hosts.Contains(newHostname("serenity", net.ParseIP("192.168.0.1"), false, false)) || + len(hosts) != 3 { + t.Error("Expected to find *.domain.com, mail.domain.com and serenity.") + } + + var ip net.IP; + + ip = hosts.FindHost("api.domain.com"); + if !net.ParseIP("192.168.0.1").Equal(ip) { + t.Error("Can't match wildcard host api.domain.com"); + } + + ip = hosts.FindHost("google.com") + if ip != nil { + t.Error("We shouldn't resolve google.com"); + } + + hosts = *newHostlistString(`192.168.0.1 *.domain.com mail.domain.com serenity + 192.168.0.2 api.domain.com`); + + if (!net.ParseIP("192.168.0.2").Equal(hosts.FindHost("api.domain.com"))) { + t.Error("Failed matching api.domain.com explicitly"); + } + if (!net.ParseIP("192.168.0.1").Equal(hosts.FindHost("mail.domain.com"))) { + t.Error("Failed matching api.domain.com explicitly"); + } + if (!net.ParseIP("192.168.0.1").Equal(hosts.FindHost("wildcard.domain.com"))) { + t.Error("Failed matching wildcard.domain.com explicitly"); + } + if (net.ParseIP("192.168.0.1").Equal(hosts.FindHost("sub.wildcard.domain.com"))) { + t.Error("Failed not matching sub.wildcard.domain.com explicitly"); + } + + // IPv6 (not link-local) + hosts = parseLine("2a02:7a8:1:250::80:1 rtvslo.si img.rtvslo.si") + if !hosts.Contains(newHostname("img.rtvslo.si", net.ParseIP("2a02:7a8:1:250::80:1"), true, false)) || + len(hosts) != 2 { + t.Error("Expected to find rtvslo.si ipv6, two hosts") + } + + /* the following all fails since the addressses are link-local */ + + /* // Ipv6 stuff - hosts = hostess.parseLine("::1 localhost") - if !hosts.Contains(newHostname("localhost", net.ParseIP("::1"), true)) || + hosts = parseLine("::1 localhost") + if !hosts.Contains(newHostname("localhost", net.ParseIP("::1"), true, false)) || len(hosts) != 1 { - t.Error("Expected to find localhost ipv6 (enabled)") + t.Error("Expected to find localhost ipv6") } - hosts = hostess.parseLine("ff02::1 ip6-allnodes") - if !hosts.Contains(newHostname("ip6-allnodes", net.ParseIP("ff02::1"), true)) || + hosts = parseLine("ff02::1 ip6-allnodes") + if !hosts.Contains(newHostname("ip6-allnodes", net.ParseIP("ff02::1"), true, false)) || len(hosts) != 1 { - t.Error("Expected to find ip6-allnodes ipv6 (enabled)") + t.Error("Expected to find ip6-allnodes ipv6") } + */ } func TestHostname(t *testing.T) { - h := newHostname(domain, net.ParseIP(ip), ipv6) + h := newHostname(domain, net.ParseIP(ip), ipv6, wildcard) if h.domain != domain { t.Errorf("Domain should be %s", domain) @@ -105,7 +194,10 @@ func TestHostname(t *testing.T) { if !h.ip.Equal(net.ParseIP(ip)) { t.Errorf("IP should be %s", ip) } - if h.ipv6 != enabled { - t.Errorf("Enabled should be %t", enabled) + if h.ipv6 != ipv6 { + t.Errorf("IPv6 should be %t", ipv6) + } + if h.wildcard != wildcard { + t.Errorf("Wildcard should be %t", wildcard) } } diff --git a/hostsfile/utils.go b/hostsfile/utils.go index db3d59f..590b071 100644 --- a/hostsfile/utils.go +++ b/hostsfile/utils.go @@ -17,15 +17,20 @@ import ( type hostlist []*hostname type hostname struct { - domain string - ip net.IP - ipv6 bool + domain string + ip net.IP + ipv6 bool + wildcard bool } // newHostlist creates a hostlist by parsing a file func newHostlist(data []byte) *hostlist { + return newHostlistString(string(data)); +} + +func newHostlistString(data string) *hostlist { hostlist := hostlist{} - for _, v := range strings.Split(string(data), "\n") { + for _, v := range strings.Split(data, "\n") { for _, hostname := range parseLine(v) { err := hostlist.add(hostname) if err != nil { @@ -36,12 +41,61 @@ func newHostlist(data []byte) *hostlist { return &hostlist } +func (h *hostname) Equal(hostnamev *hostname) bool { + if (h.wildcard != hostnamev.wildcard || h.ipv6 != hostnamev.ipv6) { + return false + } + if (!h.ip.Equal(hostnamev.ip)) { + return false + } + if (h.domain != hostnamev.domain) { + return false + } + return true +} + +// return first match +func (h *hostlist) FindHost(name string) (addr net.IP) { + var ips []net.IP; + ips = h.FindHosts(name) + if len(ips) > 0 { + addr = ips[0]; + } + return +} + +// return exact matches, if existing -> else, return wildcard +func (h *hostlist) FindHosts(name string) (addrs []net.IP) { + for _, hostname := range *h { + if hostname.wildcard == false && hostname.domain == name { + addrs = append(addrs, hostname.ip) + } + } + + if len(addrs) == 0 { + var domain_match string; + for _, hostname := range *h { + if hostname.wildcard == true && len(hostname.domain) < len(name) { + domain_match = strings.Join([]string{".", hostname.domain}, ""); + if name[len(name)-len(domain_match):] == domain_match { + var left string; + left = name[0:len(name)-len(domain_match)] + if !strings.Contains(left, ".") { + addrs = append(addrs, hostname.ip) + } + } + } + } + } + + return +} + func (h *hostlist) add(hostnamev *hostname) error { - hostname := newHostname(hostnamev.domain, hostnamev.ip, hostnamev.ipv6) + hostname := newHostname(hostnamev.domain, hostnamev.ip, hostnamev.ipv6, hostnamev.wildcard) for _, found := range *h { - if found.domain == hostname.domain && found.ip.Equal(hostname.ip) { - return fmt.Errorf("Duplicate hostname entry for %s -> %s", - hostname.domain, hostname.ip) + if found.Equal(hostname) { + return fmt.Errorf("Duplicate hostname entry for %#v", hostname) } } *h = append(*h, hostname) @@ -49,9 +103,9 @@ func (h *hostlist) add(hostnamev *hostname) error { } // newHostname creates a new Hostname struct -func newHostname(domain string, ip net.IP, ipv6 bool) (host *hostname) { +func newHostname(domain string, ip net.IP, ipv6 bool, wildcard bool) (host *hostname) { domain = strings.ToLower(domain) - host = &hostname{domain, ip, ipv6} + host = &hostname{domain, ip, ipv6, wildcard} return } @@ -114,8 +168,14 @@ func parseLine(line string) hostlist { return hostnames } + var isWildcard bool for _, v := range domains { - hostname := newHostname(v, ip, isIPv6) + isWildcard = false + if v[0:2] == "*." { + v = v[2:] + isWildcard = true + } + hostname := newHostname(v, ip, isIPv6, isWildcard) hostnames = append(hostnames, hostname) } diff --git a/main.go b/main.go index b5ddf31..3b2a0bb 100644 --- a/main.go +++ b/main.go @@ -26,7 +26,7 @@ import ( ) // var Version string -const Version = "1.0.1" +const Version = "1.0.2" var ( nameservers = []string{}