Skip to content

Commit

Permalink
all: fix client upstreams, imp code
Browse files Browse the repository at this point in the history
  • Loading branch information
ainar-g committed May 27, 2021
1 parent 48b8579 commit 98f86c2
Show file tree
Hide file tree
Showing 16 changed files with 133 additions and 160 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ released by then.

### Fixed

- Custom upstreams selection for clients with client IDs in DNS-over-TLS and
DNS-over-HTTP ([#3186]).
- Incorrect client-based filtering applying logic ([#2875]).

### Removed
Expand All @@ -40,6 +42,7 @@ released by then.

[#3184]: https://github.com/AdguardTeam/AdGuardHome/issues/3184
[#3185]: https://github.com/AdguardTeam/AdGuardHome/issues/3185
[#3186]: https://github.com/AdguardTeam/AdGuardHome/issues/3186



Expand Down
13 changes: 13 additions & 0 deletions internal/aghnet/addr.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,19 @@ import (
"golang.org/x/net/idna"
)

// IPFromAddr returns an IP address from addr. If addr is neither
// a *net.TCPAddr nor a *net.UDPAddr, it returns nil.
func IPFromAddr(addr net.Addr) (ip net.IP) {
switch addr := addr.(type) {
case *net.TCPAddr:
return addr.IP
case *net.UDPAddr:
return addr.IP
}

return nil
}

// IsValidHostOuterRune returns true if r is a valid initial or final rune for
// a hostname label.
func IsValidHostOuterRune(r rune) (ok bool) {
Expand Down
8 changes: 8 additions & 0 deletions internal/aghnet/addr_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,14 @@ import (
"github.com/stretchr/testify/require"
)

func TestIPFromAddr(t *testing.T) {
ip := net.IP{1, 2, 3, 4}
assert.Equal(t, net.IP(nil), IPFromAddr(nil))
assert.Equal(t, net.IP(nil), IPFromAddr(struct{ net.Addr }{}))
assert.Equal(t, ip, IPFromAddr(&net.TCPAddr{IP: ip}))
assert.Equal(t, ip, IPFromAddr(&net.UDPAddr{IP: ip}))
}

func TestValidateHardwareAddress(t *testing.T) {
testCases := []struct {
name string
Expand Down
13 changes: 13 additions & 0 deletions internal/aghstrings/strings.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,19 @@ func CloneSlice(a []string) (b []string) {
return CloneSliceOrEmpty(a)
}

// Coalesce returns the first non-empty string. It is named after the function
// COALESCE in SQL except that since strings in Go are non-nullable, it uses an
// empty string as a NULL value. If strs is empty, it returns an empty string.
func Coalesce(strs ...string) (res string) {
for _, s := range strs {
if s != "" {
return s
}
}

return ""
}

// FilterOut returns a copy of strs with all strings for which f returned true
// removed.
func FilterOut(strs []string, f func(s string) (ok bool)) (filtered []string) {
Expand Down
8 changes: 8 additions & 0 deletions internal/aghstrings/strings_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,14 @@ func TestCloneSlice_family(t *testing.T) {
})
}

func TestCoalesce(t *testing.T) {
assert.Equal(t, "", Coalesce())
assert.Equal(t, "a", Coalesce("a"))
assert.Equal(t, "a", Coalesce("", "a"))
assert.Equal(t, "a", Coalesce("a", ""))
assert.Equal(t, "a", Coalesce("a", "b"))
}

func TestFilterOut(t *testing.T) {
strs := []string{
"1.2.3.4",
Expand Down
9 changes: 4 additions & 5 deletions internal/dnsforward/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,10 @@ type FilteringConfig struct {
// FilterHandler is an optional additional filtering callback.
FilterHandler func(clientAddr net.IP, clientID string, settings *filtering.Settings) `yaml:"-"`

// GetCustomUpstreamByClient - a callback function that returns upstreams configuration
// based on the client IP address. Returns nil if there are no custom upstreams for the client
//
// TODO(e.burkov): Replace argument type with net.IP.
GetCustomUpstreamByClient func(clientAddr string) *proxy.UpstreamConfig `yaml:"-"`
// GetCustomUpstreamByClient is a callback that returns upstreams
// configuration based on the client IP address or ClientID. It returns
// nil if there are no custom upstreams for the client.
GetCustomUpstreamByClient func(id string) (conf *proxy.UpstreamConfig, err error) `yaml:"-"`

// Protection configuration
// --
Expand Down
13 changes: 9 additions & 4 deletions internal/dnsforward/dns.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"time"

"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
"github.com/AdguardTeam/AdGuardHome/internal/aghstrings"
"github.com/AdguardTeam/AdGuardHome/internal/dhcpd"
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
"github.com/AdguardTeam/dnsproxy/proxy"
Expand Down Expand Up @@ -229,7 +230,7 @@ func (s *Server) processDetermineLocal(dctx *dnsContext) (rc resultCode) {
rc = resultCodeSuccess

var ip net.IP
if ip = IPFromAddr(dctx.proxyCtx.Addr); ip == nil {
if ip = aghnet.IPFromAddr(dctx.proxyCtx.Addr); ip == nil {
return rc
}

Expand Down Expand Up @@ -497,9 +498,13 @@ func (s *Server) processUpstream(ctx *dnsContext) (rc resultCode) {
}

if d.Addr != nil && s.conf.GetCustomUpstreamByClient != nil {
clientIP := IPStringFromAddr(d.Addr)
if upsConf := s.conf.GetCustomUpstreamByClient(clientIP); upsConf != nil {
log.Debug("dns: using custom upstreams for client %s", clientIP)
// Use the clientID first, since it has a higher priority.
id := aghstrings.Coalesce(ctx.clientID, ipStringFromAddr(d.Addr))
upsConf, err := s.conf.GetCustomUpstreamByClient(id)
if err != nil {
log.Error("dns: getting custom upstreams for client %s: %s", id, err)
} else if upsConf != nil {
log.Debug("dns: using custom upstreams for client %s", id)
d.CustomUpstreamConfig = upsConf
}
}
Expand Down
20 changes: 10 additions & 10 deletions internal/dnsforward/dnsforward_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -521,16 +521,16 @@ func TestServerCustomClientUpstream(t *testing.T) {
},
}
s := createTestServer(t, &filtering.Config{}, forwardConf, nil)
s.conf.GetCustomUpstreamByClient = func(_ string) *proxy.UpstreamConfig {
return &proxy.UpstreamConfig{
Upstreams: []upstream.Upstream{
&aghtest.TestUpstream{
IPv4: map[string][]net.IP{
"host.": {{192, 168, 0, 1}},
},
},
s.conf.GetCustomUpstreamByClient = func(_ string) (conf *proxy.UpstreamConfig, err error) {
ups := &aghtest.TestUpstream{
IPv4: map[string][]net.IP{
"host.": {{192, 168, 0, 1}},
},
}

return &proxy.UpstreamConfig{
Upstreams: []upstream.Upstream{ups},
}, nil
}
startDeferStop(t, s)

Expand Down Expand Up @@ -969,11 +969,11 @@ func TestIPStringFromAddr(t *testing.T) {
Port: 12345,
Zone: "eth0",
}
assert.Equal(t, IPStringFromAddr(&addr), addr.IP.String())
assert.Equal(t, ipStringFromAddr(&addr), addr.IP.String())
})

t.Run("nil", func(t *testing.T) {
assert.Empty(t, IPStringFromAddr(nil))
assert.Empty(t, ipStringFromAddr(nil))
})
}

Expand Down
6 changes: 3 additions & 3 deletions internal/dnsforward/filter.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,15 @@ import (
"fmt"
"strings"

"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
"github.com/AdguardTeam/dnsproxy/proxy"
"github.com/AdguardTeam/golibs/log"

"github.com/miekg/dns"
)

func (s *Server) beforeRequestHandler(_ *proxy.Proxy, d *proxy.DNSContext) (bool, error) {
ip := IPFromAddr(d.Addr)
ip := aghnet.IPFromAddr(d.Addr)
disallowed, _ := s.access.IsBlockedIP(ip)
if disallowed {
log.Tracef("Client IP %s is blocked by settings", ip)
Expand All @@ -39,7 +39,7 @@ func (s *Server) beforeRequestHandler(_ *proxy.Proxy, d *proxy.DNSContext) (bool
func (s *Server) getClientRequestFilteringSettings(ctx *dnsContext) *filtering.Settings {
setts := s.dnsFilter.GetConfig()
if s.conf.FilterHandler != nil {
s.conf.FilterHandler(IPFromAddr(ctx.proxyCtx.Addr), ctx.clientID, &setts)
s.conf.FilterHandler(aghnet.IPFromAddr(ctx.proxyCtx.Addr), ctx.clientID, &setts)
}

return &setts
Expand Down
5 changes: 3 additions & 2 deletions internal/dnsforward/stats.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"strings"
"time"

"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
"github.com/AdguardTeam/AdGuardHome/internal/querylog"
"github.com/AdguardTeam/AdGuardHome/internal/stats"
Expand Down Expand Up @@ -37,7 +38,7 @@ func processQueryLogsAndStats(ctx *dnsContext) (rc resultCode) {
OrigAnswer: ctx.origResp,
Result: ctx.result,
Elapsed: elapsed,
ClientIP: IPFromAddr(pctx.Addr),
ClientIP: aghnet.IPFromAddr(pctx.Addr),
ClientID: ctx.clientID,
}

Expand Down Expand Up @@ -79,7 +80,7 @@ func (s *Server) updateStats(ctx *dnsContext, elapsed time.Duration, res filteri

if clientID := ctx.clientID; clientID != "" {
e.Client = clientID
} else if ip := IPFromAddr(pctx.Addr); ip != nil {
} else if ip := aghnet.IPFromAddr(pctx.Addr); ip != nil {
e.Client = ip.String()
}

Expand Down
19 changes: 3 additions & 16 deletions internal/dnsforward/util.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,22 +8,9 @@ import (
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
)

// IPFromAddr gets IP address from addr.
func IPFromAddr(addr net.Addr) (ip net.IP) {
switch addr := addr.(type) {
case *net.UDPAddr:
return addr.IP
case *net.TCPAddr:
return addr.IP
}
return nil
}

// IPStringFromAddr extracts IP address from net.Addr.
// Note: we can't use net.SplitHostPort(a.String()) because of IPv6 zone:
// https://github.com/AdguardTeam/AdGuardHome/internal/issues/1261
func IPStringFromAddr(addr net.Addr) (ipStr string) {
if ip := IPFromAddr(addr); ip != nil {
// ipStringFromAddr extracts an IP address string from net.Addr.
func ipStringFromAddr(addr net.Addr) (ipStr string) {
if ip := aghnet.IPFromAddr(addr); ip != nil {
return ip.String()
}

Expand Down
60 changes: 0 additions & 60 deletions internal/dnsforward/util_test.go

This file was deleted.

47 changes: 26 additions & 21 deletions internal/home/clients.go
Original file line number Diff line number Diff line change
Expand Up @@ -335,37 +335,42 @@ func (clients *clientsContainer) Find(id string) (c *Client, ok bool) {
return c, true
}

// FindUpstreams looks for upstreams configured for the client
// If no client found for this IP, or if no custom upstreams are configured,
// this method returns nil
func (clients *clientsContainer) FindUpstreams(ip string) *proxy.UpstreamConfig {
// findUpstreams returns upstreams configured for the client, identified either
// by its IP address or its ClientID. pconf is nil if the client isn't found or
// if the client has no custom upstreams.
func (clients *clientsContainer) findUpstreams(id string) (pconf *proxy.UpstreamConfig, err error) {
clients.lock.Lock()
defer clients.lock.Unlock()

c, ok := clients.findLocked(ip)
c, ok := clients.findLocked(id)
if !ok {
return nil
return nil, nil
}

upstreams := aghstrings.FilterOut(c.Upstreams, aghstrings.IsCommentOrEmpty)
if len(upstreams) == 0 {
return nil
}

if c.upstreamConfig == nil {
conf, err := proxy.ParseUpstreamsConfig(
upstreams,
upstream.Options{
Bootstrap: config.DNS.BootstrapDNS,
Timeout: dnsforward.DefaultTimeout,
},
)
if err == nil {
c.upstreamConfig = &conf
}
return nil, nil
}

if c.upstreamConfig != nil {
return c.upstreamConfig, nil
}

return c.upstreamConfig
var conf proxy.UpstreamConfig
conf, err = proxy.ParseUpstreamsConfig(
upstreams,
upstream.Options{
Bootstrap: config.DNS.BootstrapDNS,
Timeout: dnsforward.DefaultTimeout,
},
)
if err != nil {
return nil, err
}

c.upstreamConfig = &conf

return &conf, nil
}

// findLocked searches for a client by its ID. For internal use only.
Expand Down
Loading

0 comments on commit 98f86c2

Please sign in to comment.