Skip to content

Commit

Permalink
all: imp code, names
Browse files Browse the repository at this point in the history
  • Loading branch information
ainar-g committed May 28, 2021
1 parent 98f86c2 commit a8dd0e2
Show file tree
Hide file tree
Showing 7 changed files with 126 additions and 120 deletions.
45 changes: 44 additions & 1 deletion internal/dnsforward/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@ import (
"net/http"
"os"
"sort"
"strings"

"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
"github.com/AdguardTeam/AdGuardHome/internal/aghstrings"
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
"github.com/AdguardTeam/dnsproxy/proxy"
Expand Down Expand Up @@ -383,10 +385,51 @@ func (s *Server) prepareTLS(proxyConfig *proxy.Config) error {
return nil
}

// isInSorted returns true if s is in the sorted slice strs.
func isInSorted(strs []string, s string) (ok bool) {
i := sort.SearchStrings(strs, s)
if i == len(strs) || strs[i] != s {
return false
}

return true
}

// isWildcard returns true if host is a wildcard hostname.
func isWildcard(host string) (ok bool) {
return len(host) >= 2 && host[0] == '*' && host[1] == '.'
}

// matchesDomainWildcard returns true if host matches the domain wildcard
// pattern pat.
func matchesDomainWildcard(host, pat string) (ok bool) {
return isWildcard(pat) && strings.HasSuffix(host, pat[1:])
}

// anyNameMatches returns true if sni, the client's SNI value, matches any of
// the DNS names and patterns from certificate. dnsNames must be sorted.
func anyNameMatches(dnsNames []string, sni string) (ok bool) {
if aghnet.ValidateDomainName(sni) != nil {
return false
}

if isInSorted(dnsNames, sni) {
return true
}

for _, dn := range dnsNames {
if matchesDomainWildcard(sni, dn) {
return true
}
}

return false
}

// Called by 'tls' package when Client Hello is received
// If the server name (from SNI) supplied by client is incorrect - we terminate the ongoing TLS handshake.
func (s *Server) onGetCertificate(ch *tls.ClientHelloInfo) (*tls.Certificate, error) {
if s.conf.StrictSNICheck && !matchDNSName(s.conf.dnsNames, ch.ServerName) {
if s.conf.StrictSNICheck && !anyNameMatches(s.conf.dnsNames, ch.ServerName) {
log.Info("dns: tls: unknown SNI in Client Hello: %s", ch.ServerName)
return nil, fmt.Errorf("invalid SNI")
}
Expand Down
53 changes: 53 additions & 0 deletions internal/dnsforward/config_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
package dnsforward

import (
"sort"
"testing"

"github.com/stretchr/testify/assert"
)

func TestAnyNameMatches(t *testing.T) {
dnsNames := []string{"host1", "*.host2", "1.2.3.4"}
sort.Strings(dnsNames)

testCases := []struct {
name string
dnsName string
want bool
}{{
name: "match",
dnsName: "host1",
want: true,
}, {
name: "match",
dnsName: "a.host2",
want: true,
}, {
name: "match",
dnsName: "b.a.host2",
want: true,
}, {
name: "match",
dnsName: "1.2.3.4",
want: true,
}, {
name: "mismatch",
dnsName: "host2",
want: false,
}, {
name: "mismatch",
dnsName: "",
want: false,
}, {
name: "mismatch",
dnsName: "*.host2",
want: false,
}}

for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
assert.Equal(t, tc.want, anyNameMatches(dnsNames, tc.dnsName))
})
}
}
9 changes: 9 additions & 0 deletions internal/dnsforward/dns.go
Original file line number Diff line number Diff line change
Expand Up @@ -490,6 +490,15 @@ func processFilteringBeforeRequest(ctx *dnsContext) (rc resultCode) {
return resultCodeSuccess
}

// ipStringFromAddr extracts an IP address string from net.Addr.
func ipStringFromAddr(addr net.Addr) (ipStr string) {
if ip := aghnet.IPFromAddr(addr); ip != nil {
return ip.String()
}

return ""
}

// processUpstream passes request to upstream servers and handles the response.
func (s *Server) processUpstream(ctx *dnsContext) (rc resultCode) {
d := ctx.proxyCtx
Expand Down
15 changes: 15 additions & 0 deletions internal/dnsforward/dns_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -379,3 +379,18 @@ func TestServer_ProcessLocalPTR_usingResolvers(t *testing.T) {
require.Empty(t, proxyCtx.Res.Answer)
})
}

func TestIPStringFromAddr(t *testing.T) {
t.Run("not_nil", func(t *testing.T) {
addr := net.UDPAddr{
IP: net.ParseIP("1:2:3::4"),
Port: 12345,
Zone: "eth0",
}
assert.Equal(t, ipStringFromAddr(&addr), addr.IP.String())
})

t.Run("nil", func(t *testing.T) {
assert.Empty(t, ipStringFromAddr(nil))
})
}
60 changes: 0 additions & 60 deletions internal/dnsforward/dnsforward_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ import (
"math/big"
"net"
"os"
"sort"
"sync"
"testing"
"time"
Expand Down Expand Up @@ -962,65 +961,6 @@ func publicKey(priv interface{}) interface{} {
}
}

func TestIPStringFromAddr(t *testing.T) {
t.Run("not_nil", func(t *testing.T) {
addr := net.UDPAddr{
IP: net.ParseIP("1:2:3::4"),
Port: 12345,
Zone: "eth0",
}
assert.Equal(t, ipStringFromAddr(&addr), addr.IP.String())
})

t.Run("nil", func(t *testing.T) {
assert.Empty(t, ipStringFromAddr(nil))
})
}

func TestMatchDNSName(t *testing.T) {
dnsNames := []string{"host1", "*.host2", "1.2.3.4"}
sort.Strings(dnsNames)

testCases := []struct {
name string
dnsName string
want bool
}{{
name: "match",
dnsName: "host1",
want: true,
}, {
name: "match",
dnsName: "a.host2",
want: true,
}, {
name: "match",
dnsName: "b.a.host2",
want: true,
}, {
name: "match",
dnsName: "1.2.3.4",
want: true,
}, {
name: "mismatch",
dnsName: "host2",
want: false,
}, {
name: "mismatch",
dnsName: "",
want: false,
}, {
name: "mismatch",
dnsName: "*.host2",
want: false,
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
assert.Equal(t, tc.want, matchDNSName(dnsNames, tc.dnsName))
})
}
}

type testDHCP struct{}

func (d *testDHCP) Enabled() (ok bool) { return true }
Expand Down
56 changes: 0 additions & 56 deletions internal/dnsforward/util.go

This file was deleted.

8 changes: 5 additions & 3 deletions internal/home/clients.go
Original file line number Diff line number Diff line change
Expand Up @@ -336,9 +336,11 @@ func (clients *clientsContainer) Find(id string) (c *Client, ok bool) {
}

// findUpstreams returns upstreams configured for the client, identified either
// by its IP address or its ClientID. pconf is nil if the client isn't found or
// if the client has no custom upstreams.
func (clients *clientsContainer) findUpstreams(id string) (pconf *proxy.UpstreamConfig, err error) {
// by its IP address or its ClientID. upsConf is nil if the client isn't found
// or if the client has no custom upstreams.
func (clients *clientsContainer) findUpstreams(
id string,
) (upsConf *proxy.UpstreamConfig, err error) {
clients.lock.Lock()
defer clients.lock.Unlock()

Expand Down

0 comments on commit a8dd0e2

Please sign in to comment.