-
Notifications
You must be signed in to change notification settings - Fork 1.9k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
all: mv funcs to agherr, mk system resolvers getter
- Loading branch information
1 parent
eb9526c
commit d0ee01c
Showing
8 changed files
with
338 additions
and
24 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,149 @@ | ||
package aghnet | ||
|
||
import ( | ||
"context" | ||
"errors" | ||
"fmt" | ||
"net" | ||
"sync" | ||
"time" | ||
|
||
"github.com/AdguardTeam/AdGuardHome/internal/agherr" | ||
"github.com/AdguardTeam/golibs/log" | ||
) | ||
|
||
// DefaultRefreshDur is the default period of time between refreshing cached | ||
// addresses. | ||
const DefaultRefreshDur = 5 * time.Minute | ||
|
||
// SystemResolvers provides methods to deal with local resolvers provided by system. | ||
type SystemResolvers interface { | ||
Get() (ss []string) | ||
Refresh(customHost ...string) (err error) | ||
} | ||
|
||
const ( | ||
// fakeDialErr is an error which dialFunc is expected to return. | ||
fakeDialErr agherr.Error = "this error signals the successful dialFunc work" | ||
|
||
// badAddrPassedErr is returned when dialFunc can't parse an IP address. | ||
badAddrPassedErr agherr.Error = "the passed string is not an IP address" | ||
) | ||
|
||
// unit is an alias for an existing map value. | ||
type unit = struct{} | ||
|
||
// systemResolvers is a default implementation of SystemResolvers interface. | ||
type systemResolvers struct { | ||
ipDetector *IPDetector | ||
|
||
resolver *net.Resolver | ||
|
||
// addrs is the map that contains cached local resolvers' addresses. | ||
addrs map[string]unit | ||
addrsLock sync.RWMutex | ||
} | ||
|
||
// Refresh refreshes the local resolvers' addresses cache. | ||
func (lr *systemResolvers) Refresh(customHost ...string) (err error) { | ||
var host string | ||
if len(customHost) == 0 { | ||
host = fmt.Sprintf("test%d.org", time.Now().UnixNano()) | ||
} else { | ||
host = customHost[0] | ||
} | ||
|
||
_, err = lr.resolver.LookupHost(context.Background(), host) | ||
dnserr := &net.DNSError{} | ||
if errors.As(err, &dnserr) && dnserr.Err == fakeDialErr.Error() { | ||
return nil | ||
} | ||
|
||
return err | ||
} | ||
|
||
// refreshWithTicker refreshes the cache after each tick form tickCh. | ||
func (lr *systemResolvers) refreshWithTicker(tickCh <-chan time.Time) { | ||
defer agherr.LogPanic("systemResolvers") | ||
|
||
// TODO(e.burkov): Implement a functionality to stop ticker. | ||
for range tickCh { | ||
err := lr.Refresh() | ||
if err != nil { | ||
log.Error("systemResolvers: error in refreshing goroutine: %s", err) | ||
|
||
continue | ||
} | ||
|
||
log.Debug("local addresses cache is refreshed") | ||
} | ||
} | ||
|
||
// NewSystemResolvers returns a SystemResolvers with ipd detecting the local | ||
// addresses. The cache refresh rate is defined by refreshDur and disables by | ||
// setting it to 0. | ||
func NewSystemResolvers(ipd *IPDetector, refreshDur time.Duration) (l SystemResolvers, err error) { | ||
if ipd == nil { | ||
return nil, agherr.Error("a non-nil ipdetector is required") | ||
} | ||
|
||
ldns := &systemResolvers{ | ||
ipDetector: ipd, | ||
resolver: &net.Resolver{ | ||
PreferGo: true, | ||
}, | ||
addrs: make(map[string]unit), | ||
} | ||
ldns.resolver.Dial = ldns.dialFunc | ||
|
||
// Fill cache. | ||
err = ldns.Refresh() | ||
if err != nil { | ||
return nil, err | ||
} | ||
|
||
if refreshDur > 0 { | ||
ticker := time.NewTicker(refreshDur) | ||
|
||
go ldns.refreshWithTicker(ticker.C) | ||
} | ||
|
||
return ldns, nil | ||
} | ||
|
||
// dialFunc gets the resolver's address and puts it into internal cache. | ||
func (lr *systemResolvers) dialFunc(_ context.Context, _, address string) (_ net.Conn, err error) { | ||
// Just validate the passed address is a valid IP. | ||
var host string | ||
host, err = GetHost(address) | ||
if err != nil { | ||
return nil, badAddrPassedErr | ||
} | ||
if net.ParseIP(host) == nil { | ||
return nil, badAddrPassedErr | ||
} | ||
|
||
lr.addrsLock.Lock() | ||
defer lr.addrsLock.Unlock() | ||
|
||
lr.addrs[address] = unit{} | ||
|
||
return nil, fakeDialErr | ||
} | ||
|
||
// Get implements SystemResolvers interface for *systemResolvers. It is safe | ||
// for concurrent use. | ||
func (lr *systemResolvers) Get() (ss []string) { | ||
lr.addrsLock.RLock() | ||
defer lr.addrsLock.RUnlock() | ||
|
||
addrs := lr.addrs | ||
ss = make([]string, len(addrs)) | ||
var i int | ||
for addr := range addrs { | ||
ss[i] = addr | ||
i++ | ||
} | ||
|
||
return ss | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,106 @@ | ||
package aghnet | ||
|
||
import ( | ||
"bytes" | ||
"context" | ||
"strings" | ||
"testing" | ||
"time" | ||
|
||
"github.com/AdguardTeam/AdGuardHome/internal/aghtest" | ||
"github.com/AdguardTeam/golibs/log" | ||
"github.com/stretchr/testify/assert" | ||
"github.com/stretchr/testify/require" | ||
) | ||
|
||
const zeroRefreshDur = 0 | ||
|
||
func createTestSystemResolvers(t *testing.T, refreshDur time.Duration) (sr SystemResolvers) { | ||
ipd, err := NewIPDetector() | ||
require.NoError(t, err) | ||
|
||
sr, err = NewSystemResolvers(ipd, refreshDur) | ||
require.NoError(t, err) | ||
require.NotNil(t, sr) | ||
|
||
return sr | ||
} | ||
|
||
func createTestSystemResolversImp(t *testing.T, refreshDur time.Duration) (imp *systemResolvers) { | ||
sr := createTestSystemResolvers(t, refreshDur) | ||
|
||
var ok bool | ||
imp, ok = sr.(*systemResolvers) | ||
require.True(t, ok) | ||
|
||
return imp | ||
} | ||
|
||
func TestSystemResolvers_Get(t *testing.T) { | ||
sr := createTestSystemResolvers(t, zeroRefreshDur) | ||
|
||
assert.NotEmpty(t, sr.Get()) | ||
} | ||
|
||
func TestSystemResolvers_Refresh(t *testing.T) { | ||
sr := createTestSystemResolvers(t, zeroRefreshDur) | ||
|
||
t.Run("expected_error", func(t *testing.T) { | ||
assert.NoError(t, sr.Refresh()) | ||
}) | ||
|
||
t.Run("unexpected_error", func(t *testing.T) { | ||
assert.NotNil(t, sr.Refresh("127.0.0.1::123")) | ||
}) | ||
} | ||
|
||
func TestSystemResolvers_Refresh_withTicker(t *testing.T) { | ||
t.Skip("TODO(e.burkov): The test now fails because of race in logger. Fix logger.") | ||
|
||
buf := &bytes.Buffer{} | ||
aghtest.ReplaceLogLevel(t, log.DEBUG) | ||
aghtest.ReplaceLogWriter(t, buf) | ||
|
||
const ( | ||
setRefreshDur = 1 * time.Second | ||
butWaitFor = 3 * time.Second | ||
|
||
wantEntrances = 2 | ||
) | ||
|
||
_ = createTestSystemResolvers(t, setRefreshDur) | ||
time.Sleep(butWaitFor) | ||
|
||
ents := strings.Count(buf.String(), "local addresses cache is refreshed") | ||
assert.GreaterOrEqual(t, ents, wantEntrances) | ||
} | ||
|
||
func TestSystemResolvers_DialFunc(t *testing.T) { | ||
imp := createTestSystemResolversImp(t, zeroRefreshDur) | ||
|
||
testCases := []struct { | ||
name, address string | ||
want error | ||
}{{ | ||
name: "valid", | ||
address: "127.0.0.1", | ||
want: fakeDialErr, | ||
}, { | ||
name: "invalid", | ||
address: "127.0.0.1::123", | ||
want: badAddrPassedErr, | ||
}, { | ||
name: "not_an_address", | ||
address: "not-ip", | ||
want: badAddrPassedErr, | ||
}} | ||
|
||
for _, tc := range testCases { | ||
t.Run(tc.name, func(t *testing.T) { | ||
conn, err := imp.dialFunc(context.Background(), "", tc.address) | ||
|
||
require.Nil(t, conn) | ||
assert.Equal(t, tc.want, err) | ||
}) | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.