Skip to content

Commit

Permalink
all: mv funcs to agherr, mk system resolvers getter
Browse files Browse the repository at this point in the history
  • Loading branch information
EugeneOne1 committed Mar 19, 2021
1 parent eb9526c commit d0ee01c
Show file tree
Hide file tree
Showing 8 changed files with 338 additions and 24 deletions.
16 changes: 16 additions & 0 deletions internal/agherr/agherr.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ package agherr
import (
"fmt"
"strings"

"github.com/AdguardTeam/golibs/log"
)

// Error is the constant error type.
Expand Down Expand Up @@ -107,3 +109,17 @@ func Annotate(msg string, errPtr *error, args ...interface{}) {
*errPtr = fmt.Errorf(msg, args...)
}
}

// LogPanic is a convinient helper function to log a panic in a goroutine. It
// should not be used where proper error handling is required.
func LogPanic(prefix string) {
if v := recover(); v != nil {
if prefix != "" {
log.Error("%s: recovered from panic: %v", prefix, v)

return
}

log.Error("recovered from panic: %v", v)
}
}
38 changes: 38 additions & 0 deletions internal/agherr/agherr_test.go
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
package agherr

import (
"bytes"
"errors"
"fmt"
"testing"

"github.com/AdguardTeam/AdGuardHome/internal/aghtest"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
Expand Down Expand Up @@ -120,3 +122,39 @@ func TestAnnotate(t *testing.T) {
assert.Equal(t, wantMsg, err.Error())
})
}

func TestLogPanic(t *testing.T) {
buf := &bytes.Buffer{}
aghtest.ReplaceLogWriter(t, buf)

t.Run("prefix", func(t *testing.T) {
const (
panicMsg = "spooky!"
prefix = "packagename"
errWithNoPrefix = "[error] recovered from panic: spooky!"
errWithPrefix = "[error] packagename: recovered from panic: spooky!"
)

panicFunc := func(prefix string) {
defer LogPanic(prefix)

panic(panicMsg)
}

panicFunc("")
assert.Contains(t, buf.String(), errWithNoPrefix)
buf.Reset()

panicFunc(prefix)
assert.Contains(t, buf.String(), errWithPrefix)
buf.Reset()
})

t.Run("don't_panic", func(t *testing.T) {
require.NotPanics(t, func() {
defer LogPanic("")
})

assert.Empty(t, buf.String())
})
}
21 changes: 21 additions & 0 deletions internal/aghnet/net.go
Original file line number Diff line number Diff line change
Expand Up @@ -251,3 +251,24 @@ func ErrorIsAddrInUse(err error) bool {

return errErrno == syscall.EADDRINUSE
}

// GetHost tries to get host from hostport even if there is no port.
func GetHost(hostport string) (host string, err error) {
host, _, err = net.SplitHostPort(hostport)
if err != nil {
// Check for the missing port error. If it is that error, just
// use the host as is.
//
// See the source code for net.SplitHostPort.
const missingPort = "missing port in address"

addrErr := &net.AddrError{}
if !errors.As(err, &addrErr) || addrErr.Err != missingPort {
return "", err
}

host = hostport
}

return host, nil
}
149 changes: 149 additions & 0 deletions internal/aghnet/systemresolvers.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
package aghnet

import (
"context"
"errors"
"fmt"
"net"
"sync"
"time"

"github.com/AdguardTeam/AdGuardHome/internal/agherr"
"github.com/AdguardTeam/golibs/log"
)

// DefaultRefreshDur is the default period of time between refreshing cached
// addresses.
const DefaultRefreshDur = 5 * time.Minute

// SystemResolvers provides methods to deal with local resolvers provided by system.
type SystemResolvers interface {
Get() (ss []string)
Refresh(customHost ...string) (err error)
}

const (
// fakeDialErr is an error which dialFunc is expected to return.
fakeDialErr agherr.Error = "this error signals the successful dialFunc work"

// badAddrPassedErr is returned when dialFunc can't parse an IP address.
badAddrPassedErr agherr.Error = "the passed string is not an IP address"
)

// unit is an alias for an existing map value.
type unit = struct{}

// systemResolvers is a default implementation of SystemResolvers interface.
type systemResolvers struct {
ipDetector *IPDetector

resolver *net.Resolver

// addrs is the map that contains cached local resolvers' addresses.
addrs map[string]unit
addrsLock sync.RWMutex
}

// Refresh refreshes the local resolvers' addresses cache.
func (lr *systemResolvers) Refresh(customHost ...string) (err error) {
var host string
if len(customHost) == 0 {
host = fmt.Sprintf("test%d.org", time.Now().UnixNano())
} else {
host = customHost[0]
}

_, err = lr.resolver.LookupHost(context.Background(), host)
dnserr := &net.DNSError{}
if errors.As(err, &dnserr) && dnserr.Err == fakeDialErr.Error() {
return nil
}

return err
}

// refreshWithTicker refreshes the cache after each tick form tickCh.
func (lr *systemResolvers) refreshWithTicker(tickCh <-chan time.Time) {
defer agherr.LogPanic("systemResolvers")

// TODO(e.burkov): Implement a functionality to stop ticker.
for range tickCh {
err := lr.Refresh()
if err != nil {
log.Error("systemResolvers: error in refreshing goroutine: %s", err)

continue
}

log.Debug("local addresses cache is refreshed")
}
}

// NewSystemResolvers returns a SystemResolvers with ipd detecting the local
// addresses. The cache refresh rate is defined by refreshDur and disables by
// setting it to 0.
func NewSystemResolvers(ipd *IPDetector, refreshDur time.Duration) (l SystemResolvers, err error) {
if ipd == nil {
return nil, agherr.Error("a non-nil ipdetector is required")
}

ldns := &systemResolvers{
ipDetector: ipd,
resolver: &net.Resolver{
PreferGo: true,
},
addrs: make(map[string]unit),
}
ldns.resolver.Dial = ldns.dialFunc

// Fill cache.
err = ldns.Refresh()
if err != nil {
return nil, err
}

if refreshDur > 0 {
ticker := time.NewTicker(refreshDur)

go ldns.refreshWithTicker(ticker.C)
}

return ldns, nil
}

// dialFunc gets the resolver's address and puts it into internal cache.
func (lr *systemResolvers) dialFunc(_ context.Context, _, address string) (_ net.Conn, err error) {
// Just validate the passed address is a valid IP.
var host string
host, err = GetHost(address)
if err != nil {
return nil, badAddrPassedErr
}
if net.ParseIP(host) == nil {
return nil, badAddrPassedErr
}

lr.addrsLock.Lock()
defer lr.addrsLock.Unlock()

lr.addrs[address] = unit{}

return nil, fakeDialErr
}

// Get implements SystemResolvers interface for *systemResolvers. It is safe
// for concurrent use.
func (lr *systemResolvers) Get() (ss []string) {
lr.addrsLock.RLock()
defer lr.addrsLock.RUnlock()

addrs := lr.addrs
ss = make([]string, len(addrs))
var i int
for addr := range addrs {
ss[i] = addr
i++
}

return ss
}
106 changes: 106 additions & 0 deletions internal/aghnet/systemresolvers_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
package aghnet

import (
"bytes"
"context"
"strings"
"testing"
"time"

"github.com/AdguardTeam/AdGuardHome/internal/aghtest"
"github.com/AdguardTeam/golibs/log"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

const zeroRefreshDur = 0

func createTestSystemResolvers(t *testing.T, refreshDur time.Duration) (sr SystemResolvers) {
ipd, err := NewIPDetector()
require.NoError(t, err)

sr, err = NewSystemResolvers(ipd, refreshDur)
require.NoError(t, err)
require.NotNil(t, sr)

return sr
}

func createTestSystemResolversImp(t *testing.T, refreshDur time.Duration) (imp *systemResolvers) {
sr := createTestSystemResolvers(t, refreshDur)

var ok bool
imp, ok = sr.(*systemResolvers)
require.True(t, ok)

return imp
}

func TestSystemResolvers_Get(t *testing.T) {
sr := createTestSystemResolvers(t, zeroRefreshDur)

assert.NotEmpty(t, sr.Get())
}

func TestSystemResolvers_Refresh(t *testing.T) {
sr := createTestSystemResolvers(t, zeroRefreshDur)

t.Run("expected_error", func(t *testing.T) {
assert.NoError(t, sr.Refresh())
})

t.Run("unexpected_error", func(t *testing.T) {
assert.NotNil(t, sr.Refresh("127.0.0.1::123"))
})
}

func TestSystemResolvers_Refresh_withTicker(t *testing.T) {
t.Skip("TODO(e.burkov): The test now fails because of race in logger. Fix logger.")

buf := &bytes.Buffer{}
aghtest.ReplaceLogLevel(t, log.DEBUG)
aghtest.ReplaceLogWriter(t, buf)

const (
setRefreshDur = 1 * time.Second
butWaitFor = 3 * time.Second

wantEntrances = 2
)

_ = createTestSystemResolvers(t, setRefreshDur)
time.Sleep(butWaitFor)

ents := strings.Count(buf.String(), "local addresses cache is refreshed")
assert.GreaterOrEqual(t, ents, wantEntrances)
}

func TestSystemResolvers_DialFunc(t *testing.T) {
imp := createTestSystemResolversImp(t, zeroRefreshDur)

testCases := []struct {
name, address string
want error
}{{
name: "valid",
address: "127.0.0.1",
want: fakeDialErr,
}, {
name: "invalid",
address: "127.0.0.1::123",
want: badAddrPassedErr,
}, {
name: "not_an_address",
address: "not-ip",
want: badAddrPassedErr,
}}

for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
conn, err := imp.dialFunc(context.Background(), "", tc.address)

require.Nil(t, conn)
assert.Equal(t, tc.want, err)
})
}
}
4 changes: 2 additions & 2 deletions internal/aghtest/upstream.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,12 @@ package aghtest
import (
"crypto/sha256"
"encoding/hex"
"errors"
"fmt"
"net"
"strings"
"sync"

"github.com/AdguardTeam/AdGuardHome/internal/agherr"
"github.com/miekg/dns"
)

Expand Down Expand Up @@ -166,7 +166,7 @@ type TestErrUpstream struct{}

// Exchange always returns nil Msg and non-nil error.
func (u *TestErrUpstream) Exchange(*dns.Msg) (*dns.Msg, error) {
return nil, agherr.Error("bad")
return nil, errors.New("bad")
}

// Address always returns an empty string.
Expand Down
Loading

0 comments on commit d0ee01c

Please sign in to comment.