Skip to content

Commit

Permalink
aghnet: imp code quality
Browse files Browse the repository at this point in the history
  • Loading branch information
EugeneOne1 committed Mar 19, 2021
1 parent d0ee01c commit d461ac8
Show file tree
Hide file tree
Showing 5 changed files with 97 additions and 66 deletions.
5 changes: 3 additions & 2 deletions internal/aghnet/net.go
Original file line number Diff line number Diff line change
Expand Up @@ -252,8 +252,9 @@ func ErrorIsAddrInUse(err error) bool {
return errErrno == syscall.EADDRINUSE
}

// GetHost tries to get host from hostport even if there is no port.
func GetHost(hostport string) (host string, err error) {
// SplitHost is a wrapper for net.SplitHostPort for the cases when the hostport
// does not necessarily contain a port.
func SplitHost(hostport string) (host string, err error) {
host, _, err = net.SplitHostPort(hostport)
if err != nil {
// Check for the missing port error. If it is that error, just
Expand Down
105 changes: 60 additions & 45 deletions internal/aghnet/systemresolvers.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,48 +12,50 @@ import (
"github.com/AdguardTeam/golibs/log"
)

// DefaultRefreshDur is the default period of time between refreshing cached
// DefaultRefreshIvl is the default period of time between refreshing cached
// addresses.
const DefaultRefreshDur = 5 * time.Minute
// const DefaultRefreshIvl = 5 * time.Minute

// HostGenFunc is the signature for functions generating fake hostnames. The
// implementation must be safe for concurrent use.
type HostGenFunc func() (host string)

// defaultHostGen is the default method of generating host for Refresh.
func defaultHostGen() (host string) {
// TODO(e.burkov): Use strings.Builder.
return fmt.Sprintf("test%d.org", time.Now().UnixNano())
}

// SystemResolvers provides methods to deal with local resolvers provided by system.
type SystemResolvers interface {
Get() (ss []string)
Refresh(customHost ...string) (err error)
Refresh() (err error)
}

const (
// fakeDialErr is an error which dialFunc is expected to return.
fakeDialErr agherr.Error = "this error signals the successful dialFunc work"

// badAddrPassedErr is returned when dialFunc can't parse an IP address.
badAddrPassedErr agherr.Error = "the passed string is not an IP address"
badAddrPassedErr agherr.Error = "the passed string is not a valid IP address"
)

// unit is an alias for an existing map value.
type unit = struct{}

// systemResolvers is a default implementation of SystemResolvers interface.
type systemResolvers struct {
ipDetector *IPDetector

resolver *net.Resolver
resolver *net.Resolver
hostGenFunc HostGenFunc

// addrs is the map that contains cached local resolvers' addresses.
addrs map[string]unit
addrsLock sync.RWMutex
}

// Refresh refreshes the local resolvers' addresses cache.
func (lr *systemResolvers) Refresh(customHost ...string) (err error) {
var host string
if len(customHost) == 0 {
host = fmt.Sprintf("test%d.org", time.Now().UnixNano())
} else {
host = customHost[0]
}

_, err = lr.resolver.LookupHost(context.Background(), host)
func (lr *systemResolvers) Refresh() (err error) {
_, err = lr.resolver.LookupHost(context.Background(), lr.hostGenFunc())
dnserr := &net.DNSError{}
if errors.As(err, &dnserr) && dnserr.Err == fakeDialErr.Error() {
return nil
Expand All @@ -79,20 +81,53 @@ func (lr *systemResolvers) refreshWithTicker(tickCh <-chan time.Time) {
}
}

// dialFunc gets the resolver's address and puts it into internal cache.
func (lr *systemResolvers) dialFunc(_ context.Context, _, address string) (_ net.Conn, err error) {
// Just validate the passed address is a valid IP.
var host string
host, err = SplitHost(address)
if err != nil {
// TODO(e.burkov): Maybe use a structured badAddrPassedErr to
// allow unwrapping of the real error.
return nil, fmt.Errorf(
"systemResolvers: %s: %w",
err,
badAddrPassedErr,
)
}

if net.ParseIP(host) == nil {
return nil, fmt.Errorf(
"systemResolvers: parsing %q: %w",
host,
badAddrPassedErr,
)
}

lr.addrsLock.Lock()
defer lr.addrsLock.Unlock()

lr.addrs[address] = unit{}

return nil, fakeDialErr
}

// NewSystemResolvers returns a SystemResolvers with ipd detecting the local
// addresses. The cache refresh rate is defined by refreshDur and disables by
// setting it to 0.
func NewSystemResolvers(ipd *IPDetector, refreshDur time.Duration) (l SystemResolvers, err error) {
if ipd == nil {
return nil, agherr.Error("a non-nil ipdetector is required")
// setting it to 0. If nil is passed for hostGenFunc, the default will be used.
func NewSystemResolvers(
refreshIvl time.Duration,
hostGenFunc HostGenFunc,
) (l SystemResolvers, err error) {
if hostGenFunc == nil {
hostGenFunc = defaultHostGen
}

ldns := &systemResolvers{
ipDetector: ipd,
resolver: &net.Resolver{
PreferGo: true,
},
addrs: make(map[string]unit),
hostGenFunc: hostGenFunc,
addrs: make(map[string]unit),
}
ldns.resolver.Dial = ldns.dialFunc

Expand All @@ -102,35 +137,15 @@ func NewSystemResolvers(ipd *IPDetector, refreshDur time.Duration) (l SystemReso
return nil, err
}

if refreshDur > 0 {
ticker := time.NewTicker(refreshDur)
if refreshIvl > 0 {
ticker := time.NewTicker(refreshIvl)

go ldns.refreshWithTicker(ticker.C)
}

return ldns, nil
}

// dialFunc gets the resolver's address and puts it into internal cache.
func (lr *systemResolvers) dialFunc(_ context.Context, _, address string) (_ net.Conn, err error) {
// Just validate the passed address is a valid IP.
var host string
host, err = GetHost(address)
if err != nil {
return nil, badAddrPassedErr
}
if net.ParseIP(host) == nil {
return nil, badAddrPassedErr
}

lr.addrsLock.Lock()
defer lr.addrsLock.Unlock()

lr.addrs[address] = unit{}

return nil, fakeDialErr
}

// Get implements SystemResolvers interface for *systemResolvers. It is safe
// for concurrent use.
func (lr *systemResolvers) Get() (ss []string) {
Expand Down
48 changes: 30 additions & 18 deletions internal/aghnet/systemresolvers_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package aghnet
import (
"bytes"
"context"
"errors"
"strings"
"testing"
"time"
Expand All @@ -13,21 +14,29 @@ import (
"github.com/stretchr/testify/require"
)

const zeroRefreshDur = 0
func createTestSystemResolvers(
t *testing.T,
refreshDur time.Duration,
hostGenFunc HostGenFunc,
) (sr SystemResolvers) {
t.Helper()

func createTestSystemResolvers(t *testing.T, refreshDur time.Duration) (sr SystemResolvers) {
ipd, err := NewIPDetector()
require.NoError(t, err)

sr, err = NewSystemResolvers(ipd, refreshDur)
var err error
sr, err = NewSystemResolvers(refreshDur, hostGenFunc)
require.NoError(t, err)
require.NotNil(t, sr)

return sr
}

func createTestSystemResolversImp(t *testing.T, refreshDur time.Duration) (imp *systemResolvers) {
sr := createTestSystemResolvers(t, refreshDur)
func createTestSystemResolversImp(
t *testing.T,
refreshDur time.Duration,
hostGenFunc HostGenFunc,
) (imp *systemResolvers) {
t.Helper()

sr := createTestSystemResolvers(t, refreshDur, hostGenFunc)

var ok bool
imp, ok = sr.(*systemResolvers)
Expand All @@ -37,20 +46,22 @@ func createTestSystemResolversImp(t *testing.T, refreshDur time.Duration) (imp *
}

func TestSystemResolvers_Get(t *testing.T) {
sr := createTestSystemResolvers(t, zeroRefreshDur)

sr := createTestSystemResolvers(t, 0, nil)
assert.NotEmpty(t, sr.Get())
}

func TestSystemResolvers_Refresh(t *testing.T) {
sr := createTestSystemResolvers(t, zeroRefreshDur)

t.Run("expected_error", func(t *testing.T) {
sr := createTestSystemResolvers(t, 0, nil)

assert.NoError(t, sr.Refresh())
})

t.Run("unexpected_error", func(t *testing.T) {
assert.NotNil(t, sr.Refresh("127.0.0.1::123"))
_, err := NewSystemResolvers(0, func() string {
return "127.0.0.1::123"
})
assert.Error(t, err)
})
}

Expand All @@ -68,19 +79,20 @@ func TestSystemResolvers_Refresh_withTicker(t *testing.T) {
wantEntrances = 2
)

_ = createTestSystemResolvers(t, setRefreshDur)
_ = createTestSystemResolvers(t, setRefreshDur, nil)
time.Sleep(butWaitFor)

ents := strings.Count(buf.String(), "local addresses cache is refreshed")
assert.GreaterOrEqual(t, ents, wantEntrances)
}

func TestSystemResolvers_DialFunc(t *testing.T) {
imp := createTestSystemResolversImp(t, zeroRefreshDur)
imp := createTestSystemResolversImp(t, 0, nil)

testCases := []struct {
name, address string
want error
name string
address string
want error
}{{
name: "valid",
address: "127.0.0.1",
Expand All @@ -100,7 +112,7 @@ func TestSystemResolvers_DialFunc(t *testing.T) {
conn, err := imp.dialFunc(context.Background(), "", tc.address)

require.Nil(t, conn)
assert.Equal(t, tc.want, err)
assert.True(t, errors.Is(err, tc.want))
})
}
}
3 changes: 3 additions & 0 deletions internal/aghtest/upstream.go
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,9 @@ type TestErrUpstream struct{}

// Exchange always returns nil Msg and non-nil error.
func (u *TestErrUpstream) Exchange(*dns.Msg) (*dns.Msg, error) {
// We don't use an agherr.Error to avoid the import cycle since aghtests
// used to provide the utilities for testing which agherr (and any other
// testable package) should be able to use.
return nil, errors.New("bad")
}

Expand Down
2 changes: 1 addition & 1 deletion internal/home/control.go
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@ func handleHTTPSRedirect(w http.ResponseWriter, r *http.Request) (ok bool) {
return true
}

host, err := aghnet.GetHost(r.Host)
host, err := aghnet.SplitHost(r.Host)
if err != nil {
httpError(w, http.StatusBadRequest, "bad host: %s", err)

Expand Down

0 comments on commit d461ac8

Please sign in to comment.