Skip to content

Commit

Permalink
all: imp external client restriction
Browse files Browse the repository at this point in the history
  • Loading branch information
EugeneOne1 committed Mar 24, 2021
1 parent 1208a31 commit 1dbacfc
Show file tree
Hide file tree
Showing 13 changed files with 446 additions and 201 deletions.
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ require (
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c // indirect
golang.org/x/sys v0.0.0-20210309074719-68d13333faf2
golang.org/x/term v0.0.0-20210220032956-6a3ed077a48d // indirect
golang.org/x/text v0.3.5 // indirect
golang.org/x/text v0.3.5
gopkg.in/natefinch/lumberjack.v2 v2.0.0
gopkg.in/yaml.v2 v2.4.0
howett.net/plist v0.0.0-20201203080718-1454fab16a06
Expand Down
17 changes: 10 additions & 7 deletions internal/aghnet/ipdetector.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,13 @@ import (
type IPDetector struct {
// spNets is the slice of special-purpose address registries as defined
// by RFC-6890 (https://tools.ietf.org/html/rfc6890).
spNets []*net.IPNet
spNetsMu sync.Mutex
spNets []*net.IPNet

// locServedNets is the slice of locally-served networks as defined by
// RFC-6303 (https://tools.ietf.org/html/rfc6303).
locServedNets []*net.IPNet
locServedNets []*net.IPNet

spNetsMu sync.Mutex
locServedNetsMu sync.Mutex
}

Expand Down Expand Up @@ -133,6 +134,8 @@ func NewIPDetector() (ipd *IPDetector, err error) {

// detectLocked ranges through the given ipnets slice searching for the one
// which contains the ip. For internal use only.
//
// TODO(e.burkov): Think about memoization.
func detectLocked(ipnets *[]*net.IPNet, ip net.IP) (is bool) {
for _, ipnet := range *ipnets {
if ipnet.Contains(ip) {
Expand All @@ -143,17 +146,17 @@ func detectLocked(ipnets *[]*net.IPNet, ip net.IP) (is bool) {
return false
}

// DetectSpecialNetwork returns true if IP address is contained by any of
// IsSpecialNetwork returns true if IP address is contained by any of
// special-purpose IP address registries. It's safe for concurrent use.
func (ipd *IPDetector) DetectSpecialNetwork(ip net.IP) (is bool) {
func (ipd *IPDetector) IsSpecialNetwork(ip net.IP) (is bool) {
ipd.spNetsMu.Lock()
defer ipd.spNetsMu.Unlock()
return detectLocked(&ipd.spNets, ip)
}

// DetectLocallyServedNetwork returns true if IP address is contained by any of
// IsLocallyServedNetwork returns true if IP address is contained by any of
// locally-served IP address registries. It's safe for concurrent use.
func (ipd *IPDetector) DetectLocallyServedNetwork(ip net.IP) (is bool) {
func (ipd *IPDetector) IsLocallyServedNetwork(ip net.IP) (is bool) {
ipd.locServedNetsMu.Lock()
defer ipd.locServedNetsMu.Unlock()
return detectLocked(&ipd.locServedNets, ip)
Expand Down
38 changes: 18 additions & 20 deletions internal/aghnet/ipdetector_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ func TestIPDetector_DetectSpecialNetwork(t *testing.T) {

for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
assert.Equal(t, tc.want, ipd.DetectSpecialNetwork(tc.ip))
assert.Equal(t, tc.want, ipd.IsSpecialNetwork(tc.ip))
})
}
}
Expand Down Expand Up @@ -209,7 +209,7 @@ func TestIPDetector_DetectLocallyServedNetwork(t *testing.T) {

for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
assert.Equal(t, tc.want, ipd.DetectLocallyServedNetwork(tc.ip))
assert.Equal(t, tc.want, ipd.IsLocallyServedNetwork(tc.ip))
})
}
}
Expand All @@ -220,27 +220,25 @@ func TestIPDetector_Detect_parallel(t *testing.T) {
ipd, err := NewIPDetector()
require.Nil(t, err)

testFunc := func(t *testing.T) {
assert.NotPanics(t, func() {
for _, ip := range []net.IP{
net.IPv4allrouter,
net.IPv4allsys,
net.IPv4bcast,
net.IPv4zero,
net.IPv6interfacelocalallnodes,
net.IPv6linklocalallnodes,
net.IPv6linklocalallrouters,
net.IPv6loopback,
net.IPv6unspecified,
} {
ipd.DetectSpecialNetwork(ip)
ipd.DetectLocallyServedNetwork(ip)
}
})
testFunc := func() {
for _, ip := range []net.IP{
net.IPv4allrouter,
net.IPv4allsys,
net.IPv4bcast,
net.IPv4zero,
net.IPv6interfacelocalallnodes,
net.IPv6linklocalallnodes,
net.IPv6linklocalallrouters,
net.IPv6loopback,
net.IPv6unspecified,
} {
_ = ipd.IsSpecialNetwork(ip)
_ = ipd.IsLocallyServedNetwork(ip)
}
}

const goroutinesNum = 50
for i := 0; i < goroutinesNum; i++ {
go testFunc(t)
go testFunc()
}
}
71 changes: 71 additions & 0 deletions internal/aghnet/localresolvers.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
// This is not the best place for this functionality, but since we need to use
// it in both rDNS (home) and dnsServer (dnsforward) we put it here.

package aghnet

import (
"time"

"github.com/AdguardTeam/AdGuardHome/internal/agherr"
"github.com/AdguardTeam/dnsproxy/upstream"
"github.com/miekg/dns"
)

// LocalResolvers is used to perform exchanging PTR requests for addresses from
// locally-served networks.
//
// TODO(e.burkov): Maybe expand with method like ExchangeParallel to be able to
// use user's upstream mode settings.
type LocalResolvers interface {
Exchange(req *dns.Msg) (resp *dns.Msg, err error)
}

// localResolvers is the default implementation of LocalResolvers interface.
type localResolvers struct {
ups []upstream.Upstream
}

// NewLocalResolvers creates a LocalResolvers instance from passed local
// resolvers addresses. It returns an error if any of addrs failed to become an
// upstream.
func NewLocalResolvers(addrs []string, timeout time.Duration) (lr LocalResolvers, err error) {
defer agherr.Annotate("localResolvers: %w", &err)

if len(addrs) == 0 {
return &localResolvers{ups: nil}, nil
}

var ups []upstream.Upstream
for _, addr := range addrs {
var u upstream.Upstream
u, err = upstream.AddressToUpstream(addr, upstream.Options{Timeout: timeout})
if err != nil {
return nil, err
}

ups = append(ups, u)
}

return &localResolvers{ups: ups}, nil
}

// Exсhange performs a query to each resolver until first response.
func (lr *localResolvers) Exchange(req *dns.Msg) (resp *dns.Msg, err error) {
defer agherr.Annotate("localResolvers", &err)

var errs []error
for _, u := range lr.ups {
resp, err = u.Exchange(req)
if err != nil {
errs = append(errs, err)

continue
}

if resp != nil {
return resp, nil
}
}

return nil, agherr.Many("can't exchange", errs...)
}
64 changes: 64 additions & 0 deletions internal/aghnet/localresolvers_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
package aghnet

import (
"testing"

"github.com/AdguardTeam/AdGuardHome/internal/aghtest"
"github.com/AdguardTeam/dnsproxy/upstream"
"github.com/miekg/dns"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func TestNewLocalResolvers(t *testing.T) {
var lr LocalResolvers
var err error

t.Run("empty", func(t *testing.T) {
lr, err = NewLocalResolvers([]string{}, 0)
require.NoError(t, err)
assert.NotNil(t, lr)
})

t.Run("successful", func(t *testing.T) {
lr, err = NewLocalResolvers([]string{"www.example.com"}, 0)
require.NoError(t, err)
assert.NotNil(t, lr)
})

t.Run("unsuccessful", func(t *testing.T) {
lr, err = NewLocalResolvers([]string{"tlss://www.example.com"}, 0)
require.Error(t, err)
assert.Nil(t, lr)
})
}

func TestLocalResolvers_Exchange(t *testing.T) {
lr := &localResolvers{}

t.Run("error", func(t *testing.T) {
lr.ups = []upstream.Upstream{&aghtest.TestErrUpstream{}}

resp, err := lr.Exchange(nil)
require.Error(t, err)
assert.Nil(t, resp)
})

t.Run("all_right", func(t *testing.T) {
lr.ups = []upstream.Upstream{&aghtest.TestUpstream{
Reverse: map[string][]string{
"abc": {"cba"},
},
}}

resp, err := lr.Exchange(&dns.Msg{
Question: []dns.Question{{
Name: "abc",
Qtype: dns.TypePTR,
}},
})
require.NoError(t, err)
require.Len(t, resp.Answer, 1)
assert.Equal(t, "cba", resp.Answer[0].Header().Name)
})
}
3 changes: 3 additions & 0 deletions internal/aghnet/net.go
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,9 @@ func SplitHost(hostport string) (host string, err error) {
return host, nil
}

// TODO(e.burkov): Inspect the charToHex, ipParseARPA6, ipReverse and
// UnreverseAddr and maybe refactor it.

// charToHex converts character to a hexadecimal.
func charToHex(n byte) int8 {
if n >= '0' && n <= '9' {
Expand Down
21 changes: 21 additions & 0 deletions internal/aghtest/localresolvers.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
package aghtest

import (
"github.com/AdguardTeam/dnsproxy/upstream"
"github.com/miekg/dns"
)

// LocalResolvers is an implementor aghnet.LocalResolvers interface to
// simplify testing.
type LocalResolvers struct {
Ups upstream.Upstream
}

// Exchange implements aghnet.LocalResolvers interface for *LocalResolvers.
func (lr *LocalResolvers) Exchange(req *dns.Msg) (resp *dns.Msg, err error) {
if lr.Ups == nil {
lr.Ups = &TestErrUpstream{}
}

return lr.Ups.Exchange(req)
}
2 changes: 1 addition & 1 deletion internal/aghtest/upstream.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ func (u *TestUpstream) Exchange(m *dns.Msg) (resp *dns.Msg, err error) {
for _, n := range names {
resp.Answer = append(resp.Answer, &dns.PTR{
Hdr: dns.RR_Header{
Name: name,
Name: n,
Rrtype: rrType,
},
Ptr: n,
Expand Down
Loading

0 comments on commit 1dbacfc

Please sign in to comment.